mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-11 18:44:45 +08:00
merge upstream main
This commit is contained in:
@@ -47,6 +47,7 @@ type Config struct {
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
Ops OpsConfig `mapstructure:"ops"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
Totp TotpConfig `mapstructure:"totp"`
|
||||
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
@@ -55,6 +56,7 @@ type Config struct {
|
||||
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
|
||||
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
|
||||
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
|
||||
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
|
||||
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||
@@ -267,6 +269,33 @@ type GatewayConfig struct {
|
||||
|
||||
// Scheduling: 账号调度相关配置
|
||||
Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"`
|
||||
|
||||
// TLSFingerprint: TLS指纹伪装配置
|
||||
TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"`
|
||||
}
|
||||
|
||||
// TLSFingerprintConfig TLS指纹伪装配置
|
||||
// 用于模拟 Claude CLI (Node.js) 的 TLS 握手特征,避免被识别为非官方客户端
|
||||
type TLSFingerprintConfig struct {
|
||||
// Enabled: 是否全局启用TLS指纹功能
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
// Profiles: 预定义的TLS指纹配置模板
|
||||
// key 为模板名称,如 "claude_cli_v2", "chrome_120" 等
|
||||
Profiles map[string]TLSProfileConfig `mapstructure:"profiles"`
|
||||
}
|
||||
|
||||
// TLSProfileConfig 单个TLS指纹模板的配置
|
||||
type TLSProfileConfig struct {
|
||||
// Name: 模板显示名称
|
||||
Name string `mapstructure:"name"`
|
||||
// EnableGREASE: 是否启用GREASE扩展(Chrome使用,Node.js不使用)
|
||||
EnableGREASE bool `mapstructure:"enable_grease"`
|
||||
// CipherSuites: TLS加密套件列表(空则使用内置默认值)
|
||||
CipherSuites []uint16 `mapstructure:"cipher_suites"`
|
||||
// Curves: 椭圆曲线列表(空则使用内置默认值)
|
||||
Curves []uint16 `mapstructure:"curves"`
|
||||
// PointFormats: 点格式列表(空则使用内置默认值)
|
||||
PointFormats []uint8 `mapstructure:"point_formats"`
|
||||
}
|
||||
|
||||
// GatewaySchedulingConfig accounts scheduling configuration.
|
||||
@@ -386,6 +415,8 @@ type RedisConfig struct {
|
||||
PoolSize int `mapstructure:"pool_size"`
|
||||
// MinIdleConns: 最小空闲连接数,保持热连接减少冷启动延迟
|
||||
MinIdleConns int `mapstructure:"min_idle_conns"`
|
||||
// EnableTLS: 是否启用 TLS/SSL 连接
|
||||
EnableTLS bool `mapstructure:"enable_tls"`
|
||||
}
|
||||
|
||||
func (r *RedisConfig) Address() string {
|
||||
@@ -438,6 +469,16 @@ type JWTConfig struct {
|
||||
ExpireHour int `mapstructure:"expire_hour"`
|
||||
}
|
||||
|
||||
// TotpConfig TOTP 双因素认证配置
|
||||
type TotpConfig struct {
|
||||
// EncryptionKey 用于加密 TOTP 密钥的 AES-256 密钥(32 字节 hex 编码)
|
||||
// 如果为空,将自动生成一个随机密钥(仅适用于开发环境)
|
||||
EncryptionKey string `mapstructure:"encryption_key"`
|
||||
// EncryptionKeyConfigured 标记加密密钥是否为手动配置(非自动生成)
|
||||
// 只有手动配置了密钥才允许在管理后台启用 TOTP 功能
|
||||
EncryptionKeyConfigured bool `mapstructure:"-"`
|
||||
}
|
||||
|
||||
type TurnstileConfig struct {
|
||||
Required bool `mapstructure:"required"`
|
||||
}
|
||||
@@ -504,6 +545,20 @@ type DashboardAggregationRetentionConfig struct {
|
||||
DailyDays int `mapstructure:"daily_days"`
|
||||
}
|
||||
|
||||
// UsageCleanupConfig 使用记录清理任务配置
|
||||
type UsageCleanupConfig struct {
|
||||
// Enabled: 是否启用清理任务执行器
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
// MaxRangeDays: 单次任务允许的最大时间跨度(天)
|
||||
MaxRangeDays int `mapstructure:"max_range_days"`
|
||||
// BatchSize: 单批删除数量
|
||||
BatchSize int `mapstructure:"batch_size"`
|
||||
// WorkerIntervalSeconds: 后台任务轮询间隔(秒)
|
||||
WorkerIntervalSeconds int `mapstructure:"worker_interval_seconds"`
|
||||
// TaskTimeoutSeconds: 单次任务最大执行时长(秒)
|
||||
TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"`
|
||||
}
|
||||
|
||||
func NormalizeRunMode(value string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(value))
|
||||
switch normalized {
|
||||
@@ -584,6 +639,20 @@ func Load() (*Config, error) {
|
||||
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
|
||||
}
|
||||
|
||||
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
|
||||
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
|
||||
if cfg.Totp.EncryptionKey == "" {
|
||||
key, err := generateJWTSecret(32) // Reuse the same random generation function
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate totp encryption key error: %w", err)
|
||||
}
|
||||
cfg.Totp.EncryptionKey = key
|
||||
cfg.Totp.EncryptionKeyConfigured = false
|
||||
log.Println("Warning: TOTP encryption key auto-generated. Consider setting a fixed key for production.")
|
||||
} else {
|
||||
cfg.Totp.EncryptionKeyConfigured = true
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("validate config error: %w", err)
|
||||
}
|
||||
@@ -695,6 +764,7 @@ func setDefaults() {
|
||||
viper.SetDefault("redis.write_timeout_seconds", 3)
|
||||
viper.SetDefault("redis.pool_size", 128)
|
||||
viper.SetDefault("redis.min_idle_conns", 10)
|
||||
viper.SetDefault("redis.enable_tls", false)
|
||||
|
||||
// Ops (vNext)
|
||||
viper.SetDefault("ops.enabled", true)
|
||||
@@ -714,6 +784,9 @@ func setDefaults() {
|
||||
viper.SetDefault("jwt.secret", "")
|
||||
viper.SetDefault("jwt.expire_hour", 24)
|
||||
|
||||
// TOTP
|
||||
viper.SetDefault("totp.encryption_key", "")
|
||||
|
||||
// Default
|
||||
// Admin credentials are created via the setup flow (web wizard / CLI / AUTO_SETUP).
|
||||
// Do not ship fixed defaults here to avoid insecure "known credentials" in production.
|
||||
@@ -764,6 +837,13 @@ func setDefaults() {
|
||||
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
|
||||
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
|
||||
|
||||
// Usage cleanup task
|
||||
viper.SetDefault("usage_cleanup.enabled", true)
|
||||
viper.SetDefault("usage_cleanup.max_range_days", 31)
|
||||
viper.SetDefault("usage_cleanup.batch_size", 5000)
|
||||
viper.SetDefault("usage_cleanup.worker_interval_seconds", 10)
|
||||
viper.SetDefault("usage_cleanup.task_timeout_seconds", 1800)
|
||||
|
||||
// Gateway
|
||||
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久
|
||||
viper.SetDefault("gateway.log_upstream_error_body", true)
|
||||
@@ -802,6 +882,8 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3)
|
||||
viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000)
|
||||
viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300)
|
||||
// TLS指纹伪装配置(默认关闭,需要账号级别单独启用)
|
||||
viper.SetDefault("gateway.tls_fingerprint.enabled", true)
|
||||
viper.SetDefault("concurrency.ping_interval", 10)
|
||||
|
||||
// TokenRefresh
|
||||
@@ -1004,6 +1086,33 @@ func (c *Config) Validate() error {
|
||||
return fmt.Errorf("dashboard_aggregation.recompute_days must be non-negative")
|
||||
}
|
||||
}
|
||||
if c.UsageCleanup.Enabled {
|
||||
if c.UsageCleanup.MaxRangeDays <= 0 {
|
||||
return fmt.Errorf("usage_cleanup.max_range_days must be positive")
|
||||
}
|
||||
if c.UsageCleanup.BatchSize <= 0 {
|
||||
return fmt.Errorf("usage_cleanup.batch_size must be positive")
|
||||
}
|
||||
if c.UsageCleanup.WorkerIntervalSeconds <= 0 {
|
||||
return fmt.Errorf("usage_cleanup.worker_interval_seconds must be positive")
|
||||
}
|
||||
if c.UsageCleanup.TaskTimeoutSeconds <= 0 {
|
||||
return fmt.Errorf("usage_cleanup.task_timeout_seconds must be positive")
|
||||
}
|
||||
} else {
|
||||
if c.UsageCleanup.MaxRangeDays < 0 {
|
||||
return fmt.Errorf("usage_cleanup.max_range_days must be non-negative")
|
||||
}
|
||||
if c.UsageCleanup.BatchSize < 0 {
|
||||
return fmt.Errorf("usage_cleanup.batch_size must be non-negative")
|
||||
}
|
||||
if c.UsageCleanup.WorkerIntervalSeconds < 0 {
|
||||
return fmt.Errorf("usage_cleanup.worker_interval_seconds must be non-negative")
|
||||
}
|
||||
if c.UsageCleanup.TaskTimeoutSeconds < 0 {
|
||||
return fmt.Errorf("usage_cleanup.task_timeout_seconds must be non-negative")
|
||||
}
|
||||
}
|
||||
if c.Gateway.MaxBodySize <= 0 {
|
||||
return fmt.Errorf("gateway.max_body_size must be positive")
|
||||
}
|
||||
|
||||
@@ -280,3 +280,573 @@ func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) {
|
||||
t.Fatalf("Validate() expected backfill_max_days error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDefaultUsageCleanupConfig(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
if !cfg.UsageCleanup.Enabled {
|
||||
t.Fatalf("UsageCleanup.Enabled = false, want true")
|
||||
}
|
||||
if cfg.UsageCleanup.MaxRangeDays != 31 {
|
||||
t.Fatalf("UsageCleanup.MaxRangeDays = %d, want 31", cfg.UsageCleanup.MaxRangeDays)
|
||||
}
|
||||
if cfg.UsageCleanup.BatchSize != 5000 {
|
||||
t.Fatalf("UsageCleanup.BatchSize = %d, want 5000", cfg.UsageCleanup.BatchSize)
|
||||
}
|
||||
if cfg.UsageCleanup.WorkerIntervalSeconds != 10 {
|
||||
t.Fatalf("UsageCleanup.WorkerIntervalSeconds = %d, want 10", cfg.UsageCleanup.WorkerIntervalSeconds)
|
||||
}
|
||||
if cfg.UsageCleanup.TaskTimeoutSeconds != 1800 {
|
||||
t.Fatalf("UsageCleanup.TaskTimeoutSeconds = %d, want 1800", cfg.UsageCleanup.TaskTimeoutSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUsageCleanupConfigEnabled(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.UsageCleanup.Enabled = true
|
||||
cfg.UsageCleanup.MaxRangeDays = 0
|
||||
err = cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatalf("Validate() expected error for usage_cleanup.max_range_days, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "usage_cleanup.max_range_days") {
|
||||
t.Fatalf("Validate() expected max_range_days error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUsageCleanupConfigDisabled(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.UsageCleanup.Enabled = false
|
||||
cfg.UsageCleanup.BatchSize = -1
|
||||
err = cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatalf("Validate() expected error for usage_cleanup.batch_size, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "usage_cleanup.batch_size") {
|
||||
t.Fatalf("Validate() expected batch_size error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigAddressHelpers(t *testing.T) {
|
||||
server := ServerConfig{Host: "127.0.0.1", Port: 9000}
|
||||
if server.Address() != "127.0.0.1:9000" {
|
||||
t.Fatalf("ServerConfig.Address() = %q", server.Address())
|
||||
}
|
||||
|
||||
dbCfg := DatabaseConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "postgres",
|
||||
Password: "",
|
||||
DBName: "sub2api",
|
||||
SSLMode: "disable",
|
||||
}
|
||||
if !strings.Contains(dbCfg.DSN(), "password=") {
|
||||
} else {
|
||||
t.Fatalf("DatabaseConfig.DSN() should not include password when empty")
|
||||
}
|
||||
|
||||
dbCfg.Password = "secret"
|
||||
if !strings.Contains(dbCfg.DSN(), "password=secret") {
|
||||
t.Fatalf("DatabaseConfig.DSN() missing password")
|
||||
}
|
||||
|
||||
dbCfg.Password = ""
|
||||
if strings.Contains(dbCfg.DSNWithTimezone("UTC"), "password=") {
|
||||
t.Fatalf("DatabaseConfig.DSNWithTimezone() should omit password when empty")
|
||||
}
|
||||
|
||||
if !strings.Contains(dbCfg.DSNWithTimezone(""), "TimeZone=Asia/Shanghai") {
|
||||
t.Fatalf("DatabaseConfig.DSNWithTimezone() should use default timezone")
|
||||
}
|
||||
if !strings.Contains(dbCfg.DSNWithTimezone("UTC"), "TimeZone=UTC") {
|
||||
t.Fatalf("DatabaseConfig.DSNWithTimezone() should use provided timezone")
|
||||
}
|
||||
|
||||
redis := RedisConfig{Host: "redis", Port: 6379}
|
||||
if redis.Address() != "redis:6379" {
|
||||
t.Fatalf("RedisConfig.Address() = %q", redis.Address())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeStringSlice(t *testing.T) {
|
||||
values := normalizeStringSlice([]string{" a ", "", "b", " ", "c"})
|
||||
if len(values) != 3 || values[0] != "a" || values[1] != "b" || values[2] != "c" {
|
||||
t.Fatalf("normalizeStringSlice() unexpected result: %#v", values)
|
||||
}
|
||||
if normalizeStringSlice(nil) != nil {
|
||||
t.Fatalf("normalizeStringSlice(nil) expected nil slice")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetServerAddressFromEnv(t *testing.T) {
|
||||
t.Setenv("SERVER_HOST", "127.0.0.1")
|
||||
t.Setenv("SERVER_PORT", "9090")
|
||||
|
||||
address := GetServerAddress()
|
||||
if address != "127.0.0.1:9090" {
|
||||
t.Fatalf("GetServerAddress() = %q", address)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAbsoluteHTTPURL(t *testing.T) {
|
||||
if err := ValidateAbsoluteHTTPURL("https://example.com/path"); err != nil {
|
||||
t.Fatalf("ValidateAbsoluteHTTPURL valid url error: %v", err)
|
||||
}
|
||||
if err := ValidateAbsoluteHTTPURL(""); err == nil {
|
||||
t.Fatalf("ValidateAbsoluteHTTPURL should reject empty url")
|
||||
}
|
||||
if err := ValidateAbsoluteHTTPURL("/relative"); err == nil {
|
||||
t.Fatalf("ValidateAbsoluteHTTPURL should reject relative url")
|
||||
}
|
||||
if err := ValidateAbsoluteHTTPURL("ftp://example.com"); err == nil {
|
||||
t.Fatalf("ValidateAbsoluteHTTPURL should reject ftp scheme")
|
||||
}
|
||||
if err := ValidateAbsoluteHTTPURL("https://example.com/#frag"); err == nil {
|
||||
t.Fatalf("ValidateAbsoluteHTTPURL should reject fragment")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateFrontendRedirectURL(t *testing.T) {
|
||||
if err := ValidateFrontendRedirectURL("/auth/callback"); err != nil {
|
||||
t.Fatalf("ValidateFrontendRedirectURL relative error: %v", err)
|
||||
}
|
||||
if err := ValidateFrontendRedirectURL("https://example.com/auth"); err != nil {
|
||||
t.Fatalf("ValidateFrontendRedirectURL absolute error: %v", err)
|
||||
}
|
||||
if err := ValidateFrontendRedirectURL("example.com/path"); err == nil {
|
||||
t.Fatalf("ValidateFrontendRedirectURL should reject non-absolute url")
|
||||
}
|
||||
if err := ValidateFrontendRedirectURL("//evil.com"); err == nil {
|
||||
t.Fatalf("ValidateFrontendRedirectURL should reject // prefix")
|
||||
}
|
||||
if err := ValidateFrontendRedirectURL("javascript:alert(1)"); err == nil {
|
||||
t.Fatalf("ValidateFrontendRedirectURL should reject javascript scheme")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWarnIfInsecureURL(t *testing.T) {
|
||||
warnIfInsecureURL("test", "http://example.com")
|
||||
warnIfInsecureURL("test", "bad://url")
|
||||
}
|
||||
|
||||
func TestGenerateJWTSecretDefaultLength(t *testing.T) {
|
||||
secret, err := generateJWTSecret(0)
|
||||
if err != nil {
|
||||
t.Fatalf("generateJWTSecret error: %v", err)
|
||||
}
|
||||
if len(secret) == 0 {
|
||||
t.Fatalf("generateJWTSecret returned empty string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOpsCleanupScheduleRequired(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
cfg.Ops.Cleanup.Enabled = true
|
||||
cfg.Ops.Cleanup.Schedule = ""
|
||||
err = cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatalf("Validate() expected error for ops.cleanup.schedule")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "ops.cleanup.schedule") {
|
||||
t.Fatalf("Validate() expected ops.cleanup.schedule error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConcurrencyPingInterval(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
cfg.Concurrency.PingInterval = 3
|
||||
err = cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatalf("Validate() expected error for concurrency.ping_interval")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "concurrency.ping_interval") {
|
||||
t.Fatalf("Validate() expected concurrency.ping_interval error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvideConfig(t *testing.T) {
|
||||
viper.Reset()
|
||||
if _, err := ProvideConfig(); err != nil {
|
||||
t.Fatalf("ProvideConfig() error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConfigWithLinuxDoEnabled(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Security.CSP.Enabled = true
|
||||
cfg.Security.CSP.Policy = "default-src 'self'"
|
||||
|
||||
cfg.LinuxDo.Enabled = true
|
||||
cfg.LinuxDo.ClientID = "client"
|
||||
cfg.LinuxDo.ClientSecret = "secret"
|
||||
cfg.LinuxDo.AuthorizeURL = "https://example.com/oauth2/authorize"
|
||||
cfg.LinuxDo.TokenURL = "https://example.com/oauth2/token"
|
||||
cfg.LinuxDo.UserInfoURL = "https://example.com/oauth2/userinfo"
|
||||
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
|
||||
cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback"
|
||||
cfg.LinuxDo.TokenAuthMethod = "client_secret_post"
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
t.Fatalf("Validate() unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJWTSecretStrength(t *testing.T) {
|
||||
if !isWeakJWTSecret("change-me-in-production") {
|
||||
t.Fatalf("isWeakJWTSecret should detect weak secret")
|
||||
}
|
||||
if isWeakJWTSecret("StrongSecretValue") {
|
||||
t.Fatalf("isWeakJWTSecret should accept strong secret")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateJWTSecretWithLength(t *testing.T) {
|
||||
secret, err := generateJWTSecret(16)
|
||||
if err != nil {
|
||||
t.Fatalf("generateJWTSecret error: %v", err)
|
||||
}
|
||||
if len(secret) == 0 {
|
||||
t.Fatalf("generateJWTSecret returned empty string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAbsoluteHTTPURLMissingHost(t *testing.T) {
|
||||
if err := ValidateAbsoluteHTTPURL("https://"); err == nil {
|
||||
t.Fatalf("ValidateAbsoluteHTTPURL should reject missing host")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateFrontendRedirectURLInvalidChars(t *testing.T) {
|
||||
if err := ValidateFrontendRedirectURL("/auth/\ncallback"); err == nil {
|
||||
t.Fatalf("ValidateFrontendRedirectURL should reject invalid chars")
|
||||
}
|
||||
if err := ValidateFrontendRedirectURL("http://"); err == nil {
|
||||
t.Fatalf("ValidateFrontendRedirectURL should reject missing host")
|
||||
}
|
||||
if err := ValidateFrontendRedirectURL("mailto:user@example.com"); err == nil {
|
||||
t.Fatalf("ValidateFrontendRedirectURL should reject mailto")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWarnIfInsecureURLHTTPS(t *testing.T) {
|
||||
warnIfInsecureURL("secure", "https://example.com")
|
||||
}
|
||||
|
||||
func TestValidateConfigErrors(t *testing.T) {
|
||||
buildValid := func(t *testing.T) *Config {
|
||||
t.Helper()
|
||||
viper.Reset()
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
mutate func(*Config)
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "jwt expire hour positive",
|
||||
mutate: func(c *Config) { c.JWT.ExpireHour = 0 },
|
||||
wantErr: "jwt.expire_hour must be positive",
|
||||
},
|
||||
{
|
||||
name: "jwt expire hour max",
|
||||
mutate: func(c *Config) { c.JWT.ExpireHour = 200 },
|
||||
wantErr: "jwt.expire_hour must be <= 168",
|
||||
},
|
||||
{
|
||||
name: "csp policy required",
|
||||
mutate: func(c *Config) { c.Security.CSP.Enabled = true; c.Security.CSP.Policy = "" },
|
||||
wantErr: "security.csp.policy",
|
||||
},
|
||||
{
|
||||
name: "linuxdo client id required",
|
||||
mutate: func(c *Config) {
|
||||
c.LinuxDo.Enabled = true
|
||||
c.LinuxDo.ClientID = ""
|
||||
},
|
||||
wantErr: "linuxdo_connect.client_id",
|
||||
},
|
||||
{
|
||||
name: "linuxdo token auth method",
|
||||
mutate: func(c *Config) {
|
||||
c.LinuxDo.Enabled = true
|
||||
c.LinuxDo.ClientID = "client"
|
||||
c.LinuxDo.ClientSecret = "secret"
|
||||
c.LinuxDo.AuthorizeURL = "https://example.com/authorize"
|
||||
c.LinuxDo.TokenURL = "https://example.com/token"
|
||||
c.LinuxDo.UserInfoURL = "https://example.com/userinfo"
|
||||
c.LinuxDo.RedirectURL = "https://example.com/callback"
|
||||
c.LinuxDo.FrontendRedirectURL = "/auth/callback"
|
||||
c.LinuxDo.TokenAuthMethod = "invalid"
|
||||
},
|
||||
wantErr: "linuxdo_connect.token_auth_method",
|
||||
},
|
||||
{
|
||||
name: "billing circuit breaker threshold",
|
||||
mutate: func(c *Config) { c.Billing.CircuitBreaker.FailureThreshold = 0 },
|
||||
wantErr: "billing.circuit_breaker.failure_threshold",
|
||||
},
|
||||
{
|
||||
name: "billing circuit breaker reset",
|
||||
mutate: func(c *Config) { c.Billing.CircuitBreaker.ResetTimeoutSeconds = 0 },
|
||||
wantErr: "billing.circuit_breaker.reset_timeout_seconds",
|
||||
},
|
||||
{
|
||||
name: "billing circuit breaker half open",
|
||||
mutate: func(c *Config) { c.Billing.CircuitBreaker.HalfOpenRequests = 0 },
|
||||
wantErr: "billing.circuit_breaker.half_open_requests",
|
||||
},
|
||||
{
|
||||
name: "database max open conns",
|
||||
mutate: func(c *Config) { c.Database.MaxOpenConns = 0 },
|
||||
wantErr: "database.max_open_conns",
|
||||
},
|
||||
{
|
||||
name: "database max lifetime",
|
||||
mutate: func(c *Config) { c.Database.ConnMaxLifetimeMinutes = -1 },
|
||||
wantErr: "database.conn_max_lifetime_minutes",
|
||||
},
|
||||
{
|
||||
name: "database idle exceeds open",
|
||||
mutate: func(c *Config) { c.Database.MaxIdleConns = c.Database.MaxOpenConns + 1 },
|
||||
wantErr: "database.max_idle_conns cannot exceed",
|
||||
},
|
||||
{
|
||||
name: "redis dial timeout",
|
||||
mutate: func(c *Config) { c.Redis.DialTimeoutSeconds = 0 },
|
||||
wantErr: "redis.dial_timeout_seconds",
|
||||
},
|
||||
{
|
||||
name: "redis read timeout",
|
||||
mutate: func(c *Config) { c.Redis.ReadTimeoutSeconds = 0 },
|
||||
wantErr: "redis.read_timeout_seconds",
|
||||
},
|
||||
{
|
||||
name: "redis write timeout",
|
||||
mutate: func(c *Config) { c.Redis.WriteTimeoutSeconds = 0 },
|
||||
wantErr: "redis.write_timeout_seconds",
|
||||
},
|
||||
{
|
||||
name: "redis pool size",
|
||||
mutate: func(c *Config) { c.Redis.PoolSize = 0 },
|
||||
wantErr: "redis.pool_size",
|
||||
},
|
||||
{
|
||||
name: "redis idle exceeds pool",
|
||||
mutate: func(c *Config) { c.Redis.MinIdleConns = c.Redis.PoolSize + 1 },
|
||||
wantErr: "redis.min_idle_conns cannot exceed",
|
||||
},
|
||||
{
|
||||
name: "dashboard cache disabled negative",
|
||||
mutate: func(c *Config) { c.Dashboard.Enabled = false; c.Dashboard.StatsTTLSeconds = -1 },
|
||||
wantErr: "dashboard_cache.stats_ttl_seconds",
|
||||
},
|
||||
{
|
||||
name: "dashboard cache fresh ttl positive",
|
||||
mutate: func(c *Config) { c.Dashboard.Enabled = true; c.Dashboard.StatsFreshTTLSeconds = 0 },
|
||||
wantErr: "dashboard_cache.stats_fresh_ttl_seconds",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation enabled interval",
|
||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.IntervalSeconds = 0 },
|
||||
wantErr: "dashboard_aggregation.interval_seconds",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation backfill positive",
|
||||
mutate: func(c *Config) {
|
||||
c.DashboardAgg.Enabled = true
|
||||
c.DashboardAgg.BackfillEnabled = true
|
||||
c.DashboardAgg.BackfillMaxDays = 0
|
||||
},
|
||||
wantErr: "dashboard_aggregation.backfill_max_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation retention",
|
||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
|
||||
wantErr: "dashboard_aggregation.retention.usage_logs_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation disabled interval",
|
||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
|
||||
wantErr: "dashboard_aggregation.interval_seconds",
|
||||
},
|
||||
{
|
||||
name: "usage cleanup max range",
|
||||
mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.MaxRangeDays = 0 },
|
||||
wantErr: "usage_cleanup.max_range_days",
|
||||
},
|
||||
{
|
||||
name: "usage cleanup worker interval",
|
||||
mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.WorkerIntervalSeconds = 0 },
|
||||
wantErr: "usage_cleanup.worker_interval_seconds",
|
||||
},
|
||||
{
|
||||
name: "usage cleanup batch size",
|
||||
mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.BatchSize = 0 },
|
||||
wantErr: "usage_cleanup.batch_size",
|
||||
},
|
||||
{
|
||||
name: "usage cleanup disabled negative",
|
||||
mutate: func(c *Config) { c.UsageCleanup.Enabled = false; c.UsageCleanup.BatchSize = -1 },
|
||||
wantErr: "usage_cleanup.batch_size",
|
||||
},
|
||||
{
|
||||
name: "gateway max body size",
|
||||
mutate: func(c *Config) { c.Gateway.MaxBodySize = 0 },
|
||||
wantErr: "gateway.max_body_size",
|
||||
},
|
||||
{
|
||||
name: "gateway max idle conns",
|
||||
mutate: func(c *Config) { c.Gateway.MaxIdleConns = 0 },
|
||||
wantErr: "gateway.max_idle_conns",
|
||||
},
|
||||
{
|
||||
name: "gateway max idle conns per host",
|
||||
mutate: func(c *Config) { c.Gateway.MaxIdleConnsPerHost = 0 },
|
||||
wantErr: "gateway.max_idle_conns_per_host",
|
||||
},
|
||||
{
|
||||
name: "gateway idle timeout",
|
||||
mutate: func(c *Config) { c.Gateway.IdleConnTimeoutSeconds = 0 },
|
||||
wantErr: "gateway.idle_conn_timeout_seconds",
|
||||
},
|
||||
{
|
||||
name: "gateway max upstream clients",
|
||||
mutate: func(c *Config) { c.Gateway.MaxUpstreamClients = 0 },
|
||||
wantErr: "gateway.max_upstream_clients",
|
||||
},
|
||||
{
|
||||
name: "gateway client idle ttl",
|
||||
mutate: func(c *Config) { c.Gateway.ClientIdleTTLSeconds = 0 },
|
||||
wantErr: "gateway.client_idle_ttl_seconds",
|
||||
},
|
||||
{
|
||||
name: "gateway concurrency slot ttl",
|
||||
mutate: func(c *Config) { c.Gateway.ConcurrencySlotTTLMinutes = 0 },
|
||||
wantErr: "gateway.concurrency_slot_ttl_minutes",
|
||||
},
|
||||
{
|
||||
name: "gateway max conns per host",
|
||||
mutate: func(c *Config) { c.Gateway.MaxConnsPerHost = -1 },
|
||||
wantErr: "gateway.max_conns_per_host",
|
||||
},
|
||||
{
|
||||
name: "gateway connection isolation",
|
||||
mutate: func(c *Config) { c.Gateway.ConnectionPoolIsolation = "invalid" },
|
||||
wantErr: "gateway.connection_pool_isolation",
|
||||
},
|
||||
{
|
||||
name: "gateway stream keepalive range",
|
||||
mutate: func(c *Config) { c.Gateway.StreamKeepaliveInterval = 4 },
|
||||
wantErr: "gateway.stream_keepalive_interval",
|
||||
},
|
||||
{
|
||||
name: "gateway stream data interval range",
|
||||
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = 5 },
|
||||
wantErr: "gateway.stream_data_interval_timeout",
|
||||
},
|
||||
{
|
||||
name: "gateway stream data interval negative",
|
||||
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = -1 },
|
||||
wantErr: "gateway.stream_data_interval_timeout must be non-negative",
|
||||
},
|
||||
{
|
||||
name: "gateway max line size",
|
||||
mutate: func(c *Config) { c.Gateway.MaxLineSize = 1024 },
|
||||
wantErr: "gateway.max_line_size must be at least",
|
||||
},
|
||||
{
|
||||
name: "gateway max line size negative",
|
||||
mutate: func(c *Config) { c.Gateway.MaxLineSize = -1 },
|
||||
wantErr: "gateway.max_line_size must be non-negative",
|
||||
},
|
||||
{
|
||||
name: "gateway scheduling sticky waiting",
|
||||
mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 },
|
||||
wantErr: "gateway.scheduling.sticky_session_max_waiting",
|
||||
},
|
||||
{
|
||||
name: "gateway scheduling outbox poll",
|
||||
mutate: func(c *Config) { c.Gateway.Scheduling.OutboxPollIntervalSeconds = 0 },
|
||||
wantErr: "gateway.scheduling.outbox_poll_interval_seconds",
|
||||
},
|
||||
{
|
||||
name: "gateway scheduling outbox failures",
|
||||
mutate: func(c *Config) { c.Gateway.Scheduling.OutboxLagRebuildFailures = 0 },
|
||||
wantErr: "gateway.scheduling.outbox_lag_rebuild_failures",
|
||||
},
|
||||
{
|
||||
name: "gateway outbox lag rebuild",
|
||||
mutate: func(c *Config) {
|
||||
c.Gateway.Scheduling.OutboxLagWarnSeconds = 10
|
||||
c.Gateway.Scheduling.OutboxLagRebuildSeconds = 5
|
||||
},
|
||||
wantErr: "gateway.scheduling.outbox_lag_rebuild_seconds",
|
||||
},
|
||||
{
|
||||
name: "ops metrics collector ttl",
|
||||
mutate: func(c *Config) { c.Ops.MetricsCollectorCache.TTL = -1 },
|
||||
wantErr: "ops.metrics_collector_cache.ttl",
|
||||
},
|
||||
{
|
||||
name: "ops cleanup retention",
|
||||
mutate: func(c *Config) { c.Ops.Cleanup.ErrorLogRetentionDays = -1 },
|
||||
wantErr: "ops.cleanup.error_log_retention_days",
|
||||
},
|
||||
{
|
||||
name: "ops cleanup minute retention",
|
||||
mutate: func(c *Config) { c.Ops.Cleanup.MinuteMetricsRetentionDays = -1 },
|
||||
wantErr: "ops.cleanup.minute_metrics_retention_days",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := buildValid(t)
|
||||
tt.mutate(cfg)
|
||||
err := cfg.Validate()
|
||||
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Fatalf("Validate() error = %v, want %q", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
226
backend/internal/domain/announcement.go
Normal file
226
backend/internal/domain/announcement.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementStatusDraft = "draft"
|
||||
AnnouncementStatusActive = "active"
|
||||
AnnouncementStatusArchived = "archived"
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementConditionTypeSubscription = "subscription"
|
||||
AnnouncementConditionTypeBalance = "balance"
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementOperatorIn = "in"
|
||||
AnnouncementOperatorGT = "gt"
|
||||
AnnouncementOperatorGTE = "gte"
|
||||
AnnouncementOperatorLT = "lt"
|
||||
AnnouncementOperatorLTE = "lte"
|
||||
AnnouncementOperatorEQ = "eq"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAnnouncementNotFound = infraerrors.NotFound("ANNOUNCEMENT_NOT_FOUND", "announcement not found")
|
||||
ErrAnnouncementInvalidTarget = infraerrors.BadRequest("ANNOUNCEMENT_INVALID_TARGET", "invalid announcement targeting rules")
|
||||
)
|
||||
|
||||
type AnnouncementTargeting struct {
|
||||
// AnyOf 表示 OR:任意一个条件组满足即可展示。
|
||||
AnyOf []AnnouncementConditionGroup `json:"any_of,omitempty"`
|
||||
}
|
||||
|
||||
type AnnouncementConditionGroup struct {
|
||||
// AllOf 表示 AND:组内所有条件都满足才算命中该组。
|
||||
AllOf []AnnouncementCondition `json:"all_of,omitempty"`
|
||||
}
|
||||
|
||||
type AnnouncementCondition struct {
|
||||
// Type: subscription | balance
|
||||
Type string `json:"type"`
|
||||
|
||||
// Operator:
|
||||
// - subscription: in
|
||||
// - balance: gt/gte/lt/lte/eq
|
||||
Operator string `json:"operator"`
|
||||
|
||||
// subscription 条件:匹配的订阅套餐(group_id)
|
||||
GroupIDs []int64 `json:"group_ids,omitempty"`
|
||||
|
||||
// balance 条件:比较阈值
|
||||
Value float64 `json:"value,omitempty"`
|
||||
}
|
||||
|
||||
func (t AnnouncementTargeting) Matches(balance float64, activeSubscriptionGroupIDs map[int64]struct{}) bool {
|
||||
// 空规则:展示给所有用户
|
||||
if len(t.AnyOf) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, group := range t.AnyOf {
|
||||
if len(group.AllOf) == 0 {
|
||||
// 空条件组不命中(避免 OR 中出现无条件 “全命中”)
|
||||
continue
|
||||
}
|
||||
allMatched := true
|
||||
for _, cond := range group.AllOf {
|
||||
if !cond.Matches(balance, activeSubscriptionGroupIDs) {
|
||||
allMatched = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allMatched {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (c AnnouncementCondition) Matches(balance float64, activeSubscriptionGroupIDs map[int64]struct{}) bool {
|
||||
switch c.Type {
|
||||
case AnnouncementConditionTypeSubscription:
|
||||
if c.Operator != AnnouncementOperatorIn {
|
||||
return false
|
||||
}
|
||||
if len(c.GroupIDs) == 0 {
|
||||
return false
|
||||
}
|
||||
if len(activeSubscriptionGroupIDs) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, gid := range c.GroupIDs {
|
||||
if _, ok := activeSubscriptionGroupIDs[gid]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
|
||||
case AnnouncementConditionTypeBalance:
|
||||
switch c.Operator {
|
||||
case AnnouncementOperatorGT:
|
||||
return balance > c.Value
|
||||
case AnnouncementOperatorGTE:
|
||||
return balance >= c.Value
|
||||
case AnnouncementOperatorLT:
|
||||
return balance < c.Value
|
||||
case AnnouncementOperatorLTE:
|
||||
return balance <= c.Value
|
||||
case AnnouncementOperatorEQ:
|
||||
return balance == c.Value
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (t AnnouncementTargeting) NormalizeAndValidate() (AnnouncementTargeting, error) {
|
||||
normalized := AnnouncementTargeting{AnyOf: make([]AnnouncementConditionGroup, 0, len(t.AnyOf))}
|
||||
|
||||
// 允许空 targeting(展示给所有用户)
|
||||
if len(t.AnyOf) == 0 {
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
if len(t.AnyOf) > 50 {
|
||||
return AnnouncementTargeting{}, ErrAnnouncementInvalidTarget
|
||||
}
|
||||
|
||||
for _, g := range t.AnyOf {
|
||||
if len(g.AllOf) == 0 {
|
||||
return AnnouncementTargeting{}, ErrAnnouncementInvalidTarget
|
||||
}
|
||||
if len(g.AllOf) > 50 {
|
||||
return AnnouncementTargeting{}, ErrAnnouncementInvalidTarget
|
||||
}
|
||||
|
||||
group := AnnouncementConditionGroup{AllOf: make([]AnnouncementCondition, 0, len(g.AllOf))}
|
||||
for _, c := range g.AllOf {
|
||||
cond := AnnouncementCondition{
|
||||
Type: strings.TrimSpace(c.Type),
|
||||
Operator: strings.TrimSpace(c.Operator),
|
||||
Value: c.Value,
|
||||
}
|
||||
for _, gid := range c.GroupIDs {
|
||||
if gid <= 0 {
|
||||
return AnnouncementTargeting{}, ErrAnnouncementInvalidTarget
|
||||
}
|
||||
cond.GroupIDs = append(cond.GroupIDs, gid)
|
||||
}
|
||||
|
||||
if err := cond.validate(); err != nil {
|
||||
return AnnouncementTargeting{}, err
|
||||
}
|
||||
group.AllOf = append(group.AllOf, cond)
|
||||
}
|
||||
|
||||
normalized.AnyOf = append(normalized.AnyOf, group)
|
||||
}
|
||||
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func (c AnnouncementCondition) validate() error {
|
||||
switch c.Type {
|
||||
case AnnouncementConditionTypeSubscription:
|
||||
if c.Operator != AnnouncementOperatorIn {
|
||||
return ErrAnnouncementInvalidTarget
|
||||
}
|
||||
if len(c.GroupIDs) == 0 {
|
||||
return ErrAnnouncementInvalidTarget
|
||||
}
|
||||
return nil
|
||||
|
||||
case AnnouncementConditionTypeBalance:
|
||||
switch c.Operator {
|
||||
case AnnouncementOperatorGT, AnnouncementOperatorGTE, AnnouncementOperatorLT, AnnouncementOperatorLTE, AnnouncementOperatorEQ:
|
||||
return nil
|
||||
default:
|
||||
return ErrAnnouncementInvalidTarget
|
||||
}
|
||||
|
||||
default:
|
||||
return ErrAnnouncementInvalidTarget
|
||||
}
|
||||
}
|
||||
|
||||
type Announcement struct {
|
||||
ID int64
|
||||
Title string
|
||||
Content string
|
||||
Status string
|
||||
Targeting AnnouncementTargeting
|
||||
StartsAt *time.Time
|
||||
EndsAt *time.Time
|
||||
CreatedBy *int64
|
||||
UpdatedBy *int64
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
func (a *Announcement) IsActiveAt(now time.Time) bool {
|
||||
if a == nil {
|
||||
return false
|
||||
}
|
||||
if a.Status != AnnouncementStatusActive {
|
||||
return false
|
||||
}
|
||||
if a.StartsAt != nil && now.Before(*a.StartsAt) {
|
||||
return false
|
||||
}
|
||||
if a.EndsAt != nil && !now.Before(*a.EndsAt) {
|
||||
// ends_at 语义:到点即下线
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
64
backend/internal/domain/constants.go
Normal file
64
backend/internal/domain/constants.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package domain
|
||||
|
||||
// Status constants
|
||||
const (
|
||||
StatusActive = "active"
|
||||
StatusDisabled = "disabled"
|
||||
StatusError = "error"
|
||||
StatusUnused = "unused"
|
||||
StatusUsed = "used"
|
||||
StatusExpired = "expired"
|
||||
)
|
||||
|
||||
// Role constants
|
||||
const (
|
||||
RoleAdmin = "admin"
|
||||
RoleUser = "user"
|
||||
)
|
||||
|
||||
// Platform constants
|
||||
const (
|
||||
PlatformAnthropic = "anthropic"
|
||||
PlatformOpenAI = "openai"
|
||||
PlatformGemini = "gemini"
|
||||
PlatformAntigravity = "antigravity"
|
||||
)
|
||||
|
||||
// Account type constants
|
||||
const (
|
||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
const (
|
||||
RedeemTypeBalance = "balance"
|
||||
RedeemTypeConcurrency = "concurrency"
|
||||
RedeemTypeSubscription = "subscription"
|
||||
)
|
||||
|
||||
// PromoCode status constants
|
||||
const (
|
||||
PromoCodeStatusActive = "active"
|
||||
PromoCodeStatusDisabled = "disabled"
|
||||
)
|
||||
|
||||
// Admin adjustment type constants
|
||||
const (
|
||||
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
|
||||
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
|
||||
)
|
||||
|
||||
// Group subscription type constants
|
||||
const (
|
||||
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
|
||||
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
|
||||
)
|
||||
|
||||
// Subscription status constants
|
||||
const (
|
||||
SubscriptionStatusActive = "active"
|
||||
SubscriptionStatusExpired = "expired"
|
||||
SubscriptionStatusSuspended = "suspended"
|
||||
)
|
||||
@@ -45,6 +45,7 @@ type AccountHandler struct {
|
||||
concurrencyService *service.ConcurrencyService
|
||||
crsSyncService *service.CRSSyncService
|
||||
sessionLimitCache service.SessionLimitCache
|
||||
tokenCacheInvalidator service.TokenCacheInvalidator
|
||||
}
|
||||
|
||||
// NewAccountHandler creates a new admin account handler
|
||||
@@ -60,6 +61,7 @@ func NewAccountHandler(
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
crsSyncService *service.CRSSyncService,
|
||||
sessionLimitCache service.SessionLimitCache,
|
||||
tokenCacheInvalidator service.TokenCacheInvalidator,
|
||||
) *AccountHandler {
|
||||
return &AccountHandler{
|
||||
adminService: adminService,
|
||||
@@ -73,6 +75,7 @@ func NewAccountHandler(
|
||||
concurrencyService: concurrencyService,
|
||||
crsSyncService: crsSyncService,
|
||||
sessionLimitCache: sessionLimitCache,
|
||||
tokenCacheInvalidator: tokenCacheInvalidator,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,13 +132,6 @@ type BulkUpdateAccountsRequest struct {
|
||||
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
|
||||
}
|
||||
|
||||
// AccountLookupRequest 用于凭证身份信息查找账号
|
||||
type AccountLookupRequest struct {
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
Emails []string `json:"emails" binding:"required,min=1"`
|
||||
IdentityType string `json:"identity_type"`
|
||||
}
|
||||
|
||||
// AccountWithConcurrency extends Account with real-time concurrency info
|
||||
type AccountWithConcurrency struct {
|
||||
*dto.Account
|
||||
@@ -180,6 +176,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
// 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
|
||||
windowCostAccountIDs := make([]int64, 0)
|
||||
sessionLimitAccountIDs := make([]int64, 0)
|
||||
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if acc.IsAnthropicOAuthOrSetupToken() {
|
||||
@@ -188,6 +185,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
}
|
||||
if acc.GetMaxSessions() > 0 {
|
||||
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
|
||||
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -196,9 +194,9 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
var windowCosts map[int64]float64
|
||||
var activeSessions map[int64]int
|
||||
|
||||
// 获取活跃会话数(批量查询)
|
||||
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
|
||||
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
|
||||
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs)
|
||||
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
|
||||
if activeSessions == nil {
|
||||
activeSessions = make(map[int64]int)
|
||||
}
|
||||
@@ -218,12 +216,8 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
}
|
||||
accCopy := acc // 闭包捕获
|
||||
g.Go(func() error {
|
||||
var startTime time.Time
|
||||
if accCopy.SessionWindowStart != nil {
|
||||
startTime = *accCopy.SessionWindowStart
|
||||
} else {
|
||||
startTime = time.Now().Add(-5 * time.Hour)
|
||||
}
|
||||
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
|
||||
startTime := accCopy.GetCurrentWindowStartTime()
|
||||
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
|
||||
if err == nil && stats != nil {
|
||||
mu.Lock()
|
||||
@@ -265,87 +259,6 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
response.Paginated(c, result, total, page, pageSize)
|
||||
}
|
||||
|
||||
// Lookup 根据凭证身份信息查找账号
|
||||
// POST /api/v1/admin/accounts/lookup
|
||||
func (h *AccountHandler) Lookup(c *gin.Context) {
|
||||
var req AccountLookupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
identityType := strings.TrimSpace(req.IdentityType)
|
||||
if identityType == "" {
|
||||
identityType = "credential_email"
|
||||
}
|
||||
if identityType != "credential_email" {
|
||||
response.BadRequest(c, "Unsupported identity_type")
|
||||
return
|
||||
}
|
||||
|
||||
platform := strings.TrimSpace(req.Platform)
|
||||
if platform == "" {
|
||||
response.BadRequest(c, "Platform is required")
|
||||
return
|
||||
}
|
||||
|
||||
normalized := make([]string, 0, len(req.Emails))
|
||||
seen := make(map[string]struct{})
|
||||
for _, email := range req.Emails {
|
||||
cleaned := strings.ToLower(strings.TrimSpace(email))
|
||||
if cleaned == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[cleaned]; ok {
|
||||
continue
|
||||
}
|
||||
seen[cleaned] = struct{}{}
|
||||
normalized = append(normalized, cleaned)
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
response.BadRequest(c, "Emails is required")
|
||||
return
|
||||
}
|
||||
|
||||
accounts, err := h.adminService.LookupAccountsByCredentialEmail(c.Request.Context(), platform, normalized)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
matchedMap := make(map[string]service.Account)
|
||||
for _, account := range accounts {
|
||||
email := strings.ToLower(strings.TrimSpace(account.GetCredential("email")))
|
||||
if email == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := matchedMap[email]; ok {
|
||||
continue
|
||||
}
|
||||
matchedMap[email] = account
|
||||
}
|
||||
|
||||
matched := make([]gin.H, 0, len(matchedMap))
|
||||
missing := make([]string, 0)
|
||||
for _, email := range normalized {
|
||||
if account, ok := matchedMap[email]; ok {
|
||||
matched = append(matched, gin.H{
|
||||
"email": email,
|
||||
"account_id": account.ID,
|
||||
"platform": account.Platform,
|
||||
"name": account.Name,
|
||||
})
|
||||
continue
|
||||
}
|
||||
missing = append(missing, email)
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"matched": matched,
|
||||
"missing": missing,
|
||||
})
|
||||
}
|
||||
|
||||
// GetByID handles getting an account by ID
|
||||
// GET /api/v1/admin/accounts/:id
|
||||
func (h *AccountHandler) GetByID(c *gin.Context) {
|
||||
@@ -634,9 +547,18 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,先更新凭证,再标记账户为 error
|
||||
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
|
||||
// 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
|
||||
if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
|
||||
if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
|
||||
newCredentials["project_id"] = oldProjectID
|
||||
}
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,更新凭证但不标记为 error
|
||||
// LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试
|
||||
if tokenInfo.ProjectIDMissing {
|
||||
// 先更新凭证
|
||||
// 先更新凭证(token 本身刷新成功了)
|
||||
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
@@ -644,14 +566,10 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
|
||||
return
|
||||
}
|
||||
// 标记账户为 error
|
||||
if setErr := h.adminService.SetAccountError(c.Request.Context(), accountID, "missing_project_id: 账户缺少project id,可能无法使用Antigravity"); setErr != nil {
|
||||
response.InternalError(c, "Failed to set account error: "+setErr.Error())
|
||||
return
|
||||
}
|
||||
// 不标记为 error,只返回警告信息
|
||||
response.Success(c, gin.H{
|
||||
"message": "Token refreshed but project_id is missing, account marked as error",
|
||||
"warning": "missing_project_id",
|
||||
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
|
||||
"warning": "missing_project_id_temporary",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -698,6 +616,14 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
|
||||
if h.tokenCacheInvalidator != nil {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), updatedAccount); invalidateErr != nil {
|
||||
// 缓存失效失败只记录日志,不影响主流程
|
||||
_ = c.Error(invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(updatedAccount))
|
||||
}
|
||||
|
||||
@@ -747,6 +673,15 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 清除错误后,同时清除 token 缓存,确保下次请求会获取最新的 token(触发刷新或从 DB 读取)
|
||||
// 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题
|
||||
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil {
|
||||
// 缓存失效失败只记录日志,不影响主流程
|
||||
_ = c.Error(invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
}
|
||||
|
||||
|
||||
262
backend/internal/handler/admin/admin_basic_handlers_test.go
Normal file
262
backend/internal/handler/admin/admin_basic_handlers_test.go
Normal file
@@ -0,0 +1,262 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func setupAdminRouter() (*gin.Engine, *stubAdminService) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
adminSvc := newStubAdminService()
|
||||
|
||||
userHandler := NewUserHandler(adminSvc)
|
||||
groupHandler := NewGroupHandler(adminSvc)
|
||||
proxyHandler := NewProxyHandler(adminSvc)
|
||||
redeemHandler := NewRedeemHandler(adminSvc)
|
||||
|
||||
router.GET("/api/v1/admin/users", userHandler.List)
|
||||
router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
|
||||
router.POST("/api/v1/admin/users", userHandler.Create)
|
||||
router.PUT("/api/v1/admin/users/:id", userHandler.Update)
|
||||
router.DELETE("/api/v1/admin/users/:id", userHandler.Delete)
|
||||
router.POST("/api/v1/admin/users/:id/balance", userHandler.UpdateBalance)
|
||||
router.GET("/api/v1/admin/users/:id/api-keys", userHandler.GetUserAPIKeys)
|
||||
router.GET("/api/v1/admin/users/:id/usage", userHandler.GetUserUsage)
|
||||
|
||||
router.GET("/api/v1/admin/groups", groupHandler.List)
|
||||
router.GET("/api/v1/admin/groups/all", groupHandler.GetAll)
|
||||
router.GET("/api/v1/admin/groups/:id", groupHandler.GetByID)
|
||||
router.POST("/api/v1/admin/groups", groupHandler.Create)
|
||||
router.PUT("/api/v1/admin/groups/:id", groupHandler.Update)
|
||||
router.DELETE("/api/v1/admin/groups/:id", groupHandler.Delete)
|
||||
router.GET("/api/v1/admin/groups/:id/stats", groupHandler.GetStats)
|
||||
router.GET("/api/v1/admin/groups/:id/api-keys", groupHandler.GetGroupAPIKeys)
|
||||
|
||||
router.GET("/api/v1/admin/proxies", proxyHandler.List)
|
||||
router.GET("/api/v1/admin/proxies/all", proxyHandler.GetAll)
|
||||
router.GET("/api/v1/admin/proxies/:id", proxyHandler.GetByID)
|
||||
router.POST("/api/v1/admin/proxies", proxyHandler.Create)
|
||||
router.PUT("/api/v1/admin/proxies/:id", proxyHandler.Update)
|
||||
router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete)
|
||||
router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete)
|
||||
router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test)
|
||||
router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats)
|
||||
router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts)
|
||||
|
||||
router.GET("/api/v1/admin/redeem-codes", redeemHandler.List)
|
||||
router.GET("/api/v1/admin/redeem-codes/:id", redeemHandler.GetByID)
|
||||
router.POST("/api/v1/admin/redeem-codes", redeemHandler.Generate)
|
||||
router.DELETE("/api/v1/admin/redeem-codes/:id", redeemHandler.Delete)
|
||||
router.POST("/api/v1/admin/redeem-codes/batch-delete", redeemHandler.BatchDelete)
|
||||
router.POST("/api/v1/admin/redeem-codes/:id/expire", redeemHandler.Expire)
|
||||
router.GET("/api/v1/admin/redeem-codes/:id/stats", redeemHandler.GetStats)
|
||||
|
||||
return router, adminSvc
|
||||
}
|
||||
|
||||
func TestUserHandlerEndpoints(t *testing.T) {
|
||||
router, _ := setupAdminRouter()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/users?page=1&page_size=20", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2}
|
||||
body, _ := json.Marshal(createBody)
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
updateBody := map[string]any{"email": "updated@example.com"}
|
||||
body, _ = json.Marshal(updateBody)
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/users/1", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/users/1", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/balance", bytes.NewBufferString(`{"balance":1,"operation":"add"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1/api-keys", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1/usage?period=today", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestGroupHandlerEndpoints(t *testing.T) {
|
||||
router, _ := setupAdminRouter()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/all", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{"name": "new", "platform": "anthropic", "subscription_type": "standard"})
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/groups", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
body, _ = json.Marshal(map[string]any{"name": "update"})
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/groups/2", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/groups/2", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2/stats", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2/api-keys", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestProxyHandlerEndpoints(t *testing.T) {
|
||||
router, _ := setupAdminRouter()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/all", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{"name": "proxy", "protocol": "http", "host": "localhost", "port": 8080})
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
body, _ = json.Marshal(map[string]any{"name": "proxy2"})
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/proxies/4", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/proxies/4", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/batch-delete", bytes.NewBufferString(`{"ids":[1,2]}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/test", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/accounts", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestRedeemHandlerEndpoints(t *testing.T) {
|
||||
router, _ := setupAdminRouter()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/5", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{"count": 1, "type": "balance", "value": 10})
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/redeem-codes/5", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/batch-delete", bytes.NewBufferString(`{"ids":[1,2]}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/5/expire", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/5/stats", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
134
backend/internal/handler/admin/admin_helpers_test.go
Normal file
134
backend/internal/handler/admin/admin_helpers_test.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseTimeRange(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
req := httptest.NewRequest(http.MethodGet, "/?start_date=2024-01-01&end_date=2024-01-02&timezone=UTC", nil)
|
||||
c.Request = req
|
||||
|
||||
start, end := parseTimeRange(c)
|
||||
require.Equal(t, time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), start)
|
||||
require.Equal(t, time.Date(2024, 1, 3, 0, 0, 0, 0, time.UTC), end)
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/?start_date=bad&timezone=UTC", nil)
|
||||
c.Request = req
|
||||
start, end = parseTimeRange(c)
|
||||
require.False(t, start.IsZero())
|
||||
require.False(t, end.IsZero())
|
||||
}
|
||||
|
||||
func TestParseOpsViewParam(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/?view=excluded", nil)
|
||||
require.Equal(t, opsListViewExcluded, parseOpsViewParam(c))
|
||||
|
||||
c2, _ := gin.CreateTestContext(w)
|
||||
c2.Request = httptest.NewRequest(http.MethodGet, "/?view=all", nil)
|
||||
require.Equal(t, opsListViewAll, parseOpsViewParam(c2))
|
||||
|
||||
c3, _ := gin.CreateTestContext(w)
|
||||
c3.Request = httptest.NewRequest(http.MethodGet, "/?view=unknown", nil)
|
||||
require.Equal(t, opsListViewErrors, parseOpsViewParam(c3))
|
||||
|
||||
require.Equal(t, "", parseOpsViewParam(nil))
|
||||
}
|
||||
|
||||
func TestParseOpsDuration(t *testing.T) {
|
||||
dur, ok := parseOpsDuration("1h")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, time.Hour, dur)
|
||||
|
||||
_, ok = parseOpsDuration("invalid")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestParseOpsTimeRange(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
now := time.Now().UTC()
|
||||
startStr := now.Add(-time.Hour).Format(time.RFC3339)
|
||||
endStr := now.Format(time.RFC3339)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/?start_time="+startStr+"&end_time="+endStr, nil)
|
||||
start, end, err := parseOpsTimeRange(c, "1h")
|
||||
require.NoError(t, err)
|
||||
require.True(t, start.Before(end))
|
||||
|
||||
c2, _ := gin.CreateTestContext(w)
|
||||
c2.Request = httptest.NewRequest(http.MethodGet, "/?start_time=bad", nil)
|
||||
_, _, err = parseOpsTimeRange(c2, "1h")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseOpsRealtimeWindow(t *testing.T) {
|
||||
dur, label, ok := parseOpsRealtimeWindow("5m")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 5*time.Minute, dur)
|
||||
require.Equal(t, "5min", label)
|
||||
|
||||
_, _, ok = parseOpsRealtimeWindow("invalid")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestPickThroughputBucketSeconds(t *testing.T) {
|
||||
require.Equal(t, 60, pickThroughputBucketSeconds(30*time.Minute))
|
||||
require.Equal(t, 300, pickThroughputBucketSeconds(6*time.Hour))
|
||||
require.Equal(t, 3600, pickThroughputBucketSeconds(48*time.Hour))
|
||||
}
|
||||
|
||||
func TestParseOpsQueryMode(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/?mode=raw", nil)
|
||||
require.Equal(t, service.ParseOpsQueryMode("raw"), parseOpsQueryMode(c))
|
||||
require.Equal(t, service.OpsQueryMode(""), parseOpsQueryMode(nil))
|
||||
}
|
||||
|
||||
func TestOpsAlertRuleValidation(t *testing.T) {
|
||||
raw := map[string]json.RawMessage{
|
||||
"name": json.RawMessage(`"High error rate"`),
|
||||
"metric_type": json.RawMessage(`"error_rate"`),
|
||||
"operator": json.RawMessage(`">"`),
|
||||
"threshold": json.RawMessage(`90`),
|
||||
}
|
||||
|
||||
validated, err := validateOpsAlertRulePayload(raw)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "High error rate", validated.Name)
|
||||
|
||||
_, err = validateOpsAlertRulePayload(map[string]json.RawMessage{})
|
||||
require.Error(t, err)
|
||||
|
||||
require.True(t, isPercentOrRateMetric("error_rate"))
|
||||
require.False(t, isPercentOrRateMetric("concurrency_queue_depth"))
|
||||
}
|
||||
|
||||
func TestOpsWSHelpers(t *testing.T) {
|
||||
prefixes, invalid := parseTrustedProxyList("10.0.0.0/8,invalid")
|
||||
require.Len(t, prefixes, 1)
|
||||
require.Len(t, invalid, 1)
|
||||
|
||||
host := hostWithoutPort("example.com:443")
|
||||
require.Equal(t, "example.com", host)
|
||||
|
||||
addr := netip.MustParseAddr("10.0.0.1")
|
||||
require.True(t, isAddrInTrustedProxies(addr, prefixes))
|
||||
require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes))
|
||||
}
|
||||
294
backend/internal/handler/admin/admin_service_stub_test.go
Normal file
294
backend/internal/handler/admin/admin_service_stub_test.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type stubAdminService struct {
|
||||
users []service.User
|
||||
apiKeys []service.APIKey
|
||||
groups []service.Group
|
||||
accounts []service.Account
|
||||
proxies []service.Proxy
|
||||
proxyCounts []service.ProxyWithAccountCount
|
||||
redeems []service.RedeemCode
|
||||
}
|
||||
|
||||
func newStubAdminService() *stubAdminService {
|
||||
now := time.Now().UTC()
|
||||
user := service.User{
|
||||
ID: 1,
|
||||
Email: "user@example.com",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
apiKey := service.APIKey{
|
||||
ID: 10,
|
||||
UserID: user.ID,
|
||||
Key: "sk-test",
|
||||
Name: "test",
|
||||
Status: service.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
group := service.Group{
|
||||
ID: 2,
|
||||
Name: "group",
|
||||
Platform: service.PlatformAnthropic,
|
||||
Status: service.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
account := service.Account{
|
||||
ID: 3,
|
||||
Name: "account",
|
||||
Platform: service.PlatformAnthropic,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
proxy := service.Proxy{
|
||||
ID: 4,
|
||||
Name: "proxy",
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Status: service.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
redeem := service.RedeemCode{
|
||||
ID: 5,
|
||||
Code: "R-TEST",
|
||||
Type: service.RedeemTypeBalance,
|
||||
Value: 10,
|
||||
Status: service.StatusUnused,
|
||||
CreatedAt: now,
|
||||
}
|
||||
return &stubAdminService{
|
||||
users: []service.User{user},
|
||||
apiKeys: []service.APIKey{apiKey},
|
||||
groups: []service.Group{group},
|
||||
accounts: []service.Account{account},
|
||||
proxies: []service.Proxy{proxy},
|
||||
proxyCounts: []service.ProxyWithAccountCount{{Proxy: proxy, AccountCount: 1}},
|
||||
redeems: []service.RedeemCode{redeem},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters) ([]service.User, int64, error) {
|
||||
return s.users, int64(len(s.users)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetUser(ctx context.Context, id int64) (*service.User, error) {
|
||||
for i := range s.users {
|
||||
if s.users[i].ID == id {
|
||||
return &s.users[i], nil
|
||||
}
|
||||
}
|
||||
user := service.User{ID: id, Email: "user@example.com", Status: service.StatusActive}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CreateUser(ctx context.Context, input *service.CreateUserInput) (*service.User, error) {
|
||||
user := service.User{ID: 100, Email: input.Email, Status: service.StatusActive}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) UpdateUser(ctx context.Context, id int64, input *service.UpdateUserInput) (*service.User, error) {
|
||||
user := service.User{ID: id, Email: "updated@example.com", Status: service.StatusActive}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) DeleteUser(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*service.User, error) {
|
||||
user := service.User{ID: userID, Balance: balance, Status: service.StatusActive}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]service.APIKey, int64, error) {
|
||||
return s.apiKeys, int64(len(s.apiKeys)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
|
||||
return map[string]any{"user_id": userID}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]service.Group, int64, error) {
|
||||
return s.groups, int64(len(s.groups)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetAllGroups(ctx context.Context) ([]service.Group, error) {
|
||||
return s.groups, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
|
||||
return s.groups, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetGroup(ctx context.Context, id int64) (*service.Group, error) {
|
||||
group := service.Group{ID: id, Name: "group", Status: service.StatusActive}
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CreateGroup(ctx context.Context, input *service.CreateGroupInput) (*service.Group, error) {
|
||||
group := service.Group{ID: 200, Name: input.Name, Status: service.StatusActive}
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) UpdateGroup(ctx context.Context, id int64, input *service.UpdateGroupInput) (*service.Group, error) {
|
||||
group := service.Group{ID: id, Name: input.Name, Status: service.StatusActive}
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) DeleteGroup(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]service.APIKey, int64, error) {
|
||||
return s.apiKeys, int64(len(s.apiKeys)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]service.Account, int64, error) {
|
||||
return s.accounts, int64(len(s.accounts)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetAccount(ctx context.Context, id int64) (*service.Account, error) {
|
||||
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
|
||||
out := make([]*service.Account, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
|
||||
out = append(out, &account)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.CreateAccountInput) (*service.Account, error) {
|
||||
account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
|
||||
account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) DeleteAccount(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) RefreshAccountCredentials(ctx context.Context, id int64) (*service.Account, error) {
|
||||
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ClearAccountError(ctx context.Context, id int64) (*service.Account, error) {
|
||||
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) SetAccountError(ctx context.Context, id int64, errorMsg string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*service.Account, error) {
|
||||
account := service.Account{ID: id, Name: "account", Status: service.StatusActive, Schedulable: schedulable}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *service.BulkUpdateAccountsInput) (*service.BulkUpdateAccountsResult, error) {
|
||||
return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
|
||||
return s.proxies, int64(len(s.proxies)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.ProxyWithAccountCount, int64, error) {
|
||||
return s.proxyCounts, int64(len(s.proxyCounts)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetAllProxies(ctx context.Context) ([]service.Proxy, error) {
|
||||
return s.proxies, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetAllProxiesWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) {
|
||||
return s.proxyCounts, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetProxy(ctx context.Context, id int64) (*service.Proxy, error) {
|
||||
proxy := service.Proxy{ID: id, Name: "proxy", Status: service.StatusActive}
|
||||
return &proxy, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CreateProxy(ctx context.Context, input *service.CreateProxyInput) (*service.Proxy, error) {
|
||||
proxy := service.Proxy{ID: 400, Name: input.Name, Status: service.StatusActive}
|
||||
return &proxy, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) UpdateProxy(ctx context.Context, id int64, input *service.UpdateProxyInput) (*service.Proxy, error) {
|
||||
proxy := service.Proxy{ID: id, Name: input.Name, Status: service.StatusActive}
|
||||
return &proxy, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) DeleteProxy(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) BatchDeleteProxies(ctx context.Context, ids []int64) (*service.ProxyBatchDeleteResult, error) {
|
||||
return &service.ProxyBatchDeleteResult{DeletedIDs: ids}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetProxyAccounts(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
|
||||
return []service.ProxyAccountSummary{{ID: 1, Name: "account"}}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.ProxyTestResult, error) {
|
||||
return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) {
|
||||
return s.redeems, int64(len(s.redeems)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetRedeemCode(ctx context.Context, id int64) (*service.RedeemCode, error) {
|
||||
code := service.RedeemCode{ID: id, Code: "R-TEST", Status: service.StatusUnused}
|
||||
return &code, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GenerateRedeemCodes(ctx context.Context, input *service.GenerateRedeemCodesInput) ([]service.RedeemCode, error) {
|
||||
return s.redeems, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) DeleteRedeemCode(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) {
|
||||
return int64(len(ids)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ExpireRedeemCode(ctx context.Context, id int64) (*service.RedeemCode, error) {
|
||||
code := service.RedeemCode{ID: id, Code: "R-TEST", Status: service.StatusUsed}
|
||||
return &code, nil
|
||||
}
|
||||
|
||||
// Ensure stub implements interface.
|
||||
var _ service.AdminService = (*stubAdminService)(nil)
|
||||
246
backend/internal/handler/admin/announcement_handler.go
Normal file
246
backend/internal/handler/admin/announcement_handler.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AnnouncementHandler handles admin announcement management
|
||||
type AnnouncementHandler struct {
|
||||
announcementService *service.AnnouncementService
|
||||
}
|
||||
|
||||
// NewAnnouncementHandler creates a new admin announcement handler
|
||||
func NewAnnouncementHandler(announcementService *service.AnnouncementService) *AnnouncementHandler {
|
||||
return &AnnouncementHandler{
|
||||
announcementService: announcementService,
|
||||
}
|
||||
}
|
||||
|
||||
type CreateAnnouncementRequest struct {
|
||||
Title string `json:"title" binding:"required"`
|
||||
Content string `json:"content" binding:"required"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=draft active archived"`
|
||||
Targeting service.AnnouncementTargeting `json:"targeting"`
|
||||
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0/empty = immediate
|
||||
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0/empty = never
|
||||
}
|
||||
|
||||
type UpdateAnnouncementRequest struct {
|
||||
Title *string `json:"title"`
|
||||
Content *string `json:"content"`
|
||||
Status *string `json:"status" binding:"omitempty,oneof=draft active archived"`
|
||||
Targeting *service.AnnouncementTargeting `json:"targeting"`
|
||||
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0 = clear
|
||||
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0 = clear
|
||||
}
|
||||
|
||||
// List handles listing announcements with filters
|
||||
// GET /api/v1/admin/announcements
|
||||
func (h *AnnouncementHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
status := strings.TrimSpace(c.Query("status"))
|
||||
search := strings.TrimSpace(c.Query("search"))
|
||||
if len(search) > 200 {
|
||||
search = search[:200]
|
||||
}
|
||||
|
||||
params := pagination.PaginationParams{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}
|
||||
|
||||
items, paginationResult, err := h.announcementService.List(
|
||||
c.Request.Context(),
|
||||
params,
|
||||
service.AnnouncementListFilters{Status: status, Search: search},
|
||||
)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.Announcement, 0, len(items))
|
||||
for i := range items {
|
||||
out = append(out, *dto.AnnouncementFromService(&items[i]))
|
||||
}
|
||||
response.Paginated(c, out, paginationResult.Total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting an announcement by ID
|
||||
// GET /api/v1/admin/announcements/:id
|
||||
func (h *AnnouncementHandler) GetByID(c *gin.Context) {
|
||||
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil || announcementID <= 0 {
|
||||
response.BadRequest(c, "Invalid announcement ID")
|
||||
return
|
||||
}
|
||||
|
||||
item, err := h.announcementService.GetByID(c.Request.Context(), announcementID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.AnnouncementFromService(item))
|
||||
}
|
||||
|
||||
// Create handles creating a new announcement
|
||||
// POST /api/v1/admin/announcements
|
||||
func (h *AnnouncementHandler) Create(c *gin.Context) {
|
||||
var req CreateAnnouncementRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
input := &service.CreateAnnouncementInput{
|
||||
Title: req.Title,
|
||||
Content: req.Content,
|
||||
Status: req.Status,
|
||||
Targeting: req.Targeting,
|
||||
ActorID: &subject.UserID,
|
||||
}
|
||||
|
||||
if req.StartsAt != nil && *req.StartsAt > 0 {
|
||||
t := time.Unix(*req.StartsAt, 0)
|
||||
input.StartsAt = &t
|
||||
}
|
||||
if req.EndsAt != nil && *req.EndsAt > 0 {
|
||||
t := time.Unix(*req.EndsAt, 0)
|
||||
input.EndsAt = &t
|
||||
}
|
||||
|
||||
created, err := h.announcementService.Create(c.Request.Context(), input)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.AnnouncementFromService(created))
|
||||
}
|
||||
|
||||
// Update handles updating an announcement
|
||||
// PUT /api/v1/admin/announcements/:id
|
||||
func (h *AnnouncementHandler) Update(c *gin.Context) {
|
||||
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil || announcementID <= 0 {
|
||||
response.BadRequest(c, "Invalid announcement ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateAnnouncementRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
input := &service.UpdateAnnouncementInput{
|
||||
Title: req.Title,
|
||||
Content: req.Content,
|
||||
Status: req.Status,
|
||||
Targeting: req.Targeting,
|
||||
ActorID: &subject.UserID,
|
||||
}
|
||||
|
||||
if req.StartsAt != nil {
|
||||
if *req.StartsAt == 0 {
|
||||
var cleared *time.Time = nil
|
||||
input.StartsAt = &cleared
|
||||
} else {
|
||||
t := time.Unix(*req.StartsAt, 0)
|
||||
ptr := &t
|
||||
input.StartsAt = &ptr
|
||||
}
|
||||
}
|
||||
|
||||
if req.EndsAt != nil {
|
||||
if *req.EndsAt == 0 {
|
||||
var cleared *time.Time = nil
|
||||
input.EndsAt = &cleared
|
||||
} else {
|
||||
t := time.Unix(*req.EndsAt, 0)
|
||||
ptr := &t
|
||||
input.EndsAt = &ptr
|
||||
}
|
||||
}
|
||||
|
||||
updated, err := h.announcementService.Update(c.Request.Context(), announcementID, input)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.AnnouncementFromService(updated))
|
||||
}
|
||||
|
||||
// Delete handles deleting an announcement
|
||||
// DELETE /api/v1/admin/announcements/:id
|
||||
func (h *AnnouncementHandler) Delete(c *gin.Context) {
|
||||
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil || announcementID <= 0 {
|
||||
response.BadRequest(c, "Invalid announcement ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.announcementService.Delete(c.Request.Context(), announcementID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Announcement deleted successfully"})
|
||||
}
|
||||
|
||||
// ListReadStatus handles listing users read status for an announcement
|
||||
// GET /api/v1/admin/announcements/:id/read-status
|
||||
func (h *AnnouncementHandler) ListReadStatus(c *gin.Context) {
|
||||
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil || announcementID <= 0 {
|
||||
response.BadRequest(c, "Invalid announcement ID")
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
params := pagination.PaginationParams{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}
|
||||
search := strings.TrimSpace(c.Query("search"))
|
||||
if len(search) > 200 {
|
||||
search = search[:200]
|
||||
}
|
||||
|
||||
items, paginationResult, err := h.announcementService.ListUserReadStatus(
|
||||
c.Request.Context(),
|
||||
announcementID,
|
||||
params,
|
||||
search,
|
||||
)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, items, paginationResult.Total, page, pageSize)
|
||||
}
|
||||
@@ -186,7 +186,7 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
|
||||
|
||||
// GetUsageTrend handles getting usage trend data
|
||||
// GET /api/v1/admin/dashboard/trend
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type
|
||||
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
@@ -195,6 +195,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
var userID, apiKeyID, accountID, groupID int64
|
||||
var model string
|
||||
var stream *bool
|
||||
var billingType *int8
|
||||
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
||||
@@ -224,8 +225,17 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
stream = &streamVal
|
||||
}
|
||||
}
|
||||
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
||||
if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
|
||||
bt := int8(v)
|
||||
billingType = &bt
|
||||
} else {
|
||||
response.BadRequest(c, "Invalid billing_type")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream)
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
@@ -241,13 +251,14 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
|
||||
// GetModelStats handles getting model usage statistics
|
||||
// GET /api/v1/admin/dashboard/models
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type
|
||||
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
|
||||
// Parse optional filter params
|
||||
var userID, apiKeyID, accountID, groupID int64
|
||||
var stream *bool
|
||||
var billingType *int8
|
||||
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
||||
@@ -274,8 +285,17 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
stream = &streamVal
|
||||
}
|
||||
}
|
||||
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
||||
if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
|
||||
bt := int8(v)
|
||||
billingType = &bt
|
||||
} else {
|
||||
response.BadRequest(c, "Invalid billing_type")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream)
|
||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
|
||||
@@ -98,9 +98,9 @@ func (h *GroupHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
outGroups := make([]dto.Group, 0, len(groups))
|
||||
outGroups := make([]dto.AdminGroup, 0, len(groups))
|
||||
for i := range groups {
|
||||
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
|
||||
outGroups = append(outGroups, *dto.GroupFromServiceAdmin(&groups[i]))
|
||||
}
|
||||
response.Paginated(c, outGroups, total, page, pageSize)
|
||||
}
|
||||
@@ -124,9 +124,9 @@ func (h *GroupHandler) GetAll(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
outGroups := make([]dto.Group, 0, len(groups))
|
||||
outGroups := make([]dto.AdminGroup, 0, len(groups))
|
||||
for i := range groups {
|
||||
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
|
||||
outGroups = append(outGroups, *dto.GroupFromServiceAdmin(&groups[i]))
|
||||
}
|
||||
response.Success(c, outGroups)
|
||||
}
|
||||
@@ -146,7 +146,7 @@ func (h *GroupHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.GroupFromService(group))
|
||||
response.Success(c, dto.GroupFromServiceAdmin(group))
|
||||
}
|
||||
|
||||
// Create handles creating a new group
|
||||
@@ -183,7 +183,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.GroupFromService(group))
|
||||
response.Success(c, dto.GroupFromServiceAdmin(group))
|
||||
}
|
||||
|
||||
// Update handles updating a group
|
||||
@@ -227,7 +227,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.GroupFromService(group))
|
||||
response.Success(c, dto.GroupFromServiceAdmin(group))
|
||||
}
|
||||
|
||||
// Delete handles deleting a group
|
||||
|
||||
@@ -54,9 +54,9 @@ func (h *RedeemHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.RedeemCode, 0, len(codes))
|
||||
out := make([]dto.AdminRedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
|
||||
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
@@ -76,7 +76,7 @@ func (h *RedeemHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RedeemCodeFromService(code))
|
||||
response.Success(c, dto.RedeemCodeFromServiceAdmin(code))
|
||||
}
|
||||
|
||||
// Generate handles generating new redeem codes
|
||||
@@ -100,9 +100,9 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.RedeemCode, 0, len(codes))
|
||||
out := make([]dto.AdminRedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
|
||||
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
@@ -163,7 +163,7 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RedeemCodeFromService(code))
|
||||
response.Success(c, dto.RedeemCodeFromServiceAdmin(code))
|
||||
}
|
||||
|
||||
// GetStats handles getting redeem code statistics
|
||||
|
||||
@@ -47,6 +47,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||
SMTPHost: settings.SMTPHost,
|
||||
SMTPPort: settings.SMTPPort,
|
||||
SMTPUsername: settings.SMTPUsername,
|
||||
@@ -68,6 +72,9 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
@@ -87,8 +94,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
// UpdateSettingsRequest 更新设置请求
|
||||
type UpdateSettingsRequest struct {
|
||||
// 注册设置
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
|
||||
// 邮件服务设置
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
@@ -111,13 +121,16 @@ type UpdateSettingsRequest struct {
|
||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||
|
||||
// OEM设置
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
@@ -194,6 +207,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// TOTP 双因素认证参数验证
|
||||
// 只有手动配置了加密密钥才允许启用 TOTP 功能
|
||||
if req.TotpEnabled && !previousSettings.TotpEnabled {
|
||||
// 尝试启用 TOTP,检查加密密钥是否已手动配置
|
||||
if !h.settingService.IsTotpEncryptionKeyConfigured() {
|
||||
response.BadRequest(c, "Cannot enable TOTP: TOTP_ENCRYPTION_KEY environment variable must be configured first. Generate a key with 'openssl rand -hex 32' and set it in your environment.")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// LinuxDo Connect 参数验证
|
||||
if req.LinuxDoConnectEnabled {
|
||||
req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID)
|
||||
@@ -223,6 +246,34 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// “购买订阅”页面配置验证
|
||||
purchaseEnabled := previousSettings.PurchaseSubscriptionEnabled
|
||||
if req.PurchaseSubscriptionEnabled != nil {
|
||||
purchaseEnabled = *req.PurchaseSubscriptionEnabled
|
||||
}
|
||||
purchaseURL := previousSettings.PurchaseSubscriptionURL
|
||||
if req.PurchaseSubscriptionURL != nil {
|
||||
purchaseURL = strings.TrimSpace(*req.PurchaseSubscriptionURL)
|
||||
}
|
||||
|
||||
// - 启用时要求 URL 合法且非空
|
||||
// - 禁用时允许为空;若提供了 URL 也做基本校验,避免误配置
|
||||
if purchaseEnabled {
|
||||
if purchaseURL == "" {
|
||||
response.BadRequest(c, "Purchase Subscription URL is required when enabled")
|
||||
return
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(purchaseURL); err != nil {
|
||||
response.BadRequest(c, "Purchase Subscription URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
} else if purchaseURL != "" {
|
||||
if err := config.ValidateAbsoluteHTTPURL(purchaseURL); err != nil {
|
||||
response.BadRequest(c, "Purchase Subscription URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Ops metrics collector interval validation (seconds).
|
||||
if req.OpsMetricsIntervalSeconds != nil {
|
||||
v := *req.OpsMetricsIntervalSeconds
|
||||
@@ -236,38 +287,44 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
HomeContent: req.HomeContent,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||
TotpEnabled: req.TotpEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
HomeContent: req.HomeContent,
|
||||
HideCcsImportButton: req.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: purchaseEnabled,
|
||||
PurchaseSubscriptionURL: purchaseURL,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
OpsMonitoringEnabled: func() bool {
|
||||
if req.OpsMonitoringEnabled != nil {
|
||||
return *req.OpsMonitoringEnabled
|
||||
@@ -311,6 +368,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
||||
TotpEnabled: updatedSettings.TotpEnabled,
|
||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||
SMTPHost: updatedSettings.SMTPHost,
|
||||
SMTPPort: updatedSettings.SMTPPort,
|
||||
SMTPUsername: updatedSettings.SMTPUsername,
|
||||
@@ -332,6 +393,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
ContactInfo: updatedSettings.ContactInfo,
|
||||
DocURL: updatedSettings.DocURL,
|
||||
HomeContent: updatedSettings.HomeContent,
|
||||
HideCcsImportButton: updatedSettings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
@@ -376,6 +440,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
|
||||
changed = append(changed, "email_verify_enabled")
|
||||
}
|
||||
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
||||
changed = append(changed, "password_reset_enabled")
|
||||
}
|
||||
if before.TotpEnabled != after.TotpEnabled {
|
||||
changed = append(changed, "totp_enabled")
|
||||
}
|
||||
if before.SMTPHost != after.SMTPHost {
|
||||
changed = append(changed, "smtp_host")
|
||||
}
|
||||
@@ -439,6 +509,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.HomeContent != after.HomeContent {
|
||||
changed = append(changed, "home_content")
|
||||
}
|
||||
if before.HideCcsImportButton != after.HideCcsImportButton {
|
||||
changed = append(changed, "hide_ccs_import_button")
|
||||
}
|
||||
if before.DefaultConcurrency != after.DefaultConcurrency {
|
||||
changed = append(changed, "default_concurrency")
|
||||
}
|
||||
|
||||
@@ -53,9 +53,9 @@ type BulkAssignSubscriptionRequest struct {
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// ExtendSubscriptionRequest represents extend subscription request
|
||||
type ExtendSubscriptionRequest struct {
|
||||
Days int `json:"days" binding:"required,min=1,max=36500"` // max 100 years
|
||||
// AdjustSubscriptionRequest represents adjust subscription request (extend or shorten)
|
||||
type AdjustSubscriptionRequest struct {
|
||||
Days int `json:"days" binding:"required,min=-36500,max=36500"` // negative to shorten, positive to extend
|
||||
}
|
||||
|
||||
// List handles listing all subscriptions with pagination and filters
|
||||
@@ -77,15 +77,19 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
|
||||
}
|
||||
status := c.Query("status")
|
||||
|
||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
|
||||
// Parse sorting parameters
|
||||
sortBy := c.DefaultQuery("sort_by", "created_at")
|
||||
sortOrder := c.DefaultQuery("sort_order", "desc")
|
||||
|
||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
|
||||
}
|
||||
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
|
||||
}
|
||||
@@ -105,7 +109,7 @@ func (h *SubscriptionHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserSubscriptionFromService(subscription))
|
||||
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
|
||||
}
|
||||
|
||||
// GetProgress handles getting subscription usage progress
|
||||
@@ -150,7 +154,7 @@ func (h *SubscriptionHandler) Assign(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserSubscriptionFromService(subscription))
|
||||
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
|
||||
}
|
||||
|
||||
// BulkAssign handles bulk assigning subscriptions to multiple users
|
||||
@@ -180,7 +184,7 @@ func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
|
||||
response.Success(c, dto.BulkAssignResultFromService(result))
|
||||
}
|
||||
|
||||
// Extend handles extending a subscription
|
||||
// Extend handles adjusting a subscription (extend or shorten)
|
||||
// POST /api/v1/admin/subscriptions/:id/extend
|
||||
func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
@@ -189,7 +193,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var req ExtendSubscriptionRequest
|
||||
var req AdjustSubscriptionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
@@ -201,7 +205,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserSubscriptionFromService(subscription))
|
||||
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
|
||||
}
|
||||
|
||||
// Revoke handles revoking a subscription
|
||||
@@ -239,9 +243,9 @@ func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
|
||||
}
|
||||
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
|
||||
}
|
||||
@@ -261,9 +265,9 @@ func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
377
backend/internal/handler/admin/usage_cleanup_handler_test.go
Normal file
377
backend/internal/handler/admin/usage_cleanup_handler_test.go
Normal file
@@ -0,0 +1,377 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type cleanupRepoStub struct {
|
||||
mu sync.Mutex
|
||||
created []*service.UsageCleanupTask
|
||||
listTasks []service.UsageCleanupTask
|
||||
listResult *pagination.PaginationResult
|
||||
listErr error
|
||||
statusByID map[int64]string
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) CreateTask(ctx context.Context, task *service.UsageCleanupTask) error {
|
||||
if task == nil {
|
||||
return nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if task.ID == 0 {
|
||||
task.ID = int64(len(s.created) + 1)
|
||||
}
|
||||
if task.CreatedAt.IsZero() {
|
||||
task.CreatedAt = time.Now().UTC()
|
||||
}
|
||||
task.UpdatedAt = task.CreatedAt
|
||||
clone := *task
|
||||
s.created = append(s.created, &clone)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.listTasks, s.listResult, s.listErr
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*service.UsageCleanupTask, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.statusByID == nil {
|
||||
return "", sql.ErrNoRows
|
||||
}
|
||||
status, ok := s.statusByID[taskID]
|
||||
if !ok {
|
||||
return "", sql.ErrNoRows
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.statusByID == nil {
|
||||
s.statusByID = map[int64]string{}
|
||||
}
|
||||
status := s.statusByID[taskID]
|
||||
if status != service.UsageCleanupStatusPending && status != service.UsageCleanupStatusRunning {
|
||||
return false, nil
|
||||
}
|
||||
s.statusByID[taskID] = service.UsageCleanupStatusCanceled
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) DeleteUsageLogsBatch(ctx context.Context, filters service.UsageCleanupFilters, limit int) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var _ service.UsageCleanupRepository = (*cleanupRepoStub)(nil)
|
||||
|
||||
func setupCleanupRouter(cleanupService *service.UsageCleanupService, userID int64) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
if userID > 0 {
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: userID})
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
|
||||
handler := NewUsageHandler(nil, nil, nil, cleanupService)
|
||||
router.POST("/api/v1/admin/usage/cleanup-tasks", handler.CreateCleanupTask)
|
||||
router.GET("/api/v1/admin/usage/cleanup-tasks", handler.ListCleanupTasks)
|
||||
router.POST("/api/v1/admin/usage/cleanup-tasks/:id/cancel", handler.CancelCleanupTask)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskUnauthorized(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 0)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskUnavailable(t *testing.T) {
|
||||
router := setupCleanupRouter(nil, 1)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusServiceUnavailable, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskBindError(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 88)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString("{bad-json"))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskMissingRange(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 88)
|
||||
|
||||
payload := map[string]any{
|
||||
"start_date": "2024-01-01",
|
||||
"timezone": "UTC",
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskInvalidDate(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 88)
|
||||
|
||||
payload := map[string]any{
|
||||
"start_date": "2024-13-01",
|
||||
"end_date": "2024-01-02",
|
||||
"timezone": "UTC",
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskInvalidEndDate(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 88)
|
||||
|
||||
payload := map[string]any{
|
||||
"start_date": "2024-01-01",
|
||||
"end_date": "2024-02-40",
|
||||
"timezone": "UTC",
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 99)
|
||||
|
||||
payload := map[string]any{
|
||||
"start_date": " 2024-01-01 ",
|
||||
"end_date": "2024-01-02",
|
||||
"timezone": "UTC",
|
||||
"model": "gpt-4",
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
var resp response.Response
|
||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
|
||||
require.Equal(t, 0, resp.Code)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.created, 1)
|
||||
created := repo.created[0]
|
||||
require.Equal(t, int64(99), created.CreatedBy)
|
||||
require.NotNil(t, created.Filters.Model)
|
||||
require.Equal(t, "gpt-4", *created.Filters.Model)
|
||||
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC).Add(24*time.Hour - time.Nanosecond)
|
||||
require.True(t, created.Filters.StartTime.Equal(start))
|
||||
require.True(t, created.Filters.EndTime.Equal(end))
|
||||
}
|
||||
|
||||
func TestUsageHandlerListCleanupTasksUnavailable(t *testing.T) {
|
||||
router := setupCleanupRouter(nil, 0)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusServiceUnavailable, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerListCleanupTasksSuccess(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
repo.listTasks = []service.UsageCleanupTask{
|
||||
{
|
||||
ID: 7,
|
||||
Status: service.UsageCleanupStatusSucceeded,
|
||||
CreatedBy: 4,
|
||||
},
|
||||
}
|
||||
repo.listResult = &pagination.PaginationResult{Total: 1, Page: 1, PageSize: 20, Pages: 1}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 1)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
var resp struct {
|
||||
Code int `json:"code"`
|
||||
Data struct {
|
||||
Items []dto.UsageCleanupTask `json:"items"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
|
||||
require.Equal(t, 0, resp.Code)
|
||||
require.Len(t, resp.Data.Items, 1)
|
||||
require.Equal(t, int64(7), resp.Data.Items[0].ID)
|
||||
require.Equal(t, int64(1), resp.Data.Total)
|
||||
require.Equal(t, 1, resp.Data.Page)
|
||||
}
|
||||
|
||||
func TestUsageHandlerListCleanupTasksError(t *testing.T) {
|
||||
repo := &cleanupRepoStub{listErr: errors.New("boom")}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 1)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusInternalServerError, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCancelCleanupTaskUnauthorized(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 0)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/1/cancel", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCancelCleanupTaskNotFound(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 1)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/999/cancel", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusNotFound, rec.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCancelCleanupTaskConflict(t *testing.T) {
|
||||
repo := &cleanupRepoStub{statusByID: map[int64]string{2: service.UsageCleanupStatusSucceeded}}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 1)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/2/cancel", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusConflict, rec.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCancelCleanupTaskSuccess(t *testing.T) {
|
||||
repo := &cleanupRepoStub{statusByID: map[int64]string{3: service.UsageCleanupStatusPending}}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 1)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/3/cancel", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
@@ -1,7 +1,10 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
@@ -9,6 +12,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -16,9 +20,10 @@ import (
|
||||
|
||||
// UsageHandler handles admin usage-related requests
|
||||
type UsageHandler struct {
|
||||
usageService *service.UsageService
|
||||
apiKeyService *service.APIKeyService
|
||||
adminService service.AdminService
|
||||
usageService *service.UsageService
|
||||
apiKeyService *service.APIKeyService
|
||||
adminService service.AdminService
|
||||
cleanupService *service.UsageCleanupService
|
||||
}
|
||||
|
||||
// NewUsageHandler creates a new admin usage handler
|
||||
@@ -26,14 +31,30 @@ func NewUsageHandler(
|
||||
usageService *service.UsageService,
|
||||
apiKeyService *service.APIKeyService,
|
||||
adminService service.AdminService,
|
||||
cleanupService *service.UsageCleanupService,
|
||||
) *UsageHandler {
|
||||
return &UsageHandler{
|
||||
usageService: usageService,
|
||||
apiKeyService: apiKeyService,
|
||||
adminService: adminService,
|
||||
usageService: usageService,
|
||||
apiKeyService: apiKeyService,
|
||||
adminService: adminService,
|
||||
cleanupService: cleanupService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateUsageCleanupTaskRequest represents cleanup task creation request
|
||||
type CreateUsageCleanupTaskRequest struct {
|
||||
StartDate string `json:"start_date"`
|
||||
EndDate string `json:"end_date"`
|
||||
UserID *int64 `json:"user_id"`
|
||||
APIKeyID *int64 `json:"api_key_id"`
|
||||
AccountID *int64 `json:"account_id"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Model *string `json:"model"`
|
||||
Stream *bool `json:"stream"`
|
||||
BillingType *int8 `json:"billing_type"`
|
||||
Timezone string `json:"timezone"`
|
||||
}
|
||||
|
||||
// List handles listing all usage records with filters
|
||||
// GET /api/v1/admin/usage
|
||||
func (h *UsageHandler) List(c *gin.Context) {
|
||||
@@ -142,7 +163,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UsageLog, 0, len(records))
|
||||
out := make([]dto.AdminUsageLog, 0, len(records))
|
||||
for i := range records {
|
||||
out = append(out, *dto.UsageLogFromServiceAdmin(&records[i]))
|
||||
}
|
||||
@@ -344,3 +365,162 @@ func (h *UsageHandler) SearchAPIKeys(c *gin.Context) {
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// ListCleanupTasks handles listing usage cleanup tasks
|
||||
// GET /api/v1/admin/usage/cleanup-tasks
|
||||
func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
|
||||
if h.cleanupService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
|
||||
return
|
||||
}
|
||||
operator := int64(0)
|
||||
if subject, ok := middleware.GetAuthSubjectFromContext(c); ok {
|
||||
operator = subject.UserID
|
||||
}
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
log.Printf("[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params)
|
||||
if err != nil {
|
||||
log.Printf("[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
out := make([]dto.UsageCleanupTask, 0, len(tasks))
|
||||
for i := range tasks {
|
||||
out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i]))
|
||||
}
|
||||
log.Printf("[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
|
||||
response.Paginated(c, out, result.Total, page, pageSize)
|
||||
}
|
||||
|
||||
// CreateCleanupTask handles creating a usage cleanup task
|
||||
// POST /api/v1/admin/usage/cleanup-tasks
|
||||
func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
||||
if h.cleanupService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
|
||||
return
|
||||
}
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok || subject.UserID <= 0 {
|
||||
response.Unauthorized(c, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
var req CreateUsageCleanupTaskRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
req.StartDate = strings.TrimSpace(req.StartDate)
|
||||
req.EndDate = strings.TrimSpace(req.EndDate)
|
||||
if req.StartDate == "" || req.EndDate == "" {
|
||||
response.BadRequest(c, "start_date and end_date are required")
|
||||
return
|
||||
}
|
||||
|
||||
startTime, err := timezone.ParseInUserLocation("2006-01-02", req.StartDate, req.Timezone)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
endTime, err := timezone.ParseInUserLocation("2006-01-02", req.EndDate, req.Timezone)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
||||
|
||||
filters := service.UsageCleanupFilters{
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
UserID: req.UserID,
|
||||
APIKeyID: req.APIKeyID,
|
||||
AccountID: req.AccountID,
|
||||
GroupID: req.GroupID,
|
||||
Model: req.Model,
|
||||
Stream: req.Stream,
|
||||
BillingType: req.BillingType,
|
||||
}
|
||||
|
||||
var userID any
|
||||
if filters.UserID != nil {
|
||||
userID = *filters.UserID
|
||||
}
|
||||
var apiKeyID any
|
||||
if filters.APIKeyID != nil {
|
||||
apiKeyID = *filters.APIKeyID
|
||||
}
|
||||
var accountID any
|
||||
if filters.AccountID != nil {
|
||||
accountID = *filters.AccountID
|
||||
}
|
||||
var groupID any
|
||||
if filters.GroupID != nil {
|
||||
groupID = *filters.GroupID
|
||||
}
|
||||
var model any
|
||||
if filters.Model != nil {
|
||||
model = *filters.Model
|
||||
}
|
||||
var stream any
|
||||
if filters.Stream != nil {
|
||||
stream = *filters.Stream
|
||||
}
|
||||
var billingType any
|
||||
if filters.BillingType != nil {
|
||||
billingType = *filters.BillingType
|
||||
}
|
||||
|
||||
log.Printf("[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
|
||||
subject.UserID,
|
||||
filters.StartTime.Format(time.RFC3339),
|
||||
filters.EndTime.Format(time.RFC3339),
|
||||
userID,
|
||||
apiKeyID,
|
||||
accountID,
|
||||
groupID,
|
||||
model,
|
||||
stream,
|
||||
billingType,
|
||||
req.Timezone,
|
||||
)
|
||||
|
||||
task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID)
|
||||
if err != nil {
|
||||
log.Printf("[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
|
||||
response.Success(c, dto.UsageCleanupTaskFromService(task))
|
||||
}
|
||||
|
||||
// CancelCleanupTask handles canceling a usage cleanup task
|
||||
// POST /api/v1/admin/usage/cleanup-tasks/:id/cancel
|
||||
func (h *UsageHandler) CancelCleanupTask(c *gin.Context) {
|
||||
if h.cleanupService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
|
||||
return
|
||||
}
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok || subject.UserID <= 0 {
|
||||
response.Unauthorized(c, "Unauthorized")
|
||||
return
|
||||
}
|
||||
idStr := strings.TrimSpace(c.Param("id"))
|
||||
taskID, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil || taskID <= 0 {
|
||||
response.BadRequest(c, "Invalid task id")
|
||||
return
|
||||
}
|
||||
log.Printf("[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
|
||||
if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil {
|
||||
log.Printf("[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
log.Printf("[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
|
||||
response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled})
|
||||
}
|
||||
|
||||
@@ -84,9 +84,9 @@ func (h *UserHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.User, 0, len(users))
|
||||
out := make([]dto.AdminUser, 0, len(users))
|
||||
for i := range users {
|
||||
out = append(out, *dto.UserFromService(&users[i]))
|
||||
out = append(out, *dto.UserFromServiceAdmin(&users[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
@@ -129,7 +129,7 @@ func (h *UserHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
response.Success(c, dto.UserFromServiceAdmin(user))
|
||||
}
|
||||
|
||||
// Create handles creating a new user
|
||||
@@ -155,7 +155,7 @@ func (h *UserHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
response.Success(c, dto.UserFromServiceAdmin(user))
|
||||
}
|
||||
|
||||
// Update handles updating a user
|
||||
@@ -189,7 +189,7 @@ func (h *UserHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
response.Success(c, dto.UserFromServiceAdmin(user))
|
||||
}
|
||||
|
||||
// Delete handles deleting a user
|
||||
@@ -231,7 +231,7 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
response.Success(c, dto.UserFromServiceAdmin(user))
|
||||
}
|
||||
|
||||
// GetUserAPIKeys handles getting user's API keys
|
||||
|
||||
81
backend/internal/handler/announcement_handler.go
Normal file
81
backend/internal/handler/announcement_handler.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AnnouncementHandler handles user announcement operations
|
||||
type AnnouncementHandler struct {
|
||||
announcementService *service.AnnouncementService
|
||||
}
|
||||
|
||||
// NewAnnouncementHandler creates a new user announcement handler
|
||||
func NewAnnouncementHandler(announcementService *service.AnnouncementService) *AnnouncementHandler {
|
||||
return &AnnouncementHandler{
|
||||
announcementService: announcementService,
|
||||
}
|
||||
}
|
||||
|
||||
// List handles listing announcements visible to current user
|
||||
// GET /api/v1/announcements
|
||||
func (h *AnnouncementHandler) List(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
unreadOnly := parseBoolQuery(c.Query("unread_only"))
|
||||
|
||||
items, err := h.announcementService.ListForUser(c.Request.Context(), subject.UserID, unreadOnly)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UserAnnouncement, 0, len(items))
|
||||
for i := range items {
|
||||
out = append(out, *dto.UserAnnouncementFromService(&items[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// MarkRead marks an announcement as read for current user
|
||||
// POST /api/v1/announcements/:id/read
|
||||
func (h *AnnouncementHandler) MarkRead(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil || announcementID <= 0 {
|
||||
response.BadRequest(c, "Invalid announcement ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.announcementService.MarkRead(c.Request.Context(), subject.UserID, announcementID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "ok"})
|
||||
}
|
||||
|
||||
func parseBoolQuery(v string) bool {
|
||||
switch strings.TrimSpace(strings.ToLower(v)) {
|
||||
case "1", "true", "yes", "y", "on":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
@@ -18,16 +20,18 @@ type AuthHandler struct {
|
||||
userService *service.UserService
|
||||
settingSvc *service.SettingService
|
||||
promoService *service.PromoService
|
||||
totpService *service.TotpService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService) *AuthHandler {
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, totpService *service.TotpService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
cfg: cfg,
|
||||
authService: authService,
|
||||
userService: userService,
|
||||
settingSvc: settingService,
|
||||
promoService: promoService,
|
||||
totpService: totpService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,6 +148,100 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if TOTP 2FA is enabled for this user
|
||||
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
|
||||
// Create a temporary login session for 2FA
|
||||
tempToken, err := h.totpService.CreateLoginSession(c.Request.Context(), user.ID, user.Email)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to create 2FA session")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, TotpLoginResponse{
|
||||
Requires2FA: true,
|
||||
TempToken: tempToken,
|
||||
UserEmailMasked: service.MaskEmail(user.Email),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
User: dto.UserFromService(user),
|
||||
})
|
||||
}
|
||||
|
||||
// TotpLoginResponse represents the response when 2FA is required
|
||||
type TotpLoginResponse struct {
|
||||
Requires2FA bool `json:"requires_2fa"`
|
||||
TempToken string `json:"temp_token,omitempty"`
|
||||
UserEmailMasked string `json:"user_email_masked,omitempty"`
|
||||
}
|
||||
|
||||
// Login2FARequest represents the 2FA login request
|
||||
type Login2FARequest struct {
|
||||
TempToken string `json:"temp_token" binding:"required"`
|
||||
TotpCode string `json:"totp_code" binding:"required,len=6"`
|
||||
}
|
||||
|
||||
// Login2FA completes the login with 2FA verification
|
||||
// POST /api/v1/auth/login/2fa
|
||||
func (h *AuthHandler) Login2FA(c *gin.Context) {
|
||||
var req Login2FARequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("login_2fa_request",
|
||||
"temp_token_len", len(req.TempToken),
|
||||
"totp_code_len", len(req.TotpCode))
|
||||
|
||||
// Get the login session
|
||||
session, err := h.totpService.GetLoginSession(c.Request.Context(), req.TempToken)
|
||||
if err != nil || session == nil {
|
||||
tokenPrefix := ""
|
||||
if len(req.TempToken) >= 8 {
|
||||
tokenPrefix = req.TempToken[:8]
|
||||
}
|
||||
slog.Debug("login_2fa_session_invalid",
|
||||
"temp_token_prefix", tokenPrefix,
|
||||
"error", err)
|
||||
response.BadRequest(c, "Invalid or expired 2FA session")
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("login_2fa_session_found",
|
||||
"user_id", session.UserID,
|
||||
"email", session.Email)
|
||||
|
||||
// Verify the TOTP code
|
||||
if err := h.totpService.VerifyCode(c.Request.Context(), session.UserID, req.TotpCode); err != nil {
|
||||
slog.Debug("login_2fa_verify_failed",
|
||||
"user_id", session.UserID,
|
||||
"error", err)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the login session
|
||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||
|
||||
// Get the user
|
||||
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate the JWT token
|
||||
token, err := h.authService.GenerateToken(user)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate token")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
@@ -195,6 +293,15 @@ type ValidatePromoCodeResponse struct {
|
||||
// ValidatePromoCode 验证优惠码(公开接口,注册前调用)
|
||||
// POST /api/v1/auth/validate-promo-code
|
||||
func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
|
||||
// 检查优惠码功能是否启用
|
||||
if h.settingSvc != nil && !h.settingSvc.IsPromoCodeEnabled(c.Request.Context()) {
|
||||
response.Success(c, ValidatePromoCodeResponse{
|
||||
Valid: false,
|
||||
ErrorCode: "PROMO_CODE_DISABLED",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var req ValidatePromoCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
@@ -238,3 +345,85 @@ func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
|
||||
BonusAmount: promoCode.BonusAmount,
|
||||
})
|
||||
}
|
||||
|
||||
// ForgotPasswordRequest 忘记密码请求
|
||||
type ForgotPasswordRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
TurnstileToken string `json:"turnstile_token"`
|
||||
}
|
||||
|
||||
// ForgotPasswordResponse 忘记密码响应
|
||||
type ForgotPasswordResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ForgotPassword 请求密码重置
|
||||
// POST /api/v1/auth/forgot-password
|
||||
func (h *AuthHandler) ForgotPassword(c *gin.Context) {
|
||||
var req ForgotPasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Turnstile 验证
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Build frontend base URL from request
|
||||
scheme := "https"
|
||||
if c.Request.TLS == nil {
|
||||
// Check X-Forwarded-Proto header (common in reverse proxy setups)
|
||||
if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" {
|
||||
scheme = proto
|
||||
} else {
|
||||
scheme = "http"
|
||||
}
|
||||
}
|
||||
frontendBaseURL := scheme + "://" + c.Request.Host
|
||||
|
||||
// Request password reset (async)
|
||||
// Note: This returns success even if email doesn't exist (to prevent enumeration)
|
||||
if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, ForgotPasswordResponse{
|
||||
Message: "If your email is registered, you will receive a password reset link shortly.",
|
||||
})
|
||||
}
|
||||
|
||||
// ResetPasswordRequest 重置密码请求
|
||||
type ResetPasswordRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Token string `json:"token" binding:"required"`
|
||||
NewPassword string `json:"new_password" binding:"required,min=6"`
|
||||
}
|
||||
|
||||
// ResetPasswordResponse 重置密码响应
|
||||
type ResetPasswordResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ResetPassword 重置密码
|
||||
// POST /api/v1/auth/reset-password
|
||||
func (h *AuthHandler) ResetPassword(c *gin.Context) {
|
||||
var req ResetPasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Reset password
|
||||
if err := h.authService.ResetPassword(c.Request.Context(), req.Email, req.Token, req.NewPassword); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, ResetPasswordResponse{
|
||||
Message: "Your password has been reset successfully. You can now log in with your new password.",
|
||||
})
|
||||
}
|
||||
|
||||
74
backend/internal/handler/dto/announcement.go
Normal file
74
backend/internal/handler/dto/announcement.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type Announcement struct {
|
||||
ID int64 `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Status string `json:"status"`
|
||||
|
||||
Targeting service.AnnouncementTargeting `json:"targeting"`
|
||||
|
||||
StartsAt *time.Time `json:"starts_at,omitempty"`
|
||||
EndsAt *time.Time `json:"ends_at,omitempty"`
|
||||
|
||||
CreatedBy *int64 `json:"created_by,omitempty"`
|
||||
UpdatedBy *int64 `json:"updated_by,omitempty"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type UserAnnouncement struct {
|
||||
ID int64 `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
|
||||
StartsAt *time.Time `json:"starts_at,omitempty"`
|
||||
EndsAt *time.Time `json:"ends_at,omitempty"`
|
||||
|
||||
ReadAt *time.Time `json:"read_at,omitempty"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
func AnnouncementFromService(a *service.Announcement) *Announcement {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
return &Announcement{
|
||||
ID: a.ID,
|
||||
Title: a.Title,
|
||||
Content: a.Content,
|
||||
Status: a.Status,
|
||||
Targeting: a.Targeting,
|
||||
StartsAt: a.StartsAt,
|
||||
EndsAt: a.EndsAt,
|
||||
CreatedBy: a.CreatedBy,
|
||||
UpdatedBy: a.UpdatedBy,
|
||||
CreatedAt: a.CreatedAt,
|
||||
UpdatedAt: a.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func UserAnnouncementFromService(a *service.UserAnnouncement) *UserAnnouncement {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
return &UserAnnouncement{
|
||||
ID: a.Announcement.ID,
|
||||
Title: a.Announcement.Title,
|
||||
Content: a.Announcement.Content,
|
||||
StartsAt: a.Announcement.StartsAt,
|
||||
EndsAt: a.Announcement.EndsAt,
|
||||
ReadAt: a.ReadAt,
|
||||
CreatedAt: a.Announcement.CreatedAt,
|
||||
UpdatedAt: a.Announcement.UpdatedAt,
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,6 @@ func UserFromServiceShallow(u *service.User) *User {
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Notes: u.Notes,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
@@ -48,6 +47,22 @@ func UserFromService(u *service.User) *User {
|
||||
return out
|
||||
}
|
||||
|
||||
// UserFromServiceAdmin converts a service User to DTO for admin users.
|
||||
// It includes notes - user-facing endpoints must not use this.
|
||||
func UserFromServiceAdmin(u *service.User) *AdminUser {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
base := UserFromService(u)
|
||||
if base == nil {
|
||||
return nil
|
||||
}
|
||||
return &AdminUser{
|
||||
User: *base,
|
||||
Notes: u.Notes,
|
||||
}
|
||||
}
|
||||
|
||||
func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
if k == nil {
|
||||
return nil
|
||||
@@ -72,38 +87,31 @@ func GroupFromServiceShallow(g *service.Group) *Group {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
return &Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: g.Description,
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUSD,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
||||
ImagePrice1K: g.ImagePrice1K,
|
||||
ImagePrice2K: g.ImagePrice2K,
|
||||
ImagePrice4K: g.ImagePrice4K,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
AccountCount: g.AccountCount,
|
||||
}
|
||||
out := groupFromServiceBase(g)
|
||||
return &out
|
||||
}
|
||||
|
||||
func GroupFromService(g *service.Group) *Group {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
out := GroupFromServiceShallow(g)
|
||||
return GroupFromServiceShallow(g)
|
||||
}
|
||||
|
||||
// GroupFromServiceAdmin converts a service Group to DTO for admin users.
|
||||
// It includes internal fields like model_routing and account_count.
|
||||
func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
out := &AdminGroup{
|
||||
Group: groupFromServiceBase(g),
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
AccountCount: g.AccountCount,
|
||||
}
|
||||
if len(g.AccountGroups) > 0 {
|
||||
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
||||
for i := range g.AccountGroups {
|
||||
@@ -114,6 +122,31 @@ func GroupFromService(g *service.Group) *Group {
|
||||
return out
|
||||
}
|
||||
|
||||
func groupFromServiceBase(g *service.Group) Group {
|
||||
return Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: g.Description,
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUSD,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
||||
ImagePrice1K: g.ImagePrice1K,
|
||||
ImagePrice2K: g.ImagePrice2K,
|
||||
ImagePrice4K: g.ImagePrice4K,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
// 无效请求兜底分组
|
||||
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
if a == nil {
|
||||
return nil
|
||||
@@ -163,6 +196,16 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
if idleTimeout := a.GetSessionIdleTimeoutMinutes(); idleTimeout > 0 {
|
||||
out.SessionIdleTimeoutMin = &idleTimeout
|
||||
}
|
||||
// TLS指纹伪装开关
|
||||
if a.IsTLSFingerprintEnabled() {
|
||||
enabled := true
|
||||
out.EnableTLSFingerprint = &enabled
|
||||
}
|
||||
// 会话ID伪装开关
|
||||
if a.IsSessionIDMaskingEnabled() {
|
||||
enabled := true
|
||||
out.EnableSessionIDMasking = &enabled
|
||||
}
|
||||
}
|
||||
|
||||
if scopeLimits := a.GetAntigravityScopeRateLimits(); len(scopeLimits) > 0 {
|
||||
@@ -276,7 +319,24 @@ func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
|
||||
if rc == nil {
|
||||
return nil
|
||||
}
|
||||
return &RedeemCode{
|
||||
out := redeemCodeFromServiceBase(rc)
|
||||
return &out
|
||||
}
|
||||
|
||||
// RedeemCodeFromServiceAdmin converts a service RedeemCode to DTO for admin users.
|
||||
// It includes notes - user-facing endpoints must not use this.
|
||||
func RedeemCodeFromServiceAdmin(rc *service.RedeemCode) *AdminRedeemCode {
|
||||
if rc == nil {
|
||||
return nil
|
||||
}
|
||||
return &AdminRedeemCode{
|
||||
RedeemCode: redeemCodeFromServiceBase(rc),
|
||||
Notes: rc.Notes,
|
||||
}
|
||||
}
|
||||
|
||||
func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode {
|
||||
out := RedeemCode{
|
||||
ID: rc.ID,
|
||||
Code: rc.Code,
|
||||
Type: rc.Type,
|
||||
@@ -284,13 +344,20 @@ func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
|
||||
Status: rc.Status,
|
||||
UsedBy: rc.UsedBy,
|
||||
UsedAt: rc.UsedAt,
|
||||
Notes: rc.Notes,
|
||||
CreatedAt: rc.CreatedAt,
|
||||
GroupID: rc.GroupID,
|
||||
ValidityDays: rc.ValidityDays,
|
||||
User: UserFromServiceShallow(rc.User),
|
||||
Group: GroupFromServiceShallow(rc.Group),
|
||||
}
|
||||
|
||||
// For admin_balance/admin_concurrency types, include notes so users can see
|
||||
// why they were charged or credited by admin
|
||||
if (rc.Type == "admin_balance" || rc.Type == "admin_concurrency") && rc.Notes != "" {
|
||||
out.Notes = &rc.Notes
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// AccountSummaryFromService returns a minimal AccountSummary for usage log display.
|
||||
@@ -305,14 +372,9 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary {
|
||||
}
|
||||
}
|
||||
|
||||
// usageLogFromServiceBase is a helper that converts service UsageLog to DTO.
|
||||
// The account parameter allows caller to control what Account info is included.
|
||||
// The includeIPAddress parameter controls whether to include the IP address (admin-only).
|
||||
func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, includeIPAddress bool) *UsageLog {
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
result := &UsageLog{
|
||||
func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
// 普通用户 DTO:严禁包含管理员字段(例如 account_rate_multiplier、ip_address、account)。
|
||||
return UsageLog{
|
||||
ID: l.ID,
|
||||
UserID: l.UserID,
|
||||
APIKeyID: l.APIKeyID,
|
||||
@@ -334,7 +396,6 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, inclu
|
||||
TotalCost: l.TotalCost,
|
||||
ActualCost: l.ActualCost,
|
||||
RateMultiplier: l.RateMultiplier,
|
||||
AccountRateMultiplier: l.AccountRateMultiplier,
|
||||
BillingType: l.BillingType,
|
||||
Stream: l.Stream,
|
||||
DurationMs: l.DurationMs,
|
||||
@@ -345,30 +406,63 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, inclu
|
||||
CreatedAt: l.CreatedAt,
|
||||
User: UserFromServiceShallow(l.User),
|
||||
APIKey: APIKeyFromService(l.APIKey),
|
||||
Account: account,
|
||||
Group: GroupFromServiceShallow(l.Group),
|
||||
Subscription: UserSubscriptionFromService(l.Subscription),
|
||||
}
|
||||
// IP 地址仅对管理员可见
|
||||
if includeIPAddress {
|
||||
result.IPAddress = l.IPAddress
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// UsageLogFromService converts a service UsageLog to DTO for regular users.
|
||||
// It excludes Account details and IP address - users should not see these.
|
||||
func UsageLogFromService(l *service.UsageLog) *UsageLog {
|
||||
return usageLogFromServiceBase(l, nil, false)
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
u := usageLogFromServiceUser(l)
|
||||
return &u
|
||||
}
|
||||
|
||||
// UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users.
|
||||
// It includes minimal Account info (ID, Name only) and IP address.
|
||||
func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog {
|
||||
func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog {
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true)
|
||||
return &AdminUsageLog{
|
||||
UsageLog: usageLogFromServiceUser(l),
|
||||
AccountRateMultiplier: l.AccountRateMultiplier,
|
||||
IPAddress: l.IPAddress,
|
||||
Account: AccountSummaryFromService(l.Account),
|
||||
}
|
||||
}
|
||||
|
||||
func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTask {
|
||||
if task == nil {
|
||||
return nil
|
||||
}
|
||||
return &UsageCleanupTask{
|
||||
ID: task.ID,
|
||||
Status: task.Status,
|
||||
Filters: UsageCleanupFilters{
|
||||
StartTime: task.Filters.StartTime,
|
||||
EndTime: task.Filters.EndTime,
|
||||
UserID: task.Filters.UserID,
|
||||
APIKeyID: task.Filters.APIKeyID,
|
||||
AccountID: task.Filters.AccountID,
|
||||
GroupID: task.Filters.GroupID,
|
||||
Model: task.Filters.Model,
|
||||
Stream: task.Filters.Stream,
|
||||
BillingType: task.Filters.BillingType,
|
||||
},
|
||||
CreatedBy: task.CreatedBy,
|
||||
DeletedRows: task.DeletedRows,
|
||||
ErrorMessage: task.ErrorMsg,
|
||||
CanceledBy: task.CanceledBy,
|
||||
CanceledAt: task.CanceledAt,
|
||||
StartedAt: task.StartedAt,
|
||||
FinishedAt: task.FinishedAt,
|
||||
CreatedAt: task.CreatedAt,
|
||||
UpdatedAt: task.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func SettingFromService(s *service.Setting) *Setting {
|
||||
@@ -387,7 +481,27 @@ func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscriptio
|
||||
if sub == nil {
|
||||
return nil
|
||||
}
|
||||
return &UserSubscription{
|
||||
out := userSubscriptionFromServiceBase(sub)
|
||||
return &out
|
||||
}
|
||||
|
||||
// UserSubscriptionFromServiceAdmin converts a service UserSubscription to DTO for admin users.
|
||||
// It includes assignment metadata and notes.
|
||||
func UserSubscriptionFromServiceAdmin(sub *service.UserSubscription) *AdminUserSubscription {
|
||||
if sub == nil {
|
||||
return nil
|
||||
}
|
||||
return &AdminUserSubscription{
|
||||
UserSubscription: userSubscriptionFromServiceBase(sub),
|
||||
AssignedBy: sub.AssignedBy,
|
||||
AssignedAt: sub.AssignedAt,
|
||||
Notes: sub.Notes,
|
||||
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
|
||||
}
|
||||
}
|
||||
|
||||
func userSubscriptionFromServiceBase(sub *service.UserSubscription) UserSubscription {
|
||||
return UserSubscription{
|
||||
ID: sub.ID,
|
||||
UserID: sub.UserID,
|
||||
GroupID: sub.GroupID,
|
||||
@@ -400,14 +514,10 @@ func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscriptio
|
||||
DailyUsageUSD: sub.DailyUsageUSD,
|
||||
WeeklyUsageUSD: sub.WeeklyUsageUSD,
|
||||
MonthlyUsageUSD: sub.MonthlyUsageUSD,
|
||||
AssignedBy: sub.AssignedBy,
|
||||
AssignedAt: sub.AssignedAt,
|
||||
Notes: sub.Notes,
|
||||
CreatedAt: sub.CreatedAt,
|
||||
UpdatedAt: sub.UpdatedAt,
|
||||
User: UserFromServiceShallow(sub.User),
|
||||
Group: GroupFromServiceShallow(sub.Group),
|
||||
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -415,9 +525,9 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
subs := make([]UserSubscription, 0, len(r.Subscriptions))
|
||||
subs := make([]AdminUserSubscription, 0, len(r.Subscriptions))
|
||||
for i := range r.Subscriptions {
|
||||
subs = append(subs, *UserSubscriptionFromService(&r.Subscriptions[i]))
|
||||
subs = append(subs, *UserSubscriptionFromServiceAdmin(&r.Subscriptions[i]))
|
||||
}
|
||||
return &BulkAssignResult{
|
||||
SuccessCount: r.SuccessCount,
|
||||
|
||||
@@ -2,8 +2,12 @@ package dto
|
||||
|
||||
// SystemSettings represents the admin settings API response payload.
|
||||
type SystemSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
@@ -22,13 +26,16 @@ type SystemSettings struct {
|
||||
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
|
||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
@@ -52,19 +59,25 @@ type SystemSettings struct {
|
||||
}
|
||||
|
||||
type PublicSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
Version string `json:"version"`
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// StreamTimeoutSettings 流超时处理配置 DTO
|
||||
|
||||
@@ -11,7 +11,6 @@ type User struct {
|
||||
ID int64 `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Notes string `json:"notes"`
|
||||
Role string `json:"role"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
@@ -24,6 +23,14 @@ type User struct {
|
||||
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
|
||||
}
|
||||
|
||||
// AdminUser 是管理员接口使用的 user DTO(包含敏感/内部字段)。
|
||||
// 注意:普通用户接口不得返回 notes 等管理员备注信息。
|
||||
type AdminUser struct {
|
||||
User
|
||||
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
@@ -65,6 +72,15 @@ type Group struct {
|
||||
// 无效请求兜底分组
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// AdminGroup 是管理员接口使用的 group DTO(包含敏感/内部字段)。
|
||||
// 注意:普通用户接口不得返回 model_routing/account_count/account_groups 等内部信息。
|
||||
type AdminGroup struct {
|
||||
Group
|
||||
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
@@ -72,9 +88,6 @@ type Group struct {
|
||||
// MCP XML 协议注入(仅 antigravity 平台使用)
|
||||
MCPXMLInject bool `json:"mcp_xml_inject"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
AccountCount int64 `json:"account_count,omitempty"`
|
||||
}
|
||||
@@ -125,6 +138,15 @@ type Account struct {
|
||||
MaxSessions *int `json:"max_sessions,omitempty"`
|
||||
SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"`
|
||||
|
||||
// TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||
// 从 extra 字段提取,方便前端显示和编辑
|
||||
EnableTLSFingerprint *bool `json:"enable_tls_fingerprint,omitempty"`
|
||||
|
||||
// 会话ID伪装(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||
// 启用后将在15分钟内固定 metadata.user_id 中的 session ID
|
||||
// 从 extra 字段提取,方便前端显示和编辑
|
||||
EnableSessionIDMasking *bool `json:"session_id_masking_enabled,omitempty"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
|
||||
@@ -184,16 +206,28 @@ type RedeemCode struct {
|
||||
Status string `json:"status"`
|
||||
UsedBy *int64 `json:"used_by"`
|
||||
UsedAt *time.Time `json:"used_at"`
|
||||
Notes string `json:"notes"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
GroupID *int64 `json:"group_id"`
|
||||
ValidityDays int `json:"validity_days"`
|
||||
|
||||
// Notes is only populated for admin_balance/admin_concurrency types
|
||||
// so users can see why they were charged or credited
|
||||
Notes *string `json:"notes,omitempty"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
// AdminRedeemCode 是管理员接口使用的 redeem code DTO(包含 notes 等字段)。
|
||||
// 注意:普通用户接口不得返回 notes 等内部信息。
|
||||
type AdminRedeemCode struct {
|
||||
RedeemCode
|
||||
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// UsageLog 是普通用户接口使用的 usage log DTO(不包含管理员字段)。
|
||||
type UsageLog struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
@@ -213,14 +247,13 @@ type UsageLog struct {
|
||||
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
|
||||
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
|
||||
|
||||
InputCost float64 `json:"input_cost"`
|
||||
OutputCost float64 `json:"output_cost"`
|
||||
CacheCreationCost float64 `json:"cache_creation_cost"`
|
||||
CacheReadCost float64 `json:"cache_read_cost"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
ActualCost float64 `json:"actual_cost"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
|
||||
InputCost float64 `json:"input_cost"`
|
||||
OutputCost float64 `json:"output_cost"`
|
||||
CacheCreationCost float64 `json:"cache_creation_cost"`
|
||||
CacheReadCost float64 `json:"cache_read_cost"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
ActualCost float64 `json:"actual_cost"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
|
||||
BillingType int8 `json:"billing_type"`
|
||||
Stream bool `json:"stream"`
|
||||
@@ -234,18 +267,55 @@ type UsageLog struct {
|
||||
// User-Agent
|
||||
UserAgent *string `json:"user_agent"`
|
||||
|
||||
// IP 地址(仅管理员可见)
|
||||
IPAddress *string `json:"ip_address,omitempty"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
APIKey *APIKey `json:"api_key,omitempty"`
|
||||
Account *AccountSummary `json:"account,omitempty"` // Use minimal AccountSummary to prevent data leakage
|
||||
Group *Group `json:"group,omitempty"`
|
||||
Subscription *UserSubscription `json:"subscription,omitempty"`
|
||||
}
|
||||
|
||||
// AdminUsageLog 是管理员接口使用的 usage log DTO(包含管理员字段)。
|
||||
type AdminUsageLog struct {
|
||||
UsageLog
|
||||
|
||||
// AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理)
|
||||
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
|
||||
|
||||
// IPAddress 用户请求 IP(仅管理员可见)
|
||||
IPAddress *string `json:"ip_address,omitempty"`
|
||||
|
||||
// Account 最小账号信息(避免泄露敏感字段)
|
||||
Account *AccountSummary `json:"account,omitempty"`
|
||||
}
|
||||
|
||||
type UsageCleanupFilters struct {
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
UserID *int64 `json:"user_id,omitempty"`
|
||||
APIKeyID *int64 `json:"api_key_id,omitempty"`
|
||||
AccountID *int64 `json:"account_id,omitempty"`
|
||||
GroupID *int64 `json:"group_id,omitempty"`
|
||||
Model *string `json:"model,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
BillingType *int8 `json:"billing_type,omitempty"`
|
||||
}
|
||||
|
||||
type UsageCleanupTask struct {
|
||||
ID int64 `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Filters UsageCleanupFilters `json:"filters"`
|
||||
CreatedBy int64 `json:"created_by"`
|
||||
DeletedRows int64 `json:"deleted_rows"`
|
||||
ErrorMessage *string `json:"error_message,omitempty"`
|
||||
CanceledBy *int64 `json:"canceled_by,omitempty"`
|
||||
CanceledAt *time.Time `json:"canceled_at,omitempty"`
|
||||
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||
FinishedAt *time.Time `json:"finished_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// AccountSummary is a minimal account info for usage log display.
|
||||
// It intentionally excludes sensitive fields like Credentials, Proxy, etc.
|
||||
type AccountSummary struct {
|
||||
@@ -277,23 +347,30 @@ type UserSubscription struct {
|
||||
WeeklyUsageUSD float64 `json:"weekly_usage_usd"`
|
||||
MonthlyUsageUSD float64 `json:"monthly_usage_usd"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
// AdminUserSubscription 是管理员接口使用的订阅 DTO(包含分配信息/备注等字段)。
|
||||
// 注意:普通用户接口不得返回 assigned_by/assigned_at/notes/assigned_by_user 等管理员字段。
|
||||
type AdminUserSubscription struct {
|
||||
UserSubscription
|
||||
|
||||
AssignedBy *int64 `json:"assigned_by"`
|
||||
AssignedAt time.Time `json:"assigned_at"`
|
||||
Notes string `json:"notes"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
AssignedByUser *User `json:"assigned_by_user,omitempty"`
|
||||
AssignedByUser *User `json:"assigned_by_user,omitempty"`
|
||||
}
|
||||
|
||||
type BulkAssignResult struct {
|
||||
SuccessCount int `json:"success_count"`
|
||||
FailedCount int `json:"failed_count"`
|
||||
Subscriptions []UserSubscription `json:"subscriptions"`
|
||||
Errors []string `json:"errors"`
|
||||
SuccessCount int `json:"success_count"`
|
||||
FailedCount int `json:"failed_count"`
|
||||
Subscriptions []AdminUserSubscription `json:"subscriptions"`
|
||||
Errors []string `json:"errors"`
|
||||
}
|
||||
|
||||
// PromoCode 注册优惠码
|
||||
|
||||
@@ -31,6 +31,7 @@ type GatewayHandler struct {
|
||||
antigravityGatewayService *service.AntigravityGatewayService
|
||||
userService *service.UserService
|
||||
billingCacheService *service.BillingCacheService
|
||||
usageService *service.UsageService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
maxAccountSwitchesGemini int
|
||||
@@ -44,6 +45,7 @@ func NewGatewayHandler(
|
||||
userService *service.UserService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
billingCacheService *service.BillingCacheService,
|
||||
usageService *service.UsageService,
|
||||
cfg *config.Config,
|
||||
) *GatewayHandler {
|
||||
pingInterval := time.Duration(0)
|
||||
@@ -64,6 +66,7 @@ func NewGatewayHandler(
|
||||
antigravityGatewayService: antigravityGatewayService,
|
||||
userService: userService,
|
||||
billingCacheService: billingCacheService,
|
||||
usageService: usageService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
|
||||
@@ -210,17 +213,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||
if account.IsInterceptWarmupEnabled() {
|
||||
interceptType := detectInterceptType(body)
|
||||
if interceptType != InterceptTypeNone {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
if reqStream {
|
||||
sendMockInterceptStream(c, reqModel, interceptType)
|
||||
} else {
|
||||
sendMockInterceptResponse(c, reqModel, interceptType)
|
||||
}
|
||||
return
|
||||
}
|
||||
if reqStream {
|
||||
sendMockWarmupStream(c, reqModel)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, reqModel)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
@@ -359,17 +365,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||
if account.IsInterceptWarmupEnabled() {
|
||||
interceptType := detectInterceptType(body)
|
||||
if interceptType != InterceptTypeNone {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
if reqStream {
|
||||
sendMockInterceptStream(c, reqModel, interceptType)
|
||||
} else {
|
||||
sendMockInterceptResponse(c, reqModel, interceptType)
|
||||
}
|
||||
return
|
||||
}
|
||||
if reqStream {
|
||||
sendMockWarmupStream(c, reqModel)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, reqModel)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
@@ -588,7 +597,7 @@ func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service
|
||||
return &cloned
|
||||
}
|
||||
|
||||
// Usage handles getting account balance for CC Switch integration
|
||||
// Usage handles getting account balance and usage statistics for CC Switch integration
|
||||
// GET /v1/usage
|
||||
func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
@@ -603,7 +612,40 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 订阅模式:返回订阅限额信息
|
||||
// Best-effort: 获取用量统计,失败不影响基础响应
|
||||
var usageData gin.H
|
||||
if h.usageService != nil {
|
||||
dashStats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID)
|
||||
if err == nil && dashStats != nil {
|
||||
usageData = gin.H{
|
||||
"today": gin.H{
|
||||
"requests": dashStats.TodayRequests,
|
||||
"input_tokens": dashStats.TodayInputTokens,
|
||||
"output_tokens": dashStats.TodayOutputTokens,
|
||||
"cache_creation_tokens": dashStats.TodayCacheCreationTokens,
|
||||
"cache_read_tokens": dashStats.TodayCacheReadTokens,
|
||||
"total_tokens": dashStats.TodayTokens,
|
||||
"cost": dashStats.TodayCost,
|
||||
"actual_cost": dashStats.TodayActualCost,
|
||||
},
|
||||
"total": gin.H{
|
||||
"requests": dashStats.TotalRequests,
|
||||
"input_tokens": dashStats.TotalInputTokens,
|
||||
"output_tokens": dashStats.TotalOutputTokens,
|
||||
"cache_creation_tokens": dashStats.TotalCacheCreationTokens,
|
||||
"cache_read_tokens": dashStats.TotalCacheReadTokens,
|
||||
"total_tokens": dashStats.TotalTokens,
|
||||
"cost": dashStats.TotalCost,
|
||||
"actual_cost": dashStats.TotalActualCost,
|
||||
},
|
||||
"average_duration_ms": dashStats.AverageDurationMs,
|
||||
"rpm": dashStats.Rpm,
|
||||
"tpm": dashStats.Tpm,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 订阅模式:返回订阅限额信息 + 用量统计
|
||||
if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() {
|
||||
subscription, ok := middleware2.GetSubscriptionFromContext(c)
|
||||
if !ok {
|
||||
@@ -612,28 +654,46 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
}
|
||||
|
||||
remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
resp := gin.H{
|
||||
"isValid": true,
|
||||
"planName": apiKey.Group.Name,
|
||||
"remaining": remaining,
|
||||
"unit": "USD",
|
||||
})
|
||||
"subscription": gin.H{
|
||||
"daily_usage_usd": subscription.DailyUsageUSD,
|
||||
"weekly_usage_usd": subscription.WeeklyUsageUSD,
|
||||
"monthly_usage_usd": subscription.MonthlyUsageUSD,
|
||||
"daily_limit_usd": apiKey.Group.DailyLimitUSD,
|
||||
"weekly_limit_usd": apiKey.Group.WeeklyLimitUSD,
|
||||
"monthly_limit_usd": apiKey.Group.MonthlyLimitUSD,
|
||||
"expires_at": subscription.ExpiresAt,
|
||||
},
|
||||
}
|
||||
if usageData != nil {
|
||||
resp["usage"] = usageData
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
|
||||
// 余额模式:返回钱包余额
|
||||
// 余额模式:返回钱包余额 + 用量统计
|
||||
latestUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info")
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
resp := gin.H{
|
||||
"isValid": true,
|
||||
"planName": "钱包余额",
|
||||
"remaining": latestUser.Balance,
|
||||
"unit": "USD",
|
||||
})
|
||||
"balance": latestUser.Balance,
|
||||
}
|
||||
if usageData != nil {
|
||||
resp["usage"] = usageData
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// calculateSubscriptionRemaining 计算订阅剩余可用额度
|
||||
@@ -835,17 +895,30 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// isWarmupRequest 检测是否为预热请求(标题生成、Warmup等)
|
||||
func isWarmupRequest(body []byte) bool {
|
||||
// 快速检查:如果body不包含关键字,直接返回false
|
||||
// InterceptType 表示请求拦截类型
|
||||
type InterceptType int
|
||||
|
||||
const (
|
||||
InterceptTypeNone InterceptType = iota
|
||||
InterceptTypeWarmup // 预热请求(返回 "New Conversation")
|
||||
InterceptTypeSuggestionMode // SUGGESTION MODE(返回空字符串)
|
||||
)
|
||||
|
||||
// detectInterceptType 检测请求是否需要拦截,返回拦截类型
|
||||
func detectInterceptType(body []byte) InterceptType {
|
||||
// 快速检查:如果不包含任何关键字,直接返回
|
||||
bodyStr := string(body)
|
||||
if !strings.Contains(bodyStr, "title") && !strings.Contains(bodyStr, "Warmup") {
|
||||
return false
|
||||
hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:")
|
||||
hasWarmupKeyword := strings.Contains(bodyStr, "title") || strings.Contains(bodyStr, "Warmup")
|
||||
|
||||
if !hasSuggestionMode && !hasWarmupKeyword {
|
||||
return InterceptTypeNone
|
||||
}
|
||||
|
||||
// 解析完整请求
|
||||
// 解析请求(只解析一次)
|
||||
var req struct {
|
||||
Messages []struct {
|
||||
Role string `json:"role"`
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
@@ -856,43 +929,71 @@ func isWarmupRequest(body []byte) bool {
|
||||
} `json:"system"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return false
|
||||
return InterceptTypeNone
|
||||
}
|
||||
|
||||
// 检查 messages 中的标题提示模式
|
||||
for _, msg := range req.Messages {
|
||||
for _, content := range msg.Content {
|
||||
if content.Type == "text" {
|
||||
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
|
||||
content.Text == "Warmup" {
|
||||
return true
|
||||
// 检查 SUGGESTION MODE(最后一条 user 消息)
|
||||
if hasSuggestionMode && len(req.Messages) > 0 {
|
||||
lastMsg := req.Messages[len(req.Messages)-1]
|
||||
if lastMsg.Role == "user" && len(lastMsg.Content) > 0 &&
|
||||
lastMsg.Content[0].Type == "text" &&
|
||||
strings.HasPrefix(lastMsg.Content[0].Text, "[SUGGESTION MODE:") {
|
||||
return InterceptTypeSuggestionMode
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 Warmup 请求
|
||||
if hasWarmupKeyword {
|
||||
// 检查 messages 中的标题提示模式
|
||||
for _, msg := range req.Messages {
|
||||
for _, content := range msg.Content {
|
||||
if content.Type == "text" {
|
||||
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
|
||||
content.Text == "Warmup" {
|
||||
return InterceptTypeWarmup
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// 检查 system 中的标题提取模式
|
||||
for _, sys := range req.System {
|
||||
if strings.Contains(sys.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
|
||||
return InterceptTypeWarmup
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 system 中的标题提取模式
|
||||
for _, system := range req.System {
|
||||
if strings.Contains(system.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return InterceptTypeNone
|
||||
}
|
||||
|
||||
// sendMockWarmupStream 发送流式 mock 响应(用于预热请求拦截)
|
||||
func sendMockWarmupStream(c *gin.Context, model string) {
|
||||
// sendMockInterceptStream 发送流式 mock 响应(用于请求拦截)
|
||||
func sendMockInterceptStream(c *gin.Context, model string, interceptType InterceptType) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
// 根据拦截类型决定响应内容
|
||||
var msgID string
|
||||
var outputTokens int
|
||||
var textDeltas []string
|
||||
|
||||
switch interceptType {
|
||||
case InterceptTypeSuggestionMode:
|
||||
msgID = "msg_mock_suggestion"
|
||||
outputTokens = 1
|
||||
textDeltas = []string{""} // 空内容
|
||||
default: // InterceptTypeWarmup
|
||||
msgID = "msg_mock_warmup"
|
||||
outputTokens = 2
|
||||
textDeltas = []string{"New", " Conversation"}
|
||||
}
|
||||
|
||||
// Build message_start event with proper JSON marshaling
|
||||
messageStart := map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": "msg_mock_warmup",
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
@@ -907,16 +1008,46 @@ func sendMockWarmupStream(c *gin.Context, model string) {
|
||||
}
|
||||
messageStartJSON, _ := json.Marshal(messageStart)
|
||||
|
||||
// Build events
|
||||
events := []string{
|
||||
`event: message_start` + "\n" + `data: ` + string(messageStartJSON),
|
||||
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
|
||||
`event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
|
||||
`event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
|
||||
`event: content_block_stop` + "\n" + `data: {"index":0,"type":"content_block_stop"}`,
|
||||
`event: message_delta` + "\n" + `data: {"delta":{"stop_reason":"end_turn","stop_sequence":null},"type":"message_delta","usage":{"input_tokens":10,"output_tokens":2}}`,
|
||||
`event: message_stop` + "\n" + `data: {"type":"message_stop"}`,
|
||||
}
|
||||
|
||||
// Add text deltas
|
||||
for _, text := range textDeltas {
|
||||
delta := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": map[string]string{
|
||||
"type": "text_delta",
|
||||
"text": text,
|
||||
},
|
||||
}
|
||||
deltaJSON, _ := json.Marshal(delta)
|
||||
events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON))
|
||||
}
|
||||
|
||||
// Add final events
|
||||
messageDelta := map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": map[string]int{
|
||||
"input_tokens": 10,
|
||||
"output_tokens": outputTokens,
|
||||
},
|
||||
}
|
||||
messageDeltaJSON, _ := json.Marshal(messageDelta)
|
||||
|
||||
events = append(events,
|
||||
`event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`,
|
||||
`event: message_delta`+"\n"+`data: `+string(messageDeltaJSON),
|
||||
`event: message_stop`+"\n"+`data: {"type":"message_stop"}`,
|
||||
)
|
||||
|
||||
for _, event := range events {
|
||||
_, _ = c.Writer.WriteString(event + "\n\n")
|
||||
c.Writer.Flush()
|
||||
@@ -924,18 +1055,32 @@ func sendMockWarmupStream(c *gin.Context, model string) {
|
||||
}
|
||||
}
|
||||
|
||||
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
|
||||
func sendMockWarmupResponse(c *gin.Context, model string) {
|
||||
// sendMockInterceptResponse 发送非流式 mock 响应(用于请求拦截)
|
||||
func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) {
|
||||
var msgID, text string
|
||||
var outputTokens int
|
||||
|
||||
switch interceptType {
|
||||
case InterceptTypeSuggestionMode:
|
||||
msgID = "msg_mock_suggestion"
|
||||
text = ""
|
||||
outputTokens = 1
|
||||
default: // InterceptTypeWarmup
|
||||
msgID = "msg_mock_warmup"
|
||||
text = "New Conversation"
|
||||
outputTokens = 2
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"id": "msg_mock_warmup",
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
|
||||
"content": []gin.H{{"type": "text", "text": text}},
|
||||
"stop_reason": "end_turn",
|
||||
"usage": gin.H{
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 2,
|
||||
"output_tokens": outputTokens,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
122
backend/internal/handler/gemini_cli_session_test.go
Normal file
122
backend/internal/handler/gemini_cli_session_test.go
Normal file
@@ -0,0 +1,122 @@
|
||||
//go:build unit
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExtractGeminiCLISessionHash(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
privilegedUserID string
|
||||
wantEmpty bool
|
||||
wantHash string
|
||||
}{
|
||||
{
|
||||
name: "with privileged-user-id and tmp dir",
|
||||
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
|
||||
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
|
||||
wantEmpty: false,
|
||||
wantHash: func() string {
|
||||
combined := "90785f52-8bbe-4b17-b111-a1ddea1636c3:f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"
|
||||
hash := sha256.Sum256([]byte(combined))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "without privileged-user-id but with tmp dir",
|
||||
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
|
||||
privilegedUserID: "",
|
||||
wantEmpty: false,
|
||||
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
|
||||
},
|
||||
{
|
||||
name: "without tmp dir",
|
||||
body: `{"contents":[{"parts":[{"text":"Hello world"}]}]}`,
|
||||
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
body: "",
|
||||
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
|
||||
wantEmpty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 创建测试上下文
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/test", nil)
|
||||
if tt.privilegedUserID != "" {
|
||||
c.Request.Header.Set("x-gemini-api-privileged-user-id", tt.privilegedUserID)
|
||||
}
|
||||
|
||||
// 调用函数
|
||||
result := extractGeminiCLISessionHash(c, []byte(tt.body))
|
||||
|
||||
// 验证结果
|
||||
if tt.wantEmpty {
|
||||
require.Empty(t, result, "expected empty session hash")
|
||||
} else {
|
||||
require.NotEmpty(t, result, "expected non-empty session hash")
|
||||
require.Equal(t, tt.wantHash, result, "session hash mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiCLITmpDirRegex(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantMatch bool
|
||||
wantHash string
|
||||
}{
|
||||
{
|
||||
name: "valid tmp dir path",
|
||||
input: "/Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
|
||||
wantMatch: true,
|
||||
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
|
||||
},
|
||||
{
|
||||
name: "valid tmp dir path in text",
|
||||
input: "The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740\nOther text",
|
||||
wantMatch: true,
|
||||
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
|
||||
},
|
||||
{
|
||||
name: "invalid hash length",
|
||||
input: "/Users/ianshaw/.gemini/tmp/abc123",
|
||||
wantMatch: false,
|
||||
},
|
||||
{
|
||||
name: "no tmp dir",
|
||||
input: "Hello world",
|
||||
wantMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
match := geminiCLITmpDirRegex.FindStringSubmatch(tt.input)
|
||||
if tt.wantMatch {
|
||||
require.NotNil(t, match, "expected regex to match")
|
||||
require.Len(t, match, 2, "expected 2 capture groups")
|
||||
require.Equal(t, tt.wantHash, match[1], "hash mismatch")
|
||||
} else {
|
||||
require.Nil(t, match, "expected regex not to match")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,15 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -20,6 +24,17 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
|
||||
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
|
||||
var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
|
||||
|
||||
func isGeminiCLIRequest(c *gin.Context, body []byte) bool {
|
||||
if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" {
|
||||
return true
|
||||
}
|
||||
return geminiCLITmpDirRegex.Match(body)
|
||||
}
|
||||
|
||||
// GeminiV1BetaListModels proxies:
|
||||
// GET /v1beta/models
|
||||
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
|
||||
@@ -215,12 +230,26 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 3) select account (sticky session based on request body)
|
||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
// 优先使用 Gemini CLI 的会话标识(privileged-user-id + tmp 目录哈希)
|
||||
sessionHash := extractGeminiCLISessionHash(c, body)
|
||||
if sessionHash == "" {
|
||||
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
|
||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
}
|
||||
sessionKey := sessionHash
|
||||
if sessionHash != "" {
|
||||
sessionKey = "gemini:" + sessionHash
|
||||
}
|
||||
|
||||
// 查询粘性会话绑定的账号 ID(用于检测账号切换)
|
||||
var sessionBoundAccountID int64
|
||||
if sessionKey != "" {
|
||||
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
|
||||
}
|
||||
isCLI := isGeminiCLIRequest(c, body)
|
||||
cleanedForUnknownBinding := false
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
@@ -239,6 +268,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
|
||||
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
|
||||
// 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
|
||||
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
|
||||
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
|
||||
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||||
sessionBoundAccountID = account.ID
|
||||
} else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
|
||||
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,CLI 继续携带旧签名。
|
||||
// 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
|
||||
log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively")
|
||||
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||||
cleanedForUnknownBinding = true
|
||||
sessionBoundAccountID = account.ID
|
||||
} else if sessionBoundAccountID == 0 {
|
||||
// 记录本次请求中首次选择到的账号,便于同一请求内 failover 时检测切换。
|
||||
sessionBoundAccountID = account.ID
|
||||
}
|
||||
|
||||
// 4) account concurrency slot
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
@@ -438,3 +485,38 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
|
||||
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
|
||||
//
|
||||
// 会话标识生成策略:
|
||||
// 1. 从请求体中提取 tmp 目录哈希(64位十六进制)
|
||||
// 2. 从 header 中提取 privileged-user-id(UUID)
|
||||
// 3. 组合两者生成 SHA256 哈希作为最终的会话标识
|
||||
//
|
||||
// 如果找不到 tmp 目录哈希,返回空字符串(不使用粘性会话)。
|
||||
//
|
||||
// extractGeminiCLISessionHash extracts session identifier from Gemini CLI requests.
|
||||
// Combines x-gemini-api-privileged-user-id header with tmp directory hash from request body.
|
||||
func extractGeminiCLISessionHash(c *gin.Context, body []byte) string {
|
||||
// 1. 从请求体中提取 tmp 目录哈希
|
||||
match := geminiCLITmpDirRegex.FindSubmatch(body)
|
||||
if len(match) < 2 {
|
||||
return "" // 没有找到 tmp 目录,不使用粘性会话
|
||||
}
|
||||
tmpDirHash := string(match[1])
|
||||
|
||||
// 2. 提取 privileged-user-id
|
||||
privilegedUserID := strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id"))
|
||||
|
||||
// 3. 组合生成最终的 session hash
|
||||
if privilegedUserID != "" {
|
||||
// 组合两个标识符:privileged-user-id + tmp 目录哈希
|
||||
combined := privilegedUserID + ":" + tmpDirHash
|
||||
hash := sha256.Sum256([]byte(combined))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// 如果没有 privileged-user-id,直接使用 tmp 目录哈希
|
||||
return tmpDirHash
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ type AdminHandlers struct {
|
||||
User *admin.UserHandler
|
||||
Group *admin.GroupHandler
|
||||
Account *admin.AccountHandler
|
||||
Announcement *admin.AnnouncementHandler
|
||||
OAuth *admin.OAuthHandler
|
||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||
GeminiOAuth *admin.GeminiOAuthHandler
|
||||
@@ -33,10 +34,12 @@ type Handlers struct {
|
||||
Usage *UsageHandler
|
||||
Redeem *RedeemHandler
|
||||
Subscription *SubscriptionHandler
|
||||
Announcement *AnnouncementHandler
|
||||
Admin *AdminHandlers
|
||||
Gateway *GatewayHandler
|
||||
OpenAIGateway *OpenAIGatewayHandler
|
||||
Setting *SettingHandler
|
||||
Totp *TotpHandler
|
||||
}
|
||||
|
||||
// BuildInfo contains build-time information
|
||||
|
||||
@@ -192,8 +192,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Generate session hash (from header for OpenAI)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c)
|
||||
// Generate session hash (header first; fallback to prompt_cache_key)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody)
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
|
||||
@@ -905,7 +905,7 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool {
|
||||
|
||||
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
|
||||
switch strings.TrimSpace(code) {
|
||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID", "USER_INACTIVE":
|
||||
return true
|
||||
}
|
||||
if phase == "billing" || phase == "concurrency" {
|
||||
@@ -1011,5 +1011,12 @@ func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message
|
||||
}
|
||||
}
|
||||
|
||||
// Check if invalid/missing API key errors should be ignored (user misconfiguration)
|
||||
if settings.IgnoreInvalidApiKeyErrors {
|
||||
if strings.Contains(bodyLower, "invalid_api_key") || strings.Contains(bodyLower, "api_key_required") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -32,18 +32,24 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
response.Success(c, dto.PublicSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
Version: h.version,
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
181
backend/internal/handler/totp_handler.go
Normal file
181
backend/internal/handler/totp_handler.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// TotpHandler handles TOTP-related requests
|
||||
type TotpHandler struct {
|
||||
totpService *service.TotpService
|
||||
}
|
||||
|
||||
// NewTotpHandler creates a new TotpHandler
|
||||
func NewTotpHandler(totpService *service.TotpService) *TotpHandler {
|
||||
return &TotpHandler{
|
||||
totpService: totpService,
|
||||
}
|
||||
}
|
||||
|
||||
// TotpStatusResponse represents the TOTP status response
|
||||
type TotpStatusResponse struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
EnabledAt *int64 `json:"enabled_at,omitempty"` // Unix timestamp
|
||||
FeatureEnabled bool `json:"feature_enabled"`
|
||||
}
|
||||
|
||||
// GetStatus returns the TOTP status for the current user
|
||||
// GET /api/v1/user/totp/status
|
||||
func (h *TotpHandler) GetStatus(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
status, err := h.totpService.GetStatus(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
resp := TotpStatusResponse{
|
||||
Enabled: status.Enabled,
|
||||
FeatureEnabled: status.FeatureEnabled,
|
||||
}
|
||||
|
||||
if status.EnabledAt != nil {
|
||||
ts := status.EnabledAt.Unix()
|
||||
resp.EnabledAt = &ts
|
||||
}
|
||||
|
||||
response.Success(c, resp)
|
||||
}
|
||||
|
||||
// TotpSetupRequest represents the request to initiate TOTP setup
|
||||
type TotpSetupRequest struct {
|
||||
EmailCode string `json:"email_code"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// TotpSetupResponse represents the TOTP setup response
|
||||
type TotpSetupResponse struct {
|
||||
Secret string `json:"secret"`
|
||||
QRCodeURL string `json:"qr_code_url"`
|
||||
SetupToken string `json:"setup_token"`
|
||||
Countdown int `json:"countdown"`
|
||||
}
|
||||
|
||||
// InitiateSetup starts the TOTP setup process
|
||||
// POST /api/v1/user/totp/setup
|
||||
func (h *TotpHandler) InitiateSetup(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req TotpSetupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
// Allow empty body (optional params)
|
||||
req = TotpSetupRequest{}
|
||||
}
|
||||
|
||||
result, err := h.totpService.InitiateSetup(c.Request.Context(), subject.UserID, req.EmailCode, req.Password)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, TotpSetupResponse{
|
||||
Secret: result.Secret,
|
||||
QRCodeURL: result.QRCodeURL,
|
||||
SetupToken: result.SetupToken,
|
||||
Countdown: result.Countdown,
|
||||
})
|
||||
}
|
||||
|
||||
// TotpEnableRequest represents the request to enable TOTP
|
||||
type TotpEnableRequest struct {
|
||||
TotpCode string `json:"totp_code" binding:"required,len=6"`
|
||||
SetupToken string `json:"setup_token" binding:"required"`
|
||||
}
|
||||
|
||||
// Enable completes the TOTP setup
|
||||
// POST /api/v1/user/totp/enable
|
||||
func (h *TotpHandler) Enable(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req TotpEnableRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.totpService.CompleteSetup(c.Request.Context(), subject.UserID, req.TotpCode, req.SetupToken); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// TotpDisableRequest represents the request to disable TOTP
|
||||
type TotpDisableRequest struct {
|
||||
EmailCode string `json:"email_code"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// Disable disables TOTP for the current user
|
||||
// POST /api/v1/user/totp/disable
|
||||
func (h *TotpHandler) Disable(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req TotpDisableRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.totpService.Disable(c.Request.Context(), subject.UserID, req.EmailCode, req.Password); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// GetVerificationMethod returns the verification method for TOTP operations
|
||||
// GET /api/v1/user/totp/verification-method
|
||||
func (h *TotpHandler) GetVerificationMethod(c *gin.Context) {
|
||||
method := h.totpService.GetVerificationMethod(c.Request.Context())
|
||||
response.Success(c, method)
|
||||
}
|
||||
|
||||
// SendVerifyCode sends an email verification code for TOTP operations
|
||||
// POST /api/v1/user/totp/send-code
|
||||
func (h *TotpHandler) SendVerifyCode(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
@@ -47,9 +47,6 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 清空notes字段,普通用户不应看到备注
|
||||
userData.Notes = ""
|
||||
|
||||
response.Success(c, dto.UserFromService(userData))
|
||||
}
|
||||
|
||||
@@ -105,8 +102,5 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 清空notes字段,普通用户不应看到备注
|
||||
updatedUser.Notes = ""
|
||||
|
||||
response.Success(c, dto.UserFromService(updatedUser))
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ func ProvideAdminHandlers(
|
||||
userHandler *admin.UserHandler,
|
||||
groupHandler *admin.GroupHandler,
|
||||
accountHandler *admin.AccountHandler,
|
||||
announcementHandler *admin.AnnouncementHandler,
|
||||
oauthHandler *admin.OAuthHandler,
|
||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
||||
@@ -32,6 +33,7 @@ func ProvideAdminHandlers(
|
||||
User: userHandler,
|
||||
Group: groupHandler,
|
||||
Account: accountHandler,
|
||||
Announcement: announcementHandler,
|
||||
OAuth: oauthHandler,
|
||||
OpenAIOAuth: openaiOAuthHandler,
|
||||
GeminiOAuth: geminiOAuthHandler,
|
||||
@@ -66,10 +68,12 @@ func ProvideHandlers(
|
||||
usageHandler *UsageHandler,
|
||||
redeemHandler *RedeemHandler,
|
||||
subscriptionHandler *SubscriptionHandler,
|
||||
announcementHandler *AnnouncementHandler,
|
||||
adminHandlers *AdminHandlers,
|
||||
gatewayHandler *GatewayHandler,
|
||||
openaiGatewayHandler *OpenAIGatewayHandler,
|
||||
settingHandler *SettingHandler,
|
||||
totpHandler *TotpHandler,
|
||||
) *Handlers {
|
||||
return &Handlers{
|
||||
Auth: authHandler,
|
||||
@@ -78,10 +82,12 @@ func ProvideHandlers(
|
||||
Usage: usageHandler,
|
||||
Redeem: redeemHandler,
|
||||
Subscription: subscriptionHandler,
|
||||
Announcement: announcementHandler,
|
||||
Admin: adminHandlers,
|
||||
Gateway: gatewayHandler,
|
||||
OpenAIGateway: openaiGatewayHandler,
|
||||
Setting: settingHandler,
|
||||
Totp: totpHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,8 +100,10 @@ var ProviderSet = wire.NewSet(
|
||||
NewUsageHandler,
|
||||
NewRedeemHandler,
|
||||
NewSubscriptionHandler,
|
||||
NewAnnouncementHandler,
|
||||
NewGatewayHandler,
|
||||
NewOpenAIGatewayHandler,
|
||||
NewTotpHandler,
|
||||
ProvideSettingHandler,
|
||||
|
||||
// Admin handlers
|
||||
@@ -103,6 +111,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewUserHandler,
|
||||
admin.NewGroupHandler,
|
||||
admin.NewAccountHandler,
|
||||
admin.NewAnnouncementHandler,
|
||||
admin.NewOAuthHandler,
|
||||
admin.NewOpenAIOAuthHandler,
|
||||
admin.NewGeminiOAuthHandler,
|
||||
|
||||
@@ -7,6 +7,9 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -88,6 +91,7 @@ func performRequest(router *gin.Engine) *httptest.ResponseRecorder {
|
||||
|
||||
func startRedis(t *testing.T, ctx context.Context) *redis.Client {
|
||||
t.Helper()
|
||||
ensureDockerAvailable(t)
|
||||
|
||||
redisContainer, err := tcredis.Run(ctx, redisImageTag)
|
||||
require.NoError(t, err)
|
||||
@@ -112,3 +116,43 @@ func startRedis(t *testing.T, ctx context.Context) *redis.Client {
|
||||
|
||||
return rdb
|
||||
}
|
||||
|
||||
func ensureDockerAvailable(t *testing.T) {
|
||||
t.Helper()
|
||||
if dockerAvailable() {
|
||||
return
|
||||
}
|
||||
t.Skip("Docker 未启用,跳过依赖 testcontainers 的集成测试")
|
||||
}
|
||||
|
||||
func dockerAvailable() bool {
|
||||
if os.Getenv("DOCKER_HOST") != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
socketCandidates := []string{
|
||||
"/var/run/docker.sock",
|
||||
filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"),
|
||||
filepath.Join(userHomeDir(), ".docker", "run", "docker.sock"),
|
||||
filepath.Join(userHomeDir(), ".docker", "desktop", "docker.sock"),
|
||||
filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"),
|
||||
}
|
||||
|
||||
for _, socket := range socketCandidates {
|
||||
if socket == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := os.Stat(socket); err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func userHomeDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return home
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ const (
|
||||
"https://www.googleapis.com/auth/experimentsandconfigs"
|
||||
|
||||
// User-Agent(与 Antigravity-Manager 保持一致)
|
||||
UserAgent = "antigravity/1.11.9 windows/amd64"
|
||||
UserAgent = "antigravity/1.15.8 windows/amd64"
|
||||
|
||||
// Session 过期时间
|
||||
SessionTTL = 30 * time.Minute
|
||||
|
||||
@@ -369,8 +369,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
|
||||
Text: block.Thinking,
|
||||
Thought: true,
|
||||
}
|
||||
// 保留原有 signature(Claude 模型需要有效的 signature)
|
||||
if block.Signature != "" {
|
||||
// signature 处理:
|
||||
// - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
|
||||
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
|
||||
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
|
||||
part.ThoughtSignature = block.Signature
|
||||
} else if !allowDummyThought {
|
||||
// Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。
|
||||
@@ -409,12 +411,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
|
||||
},
|
||||
}
|
||||
// tool_use 的 signature 处理:
|
||||
// - Gemini 模型:使用 dummy signature(跳过 thought_signature 校验)
|
||||
// - Claude 模型:透传上游返回的真实 signature(Vertex/Google 需要完整签名链路)
|
||||
if allowDummyThought {
|
||||
part.ThoughtSignature = dummyThoughtSignature
|
||||
} else if block.Signature != "" && block.Signature != dummyThoughtSignature {
|
||||
// - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
|
||||
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
|
||||
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
|
||||
part.ThoughtSignature = block.Signature
|
||||
} else if allowDummyThought {
|
||||
part.ThoughtSignature = dummyThoughtSignature
|
||||
}
|
||||
parts = append(parts, part)
|
||||
|
||||
|
||||
@@ -100,7 +100,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
|
||||
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"}
|
||||
]`
|
||||
|
||||
t.Run("Gemini uses dummy tool_use signature", func(t *testing.T) {
|
||||
t.Run("Gemini preserves provided tool_use signature", func(t *testing.T) {
|
||||
toolIDToName := make(map[string]string)
|
||||
parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true)
|
||||
if err != nil {
|
||||
@@ -109,6 +109,23 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
|
||||
if len(parts) != 1 || parts[0].FunctionCall == nil {
|
||||
t.Fatalf("expected 1 functionCall part, got %+v", parts)
|
||||
}
|
||||
if parts[0].ThoughtSignature != "sig_tool_abc" {
|
||||
t.Fatalf("expected preserved tool signature %q, got %q", "sig_tool_abc", parts[0].ThoughtSignature)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Gemini falls back to dummy tool_use signature when missing", func(t *testing.T) {
|
||||
contentNoSig := `[
|
||||
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}}
|
||||
]`
|
||||
toolIDToName := make(map[string]string)
|
||||
parts, _, err := buildParts(json.RawMessage(contentNoSig), toolIDToName, true)
|
||||
if err != nil {
|
||||
t.Fatalf("buildParts() error = %v", err)
|
||||
}
|
||||
if len(parts) != 1 || parts[0].FunctionCall == nil {
|
||||
t.Fatalf("expected 1 functionCall part, got %+v", parts)
|
||||
}
|
||||
if parts[0].ThoughtSignature != dummyThoughtSignature {
|
||||
t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature)
|
||||
}
|
||||
|
||||
@@ -20,6 +20,15 @@ func TransformGeminiToClaude(geminiResp []byte, originalModel string) ([]byte, *
|
||||
v1Resp.Response = directResp
|
||||
v1Resp.ResponseID = directResp.ResponseID
|
||||
v1Resp.ModelVersion = directResp.ModelVersion
|
||||
} else if len(v1Resp.Response.Candidates) == 0 {
|
||||
// 第一次解析成功但 candidates 为空,说明是直接的 GeminiResponse 格式
|
||||
var directResp GeminiResponse
|
||||
if err2 := json.Unmarshal(geminiResp, &directResp); err2 != nil {
|
||||
return nil, nil, fmt.Errorf("parse gemini response as direct: %w", err2)
|
||||
}
|
||||
v1Resp.Response = directResp
|
||||
v1Resp.ResponseID = directResp.ResponseID
|
||||
v1Resp.ModelVersion = directResp.ModelVersion
|
||||
}
|
||||
|
||||
// 使用处理器转换
|
||||
@@ -174,16 +183,20 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
p.textBuilder += part.Text
|
||||
|
||||
// 非空 text 带签名 - 立即刷新并输出空 thinking 块
|
||||
// 非空 text 带签名 - 特殊处理:先输出 text,再输出空 thinking 块
|
||||
if signature != "" {
|
||||
p.flushText()
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "text",
|
||||
Text: part.Text,
|
||||
})
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "thinking",
|
||||
Thinking: "",
|
||||
Signature: signature,
|
||||
})
|
||||
} else {
|
||||
// 普通 text (无签名) - 累积到 builder
|
||||
p.textBuilder += part.Text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,14 +16,11 @@ type ModelsListResponse struct {
|
||||
func DefaultModels() []Model {
|
||||
methods := []string{"generateContent", "streamGenerateContent"}
|
||||
return []Model{
|
||||
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-1.5-pro", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-1.5-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-1.5-flash-8b", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,10 +12,10 @@ type Model struct {
|
||||
// DefaultModels is the curated Gemini model list used by the admin UI "test account" flow.
|
||||
var DefaultModels = []Model{
|
||||
{ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
|
||||
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
|
||||
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
|
||||
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
|
||||
}
|
||||
|
||||
// DefaultTestModel is the default model to preselect in test flows.
|
||||
|
||||
@@ -13,20 +13,26 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Claude OAuth Constants (from CRS project)
|
||||
// Claude OAuth Constants
|
||||
const (
|
||||
// OAuth Client ID for Claude
|
||||
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
|
||||
// OAuth endpoints
|
||||
AuthorizeURL = "https://claude.ai/oauth/authorize"
|
||||
TokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||
RedirectURI = "https://console.anthropic.com/oauth/code/callback"
|
||||
TokenURL = "https://platform.claude.com/v1/oauth/token"
|
||||
RedirectURI = "https://platform.claude.com/oauth/code/callback"
|
||||
|
||||
// Scopes
|
||||
ScopeProfile = "user:profile"
|
||||
// Scopes - Browser URL (includes org:create_api_key for user authorization)
|
||||
ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers"
|
||||
// Scopes - Internal API call (org:create_api_key not supported in API)
|
||||
ScopeAPI = "user:profile user:inference user:sessions:claude_code user:mcp_servers"
|
||||
// Scopes - Setup token (inference only)
|
||||
ScopeInference = "user:inference"
|
||||
|
||||
// Code Verifier character set (RFC 7636 compliant)
|
||||
codeVerifierCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
|
||||
|
||||
// Session TTL
|
||||
SessionTTL = 30 * time.Minute
|
||||
)
|
||||
@@ -53,7 +59,6 @@ func NewSessionStore() *SessionStore {
|
||||
sessions: make(map[string]*OAuthSession),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
// Start cleanup goroutine
|
||||
go store.cleanup()
|
||||
return store
|
||||
}
|
||||
@@ -78,7 +83,6 @@ func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
// Check if expired
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
return nil, false
|
||||
}
|
||||
@@ -122,13 +126,13 @@ func GenerateRandomBytes(n int) ([]byte, error) {
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// GenerateState generates a random state string for OAuth
|
||||
// GenerateState generates a random state string for OAuth (base64url encoded)
|
||||
func GenerateState() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
return base64URLEncode(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateSessionID generates a unique session ID
|
||||
@@ -140,13 +144,30 @@ func GenerateSessionID() (string, error) {
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeVerifier generates a PKCE code verifier (32 bytes -> base64url)
|
||||
// GenerateCodeVerifier generates a PKCE code verifier using character set method
|
||||
func GenerateCodeVerifier() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
const targetLen = 32
|
||||
charsetLen := len(codeVerifierCharset)
|
||||
limit := 256 - (256 % charsetLen)
|
||||
|
||||
result := make([]byte, 0, targetLen)
|
||||
randBuf := make([]byte, targetLen*2)
|
||||
|
||||
for len(result) < targetLen {
|
||||
if _, err := rand.Read(randBuf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, b := range randBuf {
|
||||
if int(b) < limit {
|
||||
result = append(result, codeVerifierCharset[int(b)%charsetLen])
|
||||
if len(result) >= targetLen {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return base64URLEncode(bytes), nil
|
||||
|
||||
return base64URLEncode(result), nil
|
||||
}
|
||||
|
||||
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
|
||||
@@ -158,42 +179,31 @@ func GenerateCodeChallenge(verifier string) string {
|
||||
// base64URLEncode encodes bytes to base64url without padding
|
||||
func base64URLEncode(data []byte) string {
|
||||
encoded := base64.URLEncoding.EncodeToString(data)
|
||||
// Remove padding
|
||||
return strings.TrimRight(encoded, "=")
|
||||
}
|
||||
|
||||
// BuildAuthorizationURL builds the OAuth authorization URL
|
||||
// BuildAuthorizationURL builds the OAuth authorization URL with correct parameter order
|
||||
func BuildAuthorizationURL(state, codeChallenge, scope string) string {
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("redirect_uri", RedirectURI)
|
||||
params.Set("scope", scope)
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
encodedRedirectURI := url.QueryEscape(RedirectURI)
|
||||
encodedScope := strings.ReplaceAll(url.QueryEscape(scope), "%20", "+")
|
||||
|
||||
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
|
||||
}
|
||||
|
||||
// TokenRequest represents the token exchange request body
|
||||
type TokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
ClientID string `json:"client_id"`
|
||||
Code string `json:"code"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
State string `json:"state"`
|
||||
return fmt.Sprintf("%s?code=true&client_id=%s&response_type=code&redirect_uri=%s&scope=%s&code_challenge=%s&code_challenge_method=S256&state=%s",
|
||||
AuthorizeURL,
|
||||
ClientID,
|
||||
encodedRedirectURI,
|
||||
encodedScope,
|
||||
codeChallenge,
|
||||
state,
|
||||
)
|
||||
}
|
||||
|
||||
// TokenResponse represents the token response from OAuth provider
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
// Organization and Account info from OAuth response
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
Organization *OrgInfo `json:"organization,omitempty"`
|
||||
Account *AccountInfo `json:"account,omitempty"`
|
||||
}
|
||||
@@ -205,33 +215,6 @@ type OrgInfo struct {
|
||||
|
||||
// AccountInfo represents account info from OAuth response
|
||||
type AccountInfo struct {
|
||||
UUID string `json:"uuid"`
|
||||
}
|
||||
|
||||
// RefreshTokenRequest represents the refresh token request
|
||||
type RefreshTokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ClientID string `json:"client_id"`
|
||||
}
|
||||
|
||||
// BuildTokenRequest creates a token exchange request
|
||||
func BuildTokenRequest(code, codeVerifier, state string) *TokenRequest {
|
||||
return &TokenRequest{
|
||||
GrantType: "authorization_code",
|
||||
ClientID: ClientID,
|
||||
Code: code,
|
||||
RedirectURI: RedirectURI,
|
||||
CodeVerifier: codeVerifier,
|
||||
State: state,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildRefreshTokenRequest creates a refresh token request
|
||||
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
|
||||
return &RefreshTokenRequest{
|
||||
GrantType: "refresh_token",
|
||||
RefreshToken: refreshToken,
|
||||
ClientID: ClientID,
|
||||
}
|
||||
UUID string `json:"uuid"`
|
||||
EmailAddress string `json:"email_address"`
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"log"
|
||||
"math"
|
||||
"net/http"
|
||||
|
||||
@@ -74,6 +75,12 @@ func ErrorFrom(c *gin.Context, err error) bool {
|
||||
}
|
||||
|
||||
statusCode, status := infraerrors.ToHTTP(err)
|
||||
|
||||
// Log internal errors with full details for debugging
|
||||
if statusCode >= 500 && c.Request != nil {
|
||||
log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, err.Error())
|
||||
}
|
||||
|
||||
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
|
||||
return true
|
||||
}
|
||||
@@ -162,11 +169,11 @@ func ParsePagination(c *gin.Context) (page, pageSize int) {
|
||||
|
||||
// 支持 page_size 和 limit 两种参数名
|
||||
if ps := c.Query("page_size"); ps != "" {
|
||||
if val, err := parseInt(ps); err == nil && val > 0 && val <= 100 {
|
||||
if val, err := parseInt(ps); err == nil && val > 0 && val <= 1000 {
|
||||
pageSize = val
|
||||
}
|
||||
} else if l := c.Query("limit"); l != "" {
|
||||
if val, err := parseInt(l); err == nil && val > 0 && val <= 100 {
|
||||
if val, err := parseInt(l); err == nil && val > 0 && val <= 1000 {
|
||||
pageSize = val
|
||||
}
|
||||
}
|
||||
|
||||
568
backend/internal/pkg/tlsfingerprint/dialer.go
Normal file
568
backend/internal/pkg/tlsfingerprint/dialer.go
Normal file
@@ -0,0 +1,568 @@
|
||||
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
|
||||
// It uses the utls library to create TLS connections that mimic Node.js/Claude Code clients.
|
||||
package tlsfingerprint
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
utls "github.com/refraction-networking/utls"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// Profile contains TLS fingerprint configuration.
|
||||
type Profile struct {
|
||||
Name string // Profile name for identification
|
||||
CipherSuites []uint16
|
||||
Curves []uint16
|
||||
PointFormats []uint8
|
||||
EnableGREASE bool
|
||||
}
|
||||
|
||||
// Dialer creates TLS connections with custom fingerprints.
|
||||
type Dialer struct {
|
||||
profile *Profile
|
||||
baseDialer func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// HTTPProxyDialer creates TLS connections through HTTP/HTTPS proxies with custom fingerprints.
|
||||
// It handles the CONNECT tunnel establishment before performing TLS handshake.
|
||||
type HTTPProxyDialer struct {
|
||||
profile *Profile
|
||||
proxyURL *url.URL
|
||||
}
|
||||
|
||||
// SOCKS5ProxyDialer creates TLS connections through SOCKS5 proxies with custom fingerprints.
|
||||
// It uses golang.org/x/net/proxy to establish the SOCKS5 tunnel.
|
||||
type SOCKS5ProxyDialer struct {
|
||||
profile *Profile
|
||||
proxyURL *url.URL
|
||||
}
|
||||
|
||||
// Default TLS fingerprint values captured from Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)
|
||||
// Captured using: tshark -i lo -f "tcp port 8443" -Y "tls.handshake.type == 1" -V
|
||||
// JA3 Hash: 1a28e69016765d92e3b381168d68922c
|
||||
//
|
||||
// Note: JA3/JA4 may have slight variations due to:
|
||||
// - Session ticket presence/absence
|
||||
// - Extension negotiation state
|
||||
var (
|
||||
// defaultCipherSuites contains all 59 cipher suites from Claude CLI
|
||||
// Order is critical for JA3 fingerprint matching
|
||||
defaultCipherSuites = []uint16{
|
||||
// TLS 1.3 cipher suites (MUST be first)
|
||||
0x1302, // TLS_AES_256_GCM_SHA384
|
||||
0x1303, // TLS_CHACHA20_POLY1305_SHA256
|
||||
0x1301, // TLS_AES_128_GCM_SHA256
|
||||
|
||||
// ECDHE + AES-GCM
|
||||
0xc02f, // TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
|
||||
0xc02b, // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
|
||||
0xc030, // TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
|
||||
0xc02c, // TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
|
||||
|
||||
// DHE + AES-GCM
|
||||
0x009e, // TLS_DHE_RSA_WITH_AES_128_GCM_SHA256
|
||||
|
||||
// ECDHE/DHE + AES-CBC-SHA256/384
|
||||
0xc027, // TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256
|
||||
0x0067, // TLS_DHE_RSA_WITH_AES_128_CBC_SHA256
|
||||
0xc028, // TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384
|
||||
0x006b, // TLS_DHE_RSA_WITH_AES_256_CBC_SHA256
|
||||
|
||||
// DHE-DSS/RSA + AES-GCM
|
||||
0x00a3, // TLS_DHE_DSS_WITH_AES_256_GCM_SHA384
|
||||
0x009f, // TLS_DHE_RSA_WITH_AES_256_GCM_SHA384
|
||||
|
||||
// ChaCha20-Poly1305
|
||||
0xcca9, // TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
|
||||
0xcca8, // TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
|
||||
0xccaa, // TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256
|
||||
|
||||
// AES-CCM (256-bit)
|
||||
0xc0af, // TLS_ECDHE_ECDSA_WITH_AES_256_CCM_8
|
||||
0xc0ad, // TLS_ECDHE_ECDSA_WITH_AES_256_CCM
|
||||
0xc0a3, // TLS_DHE_RSA_WITH_AES_256_CCM_8
|
||||
0xc09f, // TLS_DHE_RSA_WITH_AES_256_CCM
|
||||
|
||||
// ARIA (256-bit)
|
||||
0xc05d, // TLS_ECDHE_ECDSA_WITH_ARIA_256_GCM_SHA384
|
||||
0xc061, // TLS_ECDHE_RSA_WITH_ARIA_256_GCM_SHA384
|
||||
0xc057, // TLS_DHE_DSS_WITH_ARIA_256_GCM_SHA384
|
||||
0xc053, // TLS_DHE_RSA_WITH_ARIA_256_GCM_SHA384
|
||||
|
||||
// DHE-DSS + AES-GCM (128-bit)
|
||||
0x00a2, // TLS_DHE_DSS_WITH_AES_128_GCM_SHA256
|
||||
|
||||
// AES-CCM (128-bit)
|
||||
0xc0ae, // TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8
|
||||
0xc0ac, // TLS_ECDHE_ECDSA_WITH_AES_128_CCM
|
||||
0xc0a2, // TLS_DHE_RSA_WITH_AES_128_CCM_8
|
||||
0xc09e, // TLS_DHE_RSA_WITH_AES_128_CCM
|
||||
|
||||
// ARIA (128-bit)
|
||||
0xc05c, // TLS_ECDHE_ECDSA_WITH_ARIA_128_GCM_SHA256
|
||||
0xc060, // TLS_ECDHE_RSA_WITH_ARIA_128_GCM_SHA256
|
||||
0xc056, // TLS_DHE_DSS_WITH_ARIA_128_GCM_SHA256
|
||||
0xc052, // TLS_DHE_RSA_WITH_ARIA_128_GCM_SHA256
|
||||
|
||||
// ECDHE/DHE + AES-CBC-SHA384/256 (more)
|
||||
0xc024, // TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384
|
||||
0x006a, // TLS_DHE_DSS_WITH_AES_256_CBC_SHA256
|
||||
0xc023, // TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256
|
||||
0x0040, // TLS_DHE_DSS_WITH_AES_128_CBC_SHA256
|
||||
|
||||
// ECDHE/DHE + AES-CBC-SHA (legacy)
|
||||
0xc00a, // TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA
|
||||
0xc014, // TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA
|
||||
0x0039, // TLS_DHE_RSA_WITH_AES_256_CBC_SHA
|
||||
0x0038, // TLS_DHE_DSS_WITH_AES_256_CBC_SHA
|
||||
0xc009, // TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA
|
||||
0xc013, // TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA
|
||||
0x0033, // TLS_DHE_RSA_WITH_AES_128_CBC_SHA
|
||||
0x0032, // TLS_DHE_DSS_WITH_AES_128_CBC_SHA
|
||||
|
||||
// RSA + AES-GCM/CCM/ARIA (non-PFS, 256-bit)
|
||||
0x009d, // TLS_RSA_WITH_AES_256_GCM_SHA384
|
||||
0xc0a1, // TLS_RSA_WITH_AES_256_CCM_8
|
||||
0xc09d, // TLS_RSA_WITH_AES_256_CCM
|
||||
0xc051, // TLS_RSA_WITH_ARIA_256_GCM_SHA384
|
||||
|
||||
// RSA + AES-GCM/CCM/ARIA (non-PFS, 128-bit)
|
||||
0x009c, // TLS_RSA_WITH_AES_128_GCM_SHA256
|
||||
0xc0a0, // TLS_RSA_WITH_AES_128_CCM_8
|
||||
0xc09c, // TLS_RSA_WITH_AES_128_CCM
|
||||
0xc050, // TLS_RSA_WITH_ARIA_128_GCM_SHA256
|
||||
|
||||
// RSA + AES-CBC (non-PFS, legacy)
|
||||
0x003d, // TLS_RSA_WITH_AES_256_CBC_SHA256
|
||||
0x003c, // TLS_RSA_WITH_AES_128_CBC_SHA256
|
||||
0x0035, // TLS_RSA_WITH_AES_256_CBC_SHA
|
||||
0x002f, // TLS_RSA_WITH_AES_128_CBC_SHA
|
||||
|
||||
// Renegotiation indication
|
||||
0x00ff, // TLS_EMPTY_RENEGOTIATION_INFO_SCSV
|
||||
}
|
||||
|
||||
// defaultCurves contains the 10 supported groups from Claude CLI (including FFDHE)
|
||||
defaultCurves = []utls.CurveID{
|
||||
utls.X25519, // 0x001d
|
||||
utls.CurveP256, // 0x0017 (secp256r1)
|
||||
utls.CurveID(0x001e), // x448
|
||||
utls.CurveP521, // 0x0019 (secp521r1)
|
||||
utls.CurveP384, // 0x0018 (secp384r1)
|
||||
utls.CurveID(0x0100), // ffdhe2048
|
||||
utls.CurveID(0x0101), // ffdhe3072
|
||||
utls.CurveID(0x0102), // ffdhe4096
|
||||
utls.CurveID(0x0103), // ffdhe6144
|
||||
utls.CurveID(0x0104), // ffdhe8192
|
||||
}
|
||||
|
||||
// defaultPointFormats contains all 3 point formats from Claude CLI
|
||||
defaultPointFormats = []uint8{
|
||||
0, // uncompressed
|
||||
1, // ansiX962_compressed_prime
|
||||
2, // ansiX962_compressed_char2
|
||||
}
|
||||
|
||||
// defaultSignatureAlgorithms contains the 20 signature algorithms from Claude CLI
|
||||
defaultSignatureAlgorithms = []utls.SignatureScheme{
|
||||
0x0403, // ecdsa_secp256r1_sha256
|
||||
0x0503, // ecdsa_secp384r1_sha384
|
||||
0x0603, // ecdsa_secp521r1_sha512
|
||||
0x0807, // ed25519
|
||||
0x0808, // ed448
|
||||
0x0809, // rsa_pss_pss_sha256
|
||||
0x080a, // rsa_pss_pss_sha384
|
||||
0x080b, // rsa_pss_pss_sha512
|
||||
0x0804, // rsa_pss_rsae_sha256
|
||||
0x0805, // rsa_pss_rsae_sha384
|
||||
0x0806, // rsa_pss_rsae_sha512
|
||||
0x0401, // rsa_pkcs1_sha256
|
||||
0x0501, // rsa_pkcs1_sha384
|
||||
0x0601, // rsa_pkcs1_sha512
|
||||
0x0303, // ecdsa_sha224
|
||||
0x0301, // rsa_pkcs1_sha224
|
||||
0x0302, // dsa_sha224
|
||||
0x0402, // dsa_sha256
|
||||
0x0502, // dsa_sha384
|
||||
0x0602, // dsa_sha512
|
||||
}
|
||||
)
|
||||
|
||||
// NewDialer creates a new TLS fingerprint dialer.
|
||||
// baseDialer is used for TCP connection establishment (supports proxy scenarios).
|
||||
// If baseDialer is nil, direct TCP dial is used.
|
||||
func NewDialer(profile *Profile, baseDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *Dialer {
|
||||
if baseDialer == nil {
|
||||
baseDialer = (&net.Dialer{}).DialContext
|
||||
}
|
||||
return &Dialer{profile: profile, baseDialer: baseDialer}
|
||||
}
|
||||
|
||||
// NewHTTPProxyDialer creates a new TLS fingerprint dialer that works through HTTP/HTTPS proxies.
|
||||
// It establishes a CONNECT tunnel before performing TLS handshake with custom fingerprint.
|
||||
func NewHTTPProxyDialer(profile *Profile, proxyURL *url.URL) *HTTPProxyDialer {
|
||||
return &HTTPProxyDialer{profile: profile, proxyURL: proxyURL}
|
||||
}
|
||||
|
||||
// NewSOCKS5ProxyDialer creates a new TLS fingerprint dialer that works through SOCKS5 proxies.
|
||||
// It establishes a SOCKS5 tunnel before performing TLS handshake with custom fingerprint.
|
||||
func NewSOCKS5ProxyDialer(profile *Profile, proxyURL *url.URL) *SOCKS5ProxyDialer {
|
||||
return &SOCKS5ProxyDialer{profile: profile, proxyURL: proxyURL}
|
||||
}
|
||||
|
||||
// DialTLSContext establishes a TLS connection through SOCKS5 proxy with the configured fingerprint.
|
||||
// Flow: SOCKS5 CONNECT to target -> TLS handshake with utls on the tunnel
|
||||
func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
slog.Debug("tls_fingerprint_socks5_connecting", "proxy", d.proxyURL.Host, "target", addr)
|
||||
|
||||
// Step 1: Create SOCKS5 dialer
|
||||
var auth *proxy.Auth
|
||||
if d.proxyURL.User != nil {
|
||||
username := d.proxyURL.User.Username()
|
||||
password, _ := d.proxyURL.User.Password()
|
||||
auth = &proxy.Auth{
|
||||
User: username,
|
||||
Password: password,
|
||||
}
|
||||
}
|
||||
|
||||
// Determine proxy address
|
||||
proxyAddr := d.proxyURL.Host
|
||||
if d.proxyURL.Port() == "" {
|
||||
proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "1080") // Default SOCKS5 port
|
||||
}
|
||||
|
||||
socksDialer, err := proxy.SOCKS5("tcp", proxyAddr, auth, proxy.Direct)
|
||||
if err != nil {
|
||||
slog.Debug("tls_fingerprint_socks5_dialer_failed", "error", err)
|
||||
return nil, fmt.Errorf("create SOCKS5 dialer: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Establish SOCKS5 tunnel to target
|
||||
slog.Debug("tls_fingerprint_socks5_establishing_tunnel", "target", addr)
|
||||
conn, err := socksDialer.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
slog.Debug("tls_fingerprint_socks5_connect_failed", "error", err)
|
||||
return nil, fmt.Errorf("SOCKS5 connect: %w", err)
|
||||
}
|
||||
slog.Debug("tls_fingerprint_socks5_tunnel_established")
|
||||
|
||||
// Step 3: Perform TLS handshake on the tunnel with utls fingerprint
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr
|
||||
}
|
||||
slog.Debug("tls_fingerprint_socks5_starting_handshake", "host", host)
|
||||
|
||||
// Build ClientHello specification from profile (Node.js/Claude CLI fingerprint)
|
||||
spec := buildClientHelloSpecFromProfile(d.profile)
|
||||
slog.Debug("tls_fingerprint_socks5_clienthello_spec",
|
||||
"cipher_suites", len(spec.CipherSuites),
|
||||
"extensions", len(spec.Extensions),
|
||||
"compression_methods", spec.CompressionMethods,
|
||||
"tls_vers_max", fmt.Sprintf("0x%04x", spec.TLSVersMax),
|
||||
"tls_vers_min", fmt.Sprintf("0x%04x", spec.TLSVersMin))
|
||||
|
||||
if d.profile != nil {
|
||||
slog.Debug("tls_fingerprint_socks5_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
|
||||
}
|
||||
|
||||
// Create uTLS connection on the tunnel
|
||||
tlsConn := utls.UClient(conn, &utls.Config{
|
||||
ServerName: host,
|
||||
}, utls.HelloCustom)
|
||||
|
||||
if err := tlsConn.ApplyPreset(spec); err != nil {
|
||||
slog.Debug("tls_fingerprint_socks5_apply_preset_failed", "error", err)
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("apply TLS preset: %w", err)
|
||||
}
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
slog.Debug("tls_fingerprint_socks5_handshake_failed", "error", err)
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("TLS handshake failed: %w", err)
|
||||
}
|
||||
|
||||
state := tlsConn.ConnectionState()
|
||||
slog.Debug("tls_fingerprint_socks5_handshake_success",
|
||||
"version", fmt.Sprintf("0x%04x", state.Version),
|
||||
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
|
||||
"alpn", state.NegotiatedProtocol)
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
// DialTLSContext establishes a TLS connection through HTTP proxy with the configured fingerprint.
|
||||
// Flow: TCP connect to proxy -> CONNECT tunnel -> TLS handshake with utls
|
||||
func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
slog.Debug("tls_fingerprint_http_proxy_connecting", "proxy", d.proxyURL.Host, "target", addr)
|
||||
|
||||
// Step 1: TCP connect to proxy server
|
||||
var proxyAddr string
|
||||
if d.proxyURL.Port() != "" {
|
||||
proxyAddr = d.proxyURL.Host
|
||||
} else {
|
||||
// Default ports
|
||||
if d.proxyURL.Scheme == "https" {
|
||||
proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "443")
|
||||
} else {
|
||||
proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "80")
|
||||
}
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{}
|
||||
conn, err := dialer.DialContext(ctx, "tcp", proxyAddr)
|
||||
if err != nil {
|
||||
slog.Debug("tls_fingerprint_http_proxy_connect_failed", "error", err)
|
||||
return nil, fmt.Errorf("connect to proxy: %w", err)
|
||||
}
|
||||
slog.Debug("tls_fingerprint_http_proxy_connected", "proxy_addr", proxyAddr)
|
||||
|
||||
// Step 2: Send CONNECT request to establish tunnel
|
||||
req := &http.Request{
|
||||
Method: "CONNECT",
|
||||
URL: &url.URL{Opaque: addr},
|
||||
Host: addr,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
|
||||
// Add proxy authentication if present
|
||||
if d.proxyURL.User != nil {
|
||||
username := d.proxyURL.User.Username()
|
||||
password, _ := d.proxyURL.User.Password()
|
||||
auth := base64.StdEncoding.EncodeToString([]byte(username + ":" + password))
|
||||
req.Header.Set("Proxy-Authorization", "Basic "+auth)
|
||||
}
|
||||
|
||||
slog.Debug("tls_fingerprint_http_proxy_sending_connect", "target", addr)
|
||||
if err := req.Write(conn); err != nil {
|
||||
_ = conn.Close()
|
||||
slog.Debug("tls_fingerprint_http_proxy_write_failed", "error", err)
|
||||
return nil, fmt.Errorf("write CONNECT request: %w", err)
|
||||
}
|
||||
|
||||
// Step 3: Read CONNECT response
|
||||
br := bufio.NewReader(conn)
|
||||
resp, err := http.ReadResponse(br, req)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
slog.Debug("tls_fingerprint_http_proxy_read_response_failed", "error", err)
|
||||
return nil, fmt.Errorf("read CONNECT response: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
_ = conn.Close()
|
||||
slog.Debug("tls_fingerprint_http_proxy_connect_failed_status", "status_code", resp.StatusCode, "status", resp.Status)
|
||||
return nil, fmt.Errorf("proxy CONNECT failed: %s", resp.Status)
|
||||
}
|
||||
slog.Debug("tls_fingerprint_http_proxy_tunnel_established")
|
||||
|
||||
// Step 4: Perform TLS handshake on the tunnel with utls fingerprint
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr
|
||||
}
|
||||
slog.Debug("tls_fingerprint_http_proxy_starting_handshake", "host", host)
|
||||
|
||||
// Build ClientHello specification (reuse the shared method)
|
||||
spec := buildClientHelloSpecFromProfile(d.profile)
|
||||
slog.Debug("tls_fingerprint_http_proxy_clienthello_spec",
|
||||
"cipher_suites", len(spec.CipherSuites),
|
||||
"extensions", len(spec.Extensions))
|
||||
|
||||
if d.profile != nil {
|
||||
slog.Debug("tls_fingerprint_http_proxy_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
|
||||
}
|
||||
|
||||
// Create uTLS connection on the tunnel
|
||||
// Note: TLS 1.3 cipher suites are handled automatically by utls when TLS 1.3 is in SupportedVersions
|
||||
tlsConn := utls.UClient(conn, &utls.Config{
|
||||
ServerName: host,
|
||||
}, utls.HelloCustom)
|
||||
|
||||
if err := tlsConn.ApplyPreset(spec); err != nil {
|
||||
slog.Debug("tls_fingerprint_http_proxy_apply_preset_failed", "error", err)
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("apply TLS preset: %w", err)
|
||||
}
|
||||
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
slog.Debug("tls_fingerprint_http_proxy_handshake_failed", "error", err)
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("TLS handshake failed: %w", err)
|
||||
}
|
||||
|
||||
state := tlsConn.ConnectionState()
|
||||
slog.Debug("tls_fingerprint_http_proxy_handshake_success",
|
||||
"version", fmt.Sprintf("0x%04x", state.Version),
|
||||
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
|
||||
"alpn", state.NegotiatedProtocol)
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
// DialTLSContext establishes a TLS connection with the configured fingerprint.
|
||||
// This method is designed to be used as http.Transport.DialTLSContext.
|
||||
func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
// Establish TCP connection using base dialer (supports proxy)
|
||||
slog.Debug("tls_fingerprint_dialing_tcp", "addr", addr)
|
||||
conn, err := d.baseDialer(ctx, network, addr)
|
||||
if err != nil {
|
||||
slog.Debug("tls_fingerprint_tcp_dial_failed", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
slog.Debug("tls_fingerprint_tcp_connected", "addr", addr)
|
||||
|
||||
// Extract hostname for SNI
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr
|
||||
}
|
||||
slog.Debug("tls_fingerprint_sni_hostname", "host", host)
|
||||
|
||||
// Build ClientHello specification
|
||||
spec := d.buildClientHelloSpec()
|
||||
slog.Debug("tls_fingerprint_clienthello_spec",
|
||||
"cipher_suites", len(spec.CipherSuites),
|
||||
"extensions", len(spec.Extensions))
|
||||
|
||||
// Log profile info
|
||||
if d.profile != nil {
|
||||
slog.Debug("tls_fingerprint_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
|
||||
} else {
|
||||
slog.Debug("tls_fingerprint_using_default_profile")
|
||||
}
|
||||
|
||||
// Create uTLS connection
|
||||
// Note: TLS 1.3 cipher suites are handled automatically by utls when TLS 1.3 is in SupportedVersions
|
||||
tlsConn := utls.UClient(conn, &utls.Config{
|
||||
ServerName: host,
|
||||
}, utls.HelloCustom)
|
||||
|
||||
// Apply fingerprint
|
||||
if err := tlsConn.ApplyPreset(spec); err != nil {
|
||||
slog.Debug("tls_fingerprint_apply_preset_failed", "error", err)
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
slog.Debug("tls_fingerprint_preset_applied")
|
||||
|
||||
// Perform TLS handshake
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
slog.Debug("tls_fingerprint_handshake_failed",
|
||||
"error", err,
|
||||
"local_addr", conn.LocalAddr(),
|
||||
"remote_addr", conn.RemoteAddr())
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("TLS handshake failed: %w", err)
|
||||
}
|
||||
|
||||
// Log successful handshake details
|
||||
state := tlsConn.ConnectionState()
|
||||
slog.Debug("tls_fingerprint_handshake_success",
|
||||
"version", fmt.Sprintf("0x%04x", state.Version),
|
||||
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
|
||||
"alpn", state.NegotiatedProtocol)
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
// buildClientHelloSpec constructs the ClientHello specification based on the profile.
|
||||
func (d *Dialer) buildClientHelloSpec() *utls.ClientHelloSpec {
|
||||
return buildClientHelloSpecFromProfile(d.profile)
|
||||
}
|
||||
|
||||
// toUTLSCurves converts uint16 slice to utls.CurveID slice.
|
||||
func toUTLSCurves(curves []uint16) []utls.CurveID {
|
||||
result := make([]utls.CurveID, len(curves))
|
||||
for i, c := range curves {
|
||||
result[i] = utls.CurveID(c)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// buildClientHelloSpecFromProfile constructs ClientHelloSpec from a Profile.
|
||||
// This is a standalone function that can be used by both Dialer and HTTPProxyDialer.
|
||||
func buildClientHelloSpecFromProfile(profile *Profile) *utls.ClientHelloSpec {
|
||||
// Get cipher suites
|
||||
var cipherSuites []uint16
|
||||
if profile != nil && len(profile.CipherSuites) > 0 {
|
||||
cipherSuites = profile.CipherSuites
|
||||
} else {
|
||||
cipherSuites = defaultCipherSuites
|
||||
}
|
||||
|
||||
// Get curves
|
||||
var curves []utls.CurveID
|
||||
if profile != nil && len(profile.Curves) > 0 {
|
||||
curves = toUTLSCurves(profile.Curves)
|
||||
} else {
|
||||
curves = defaultCurves
|
||||
}
|
||||
|
||||
// Get point formats
|
||||
var pointFormats []uint8
|
||||
if profile != nil && len(profile.PointFormats) > 0 {
|
||||
pointFormats = profile.PointFormats
|
||||
} else {
|
||||
pointFormats = defaultPointFormats
|
||||
}
|
||||
|
||||
// Check if GREASE is enabled
|
||||
enableGREASE := profile != nil && profile.EnableGREASE
|
||||
|
||||
extensions := make([]utls.TLSExtension, 0, 16)
|
||||
|
||||
if enableGREASE {
|
||||
extensions = append(extensions, &utls.UtlsGREASEExtension{})
|
||||
}
|
||||
|
||||
// SNI extension - MUST be explicitly added for HelloCustom mode
|
||||
// utls will populate the server name from Config.ServerName
|
||||
extensions = append(extensions, &utls.SNIExtension{})
|
||||
|
||||
// Claude CLI extension order (captured from tshark):
|
||||
// server_name(0), ec_point_formats(11), supported_groups(10), session_ticket(35),
|
||||
// alpn(16), encrypt_then_mac(22), extended_master_secret(23),
|
||||
// signature_algorithms(13), supported_versions(43),
|
||||
// psk_key_exchange_modes(45), key_share(51)
|
||||
extensions = append(extensions,
|
||||
&utls.SupportedPointsExtension{SupportedPoints: pointFormats},
|
||||
&utls.SupportedCurvesExtension{Curves: curves},
|
||||
&utls.SessionTicketExtension{},
|
||||
&utls.ALPNExtension{AlpnProtocols: []string{"http/1.1"}},
|
||||
&utls.GenericExtension{Id: 22},
|
||||
&utls.ExtendedMasterSecretExtension{},
|
||||
&utls.SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: defaultSignatureAlgorithms},
|
||||
&utls.SupportedVersionsExtension{Versions: []uint16{
|
||||
utls.VersionTLS13,
|
||||
utls.VersionTLS12,
|
||||
}},
|
||||
&utls.PSKKeyExchangeModesExtension{Modes: []uint8{utls.PskModeDHE}},
|
||||
&utls.KeyShareExtension{KeyShares: []utls.KeyShare{
|
||||
{Group: utls.X25519},
|
||||
}},
|
||||
)
|
||||
|
||||
if enableGREASE {
|
||||
extensions = append(extensions, &utls.UtlsGREASEExtension{})
|
||||
}
|
||||
|
||||
return &utls.ClientHelloSpec{
|
||||
CipherSuites: cipherSuites,
|
||||
CompressionMethods: []uint8{0}, // null compression only (standard)
|
||||
Extensions: extensions,
|
||||
TLSVersMax: utls.VersionTLS13,
|
||||
TLSVersMin: utls.VersionTLS10,
|
||||
}
|
||||
}
|
||||
278
backend/internal/pkg/tlsfingerprint/dialer_integration_test.go
Normal file
278
backend/internal/pkg/tlsfingerprint/dialer_integration_test.go
Normal file
@@ -0,0 +1,278 @@
|
||||
//go:build integration
|
||||
|
||||
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
|
||||
//
|
||||
// Integration tests for verifying TLS fingerprint correctness.
|
||||
// These tests make actual network requests to external services and should be run manually.
|
||||
//
|
||||
// Run with: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
|
||||
package tlsfingerprint
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// skipIfExternalServiceUnavailable checks if the external service is available.
|
||||
// If not, it skips the test instead of failing.
|
||||
func skipIfExternalServiceUnavailable(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
// Check for common network/TLS errors that indicate external service issues
|
||||
errStr := err.Error()
|
||||
if strings.Contains(errStr, "certificate has expired") ||
|
||||
strings.Contains(errStr, "certificate is not yet valid") ||
|
||||
strings.Contains(errStr, "connection refused") ||
|
||||
strings.Contains(errStr, "no such host") ||
|
||||
strings.Contains(errStr, "network is unreachable") ||
|
||||
strings.Contains(errStr, "timeout") {
|
||||
t.Skipf("skipping test: external service unavailable: %v", err)
|
||||
}
|
||||
t.Fatalf("failed to get fingerprint: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
|
||||
// This test uses tls.peet.ws to verify the fingerprint.
|
||||
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
|
||||
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
|
||||
func TestJA3Fingerprint(t *testing.T) {
|
||||
// Skip if network is unavailable or if running in short mode
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
profile := &Profile{
|
||||
Name: "Claude CLI Test",
|
||||
EnableGREASE: false,
|
||||
}
|
||||
dialer := NewDialer(profile, nil)
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialTLSContext: dialer.DialTLSContext,
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// Use tls.peet.ws fingerprint detection API
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
skipIfExternalServiceUnavailable(t, err)
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read response: %v", err)
|
||||
}
|
||||
|
||||
var fpResp FingerprintResponse
|
||||
if err := json.Unmarshal(body, &fpResp); err != nil {
|
||||
t.Logf("Response body: %s", string(body))
|
||||
t.Fatalf("failed to parse fingerprint response: %v", err)
|
||||
}
|
||||
|
||||
// Log all fingerprint information
|
||||
t.Logf("JA3: %s", fpResp.TLS.JA3)
|
||||
t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
|
||||
t.Logf("JA4: %s", fpResp.TLS.JA4)
|
||||
t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
|
||||
t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
|
||||
|
||||
// Verify JA3 hash matches expected value
|
||||
expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
|
||||
if fpResp.TLS.JA3Hash == expectedJA3Hash {
|
||||
t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
|
||||
} else {
|
||||
t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
|
||||
}
|
||||
|
||||
// Verify JA4 fingerprint
|
||||
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
|
||||
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
|
||||
// The suffix _a33745022dd6_1f22a2ca17c4 should match
|
||||
expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
|
||||
if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
|
||||
t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
|
||||
} else {
|
||||
t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
|
||||
}
|
||||
|
||||
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
|
||||
// d = domain (SNI present), i = IP (no SNI)
|
||||
// Since we connect to tls.peet.ws (domain), we expect 'd'
|
||||
expectedJA4Prefix := "t13d5911h1"
|
||||
if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
|
||||
t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
|
||||
} else {
|
||||
// Also accept 'i' variant for IP connections
|
||||
altPrefix := "t13i5911h1"
|
||||
if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
|
||||
t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
|
||||
} else {
|
||||
t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
|
||||
if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
|
||||
t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
|
||||
} else {
|
||||
t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
|
||||
}
|
||||
|
||||
// Verify extension list (should be 11 extensions including SNI)
|
||||
// Expected: 0-11-10-35-16-22-23-13-43-45-51
|
||||
expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
|
||||
if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
|
||||
t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
|
||||
} else {
|
||||
t.Logf("Warning: JA3 extension list may differ")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfileExpectation defines expected fingerprint values for a profile.
|
||||
type TestProfileExpectation struct {
|
||||
Profile *Profile
|
||||
ExpectedJA3 string // Expected JA3 hash (empty = don't check)
|
||||
ExpectedJA4 string // Expected full JA4 (empty = don't check)
|
||||
JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
|
||||
}
|
||||
|
||||
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
|
||||
// Run with: go test -v -tags=integration -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
|
||||
func TestAllProfiles(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// Define all profiles to test with their expected fingerprints
|
||||
// These profiles are from config.yaml gateway.tls_fingerprint.profiles
|
||||
profiles := []TestProfileExpectation{
|
||||
{
|
||||
// Linux x64 Node.js v22.17.1
|
||||
// Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
|
||||
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
|
||||
Profile: &Profile{
|
||||
Name: "linux_x64_node_v22171",
|
||||
EnableGREASE: false,
|
||||
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
|
||||
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
|
||||
PointFormats: []uint8{0, 1, 2},
|
||||
},
|
||||
JA4CipherHash: "a33745022dd6", // stable part
|
||||
},
|
||||
{
|
||||
// MacOS arm64 Node.js v22.18.0
|
||||
// Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
|
||||
// Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
|
||||
Profile: &Profile{
|
||||
Name: "macos_arm64_node_v22180",
|
||||
EnableGREASE: false,
|
||||
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
|
||||
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
|
||||
PointFormats: []uint8{0, 1, 2},
|
||||
},
|
||||
JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range profiles {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.Profile.Name, func(t *testing.T) {
|
||||
fp := fetchFingerprint(t, tc.Profile)
|
||||
if fp == nil {
|
||||
return // fetchFingerprint already called t.Fatal
|
||||
}
|
||||
|
||||
t.Logf("Profile: %s", tc.Profile.Name)
|
||||
t.Logf(" JA3: %s", fp.JA3)
|
||||
t.Logf(" JA3 Hash: %s", fp.JA3Hash)
|
||||
t.Logf(" JA4: %s", fp.JA4)
|
||||
t.Logf(" PeetPrint: %s", fp.PeetPrint)
|
||||
t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
|
||||
|
||||
// Verify expectations
|
||||
if tc.ExpectedJA3 != "" {
|
||||
if fp.JA3Hash == tc.ExpectedJA3 {
|
||||
t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
|
||||
} else {
|
||||
t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.ExpectedJA4 != "" {
|
||||
if fp.JA4 == tc.ExpectedJA4 {
|
||||
t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
|
||||
} else {
|
||||
t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
|
||||
}
|
||||
}
|
||||
|
||||
// Check JA4 cipher hash (stable middle part)
|
||||
// JA4 format: prefix_cipherHash_extHash
|
||||
if tc.JA4CipherHash != "" {
|
||||
if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
|
||||
t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
|
||||
} else {
|
||||
t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
|
||||
func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
|
||||
t.Helper()
|
||||
|
||||
dialer := NewDialer(profile, nil)
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialTLSContext: dialer.DialTLSContext,
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
return nil
|
||||
}
|
||||
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
skipIfExternalServiceUnavailable(t, err)
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read response: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
var fpResp FingerprintResponse
|
||||
if err := json.Unmarshal(body, &fpResp); err != nil {
|
||||
t.Logf("Response body: %s", string(body))
|
||||
t.Fatalf("failed to parse fingerprint response: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &fpResp.TLS
|
||||
}
|
||||
160
backend/internal/pkg/tlsfingerprint/dialer_test.go
Normal file
160
backend/internal/pkg/tlsfingerprint/dialer_test.go
Normal file
@@ -0,0 +1,160 @@
|
||||
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
|
||||
//
|
||||
// Unit tests for TLS fingerprint dialer.
|
||||
// Integration tests that require external network are in dialer_integration_test.go
|
||||
// and require the 'integration' build tag.
|
||||
//
|
||||
// Run unit tests: go test -v ./internal/pkg/tlsfingerprint/...
|
||||
// Run integration tests: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
|
||||
package tlsfingerprint
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// FingerprintResponse represents the response from tls.peet.ws/api/all.
|
||||
type FingerprintResponse struct {
|
||||
IP string `json:"ip"`
|
||||
TLS TLSInfo `json:"tls"`
|
||||
HTTP2 any `json:"http2"`
|
||||
}
|
||||
|
||||
// TLSInfo contains TLS fingerprint details.
|
||||
type TLSInfo struct {
|
||||
JA3 string `json:"ja3"`
|
||||
JA3Hash string `json:"ja3_hash"`
|
||||
JA4 string `json:"ja4"`
|
||||
PeetPrint string `json:"peetprint"`
|
||||
PeetPrintHash string `json:"peetprint_hash"`
|
||||
ClientRandom string `json:"client_random"`
|
||||
SessionID string `json:"session_id"`
|
||||
}
|
||||
|
||||
// TestDialerWithProfile tests that different profiles produce different fingerprints.
|
||||
func TestDialerWithProfile(t *testing.T) {
|
||||
// Create two dialers with different profiles
|
||||
profile1 := &Profile{
|
||||
Name: "Profile 1 - No GREASE",
|
||||
EnableGREASE: false,
|
||||
}
|
||||
profile2 := &Profile{
|
||||
Name: "Profile 2 - With GREASE",
|
||||
EnableGREASE: true,
|
||||
}
|
||||
|
||||
dialer1 := NewDialer(profile1, nil)
|
||||
dialer2 := NewDialer(profile2, nil)
|
||||
|
||||
// Build specs and compare
|
||||
// Note: We can't directly compare JA3 without making network requests
|
||||
// but we can verify the specs are different
|
||||
spec1 := dialer1.buildClientHelloSpec()
|
||||
spec2 := dialer2.buildClientHelloSpec()
|
||||
|
||||
// Profile with GREASE should have more extensions
|
||||
if len(spec2.Extensions) <= len(spec1.Extensions) {
|
||||
t.Error("expected GREASE profile to have more extensions")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHTTPProxyDialerBasic tests HTTP proxy dialer creation.
|
||||
// Note: This is a unit test - actual proxy testing requires a proxy server.
|
||||
func TestHTTPProxyDialerBasic(t *testing.T) {
|
||||
profile := &Profile{
|
||||
Name: "Test Profile",
|
||||
EnableGREASE: false,
|
||||
}
|
||||
|
||||
// Test that dialer is created without panic
|
||||
proxyURL := mustParseURL("http://proxy.example.com:8080")
|
||||
dialer := NewHTTPProxyDialer(profile, proxyURL)
|
||||
|
||||
if dialer == nil {
|
||||
t.Fatal("expected dialer to be created")
|
||||
}
|
||||
if dialer.profile != profile {
|
||||
t.Error("expected profile to be set")
|
||||
}
|
||||
if dialer.proxyURL != proxyURL {
|
||||
t.Error("expected proxyURL to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSOCKS5ProxyDialerBasic tests SOCKS5 proxy dialer creation.
|
||||
// Note: This is a unit test - actual proxy testing requires a proxy server.
|
||||
func TestSOCKS5ProxyDialerBasic(t *testing.T) {
|
||||
profile := &Profile{
|
||||
Name: "Test Profile",
|
||||
EnableGREASE: false,
|
||||
}
|
||||
|
||||
// Test that dialer is created without panic
|
||||
proxyURL := mustParseURL("socks5://proxy.example.com:1080")
|
||||
dialer := NewSOCKS5ProxyDialer(profile, proxyURL)
|
||||
|
||||
if dialer == nil {
|
||||
t.Fatal("expected dialer to be created")
|
||||
}
|
||||
if dialer.profile != profile {
|
||||
t.Error("expected profile to be set")
|
||||
}
|
||||
if dialer.proxyURL != proxyURL {
|
||||
t.Error("expected proxyURL to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildClientHelloSpec tests ClientHello spec construction.
|
||||
func TestBuildClientHelloSpec(t *testing.T) {
|
||||
// Test with nil profile (should use defaults)
|
||||
spec := buildClientHelloSpecFromProfile(nil)
|
||||
|
||||
if len(spec.CipherSuites) == 0 {
|
||||
t.Error("expected cipher suites to be set")
|
||||
}
|
||||
if len(spec.Extensions) == 0 {
|
||||
t.Error("expected extensions to be set")
|
||||
}
|
||||
|
||||
// Verify default cipher suites are used
|
||||
if len(spec.CipherSuites) != len(defaultCipherSuites) {
|
||||
t.Errorf("expected %d cipher suites, got %d", len(defaultCipherSuites), len(spec.CipherSuites))
|
||||
}
|
||||
|
||||
// Test with custom profile
|
||||
customProfile := &Profile{
|
||||
Name: "Custom",
|
||||
EnableGREASE: false,
|
||||
CipherSuites: []uint16{0x1301, 0x1302},
|
||||
}
|
||||
spec = buildClientHelloSpecFromProfile(customProfile)
|
||||
|
||||
if len(spec.CipherSuites) != 2 {
|
||||
t.Errorf("expected 2 cipher suites, got %d", len(spec.CipherSuites))
|
||||
}
|
||||
}
|
||||
|
||||
// TestToUTLSCurves tests curve ID conversion.
|
||||
func TestToUTLSCurves(t *testing.T) {
|
||||
input := []uint16{0x001d, 0x0017, 0x0018}
|
||||
result := toUTLSCurves(input)
|
||||
|
||||
if len(result) != len(input) {
|
||||
t.Errorf("expected %d curves, got %d", len(input), len(result))
|
||||
}
|
||||
|
||||
for i, curve := range result {
|
||||
if uint16(curve) != input[i] {
|
||||
t.Errorf("curve %d: expected 0x%04x, got 0x%04x", i, input[i], uint16(curve))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to parse URL without error handling.
|
||||
func mustParseURL(rawURL string) *url.URL {
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return u
|
||||
}
|
||||
171
backend/internal/pkg/tlsfingerprint/registry.go
Normal file
171
backend/internal/pkg/tlsfingerprint/registry.go
Normal file
@@ -0,0 +1,171 @@
|
||||
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
|
||||
package tlsfingerprint
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
// DefaultProfileName is the name of the built-in Claude CLI profile.
|
||||
const DefaultProfileName = "claude_cli_v2"
|
||||
|
||||
// Registry manages TLS fingerprint profiles.
|
||||
// It holds a collection of profiles that can be used for TLS fingerprint simulation.
|
||||
// Profiles are selected based on account ID using modulo operation.
|
||||
type Registry struct {
|
||||
mu sync.RWMutex
|
||||
profiles map[string]*Profile
|
||||
profileNames []string // Sorted list of profile names for deterministic selection
|
||||
}
|
||||
|
||||
// NewRegistry creates a new TLS fingerprint profile registry.
|
||||
// It initializes with the built-in default profile.
|
||||
func NewRegistry() *Registry {
|
||||
r := &Registry{
|
||||
profiles: make(map[string]*Profile),
|
||||
profileNames: make([]string, 0),
|
||||
}
|
||||
|
||||
// Register the built-in default profile
|
||||
r.registerBuiltinProfile()
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// NewRegistryFromConfig creates a new registry and loads profiles from config.
|
||||
// If the config has custom profiles defined, they will be merged with the built-in default.
|
||||
func NewRegistryFromConfig(cfg *config.TLSFingerprintConfig) *Registry {
|
||||
r := NewRegistry()
|
||||
|
||||
if cfg == nil || !cfg.Enabled {
|
||||
slog.Debug("tls_registry_disabled", "reason", "disabled or no config")
|
||||
return r
|
||||
}
|
||||
|
||||
// Load custom profiles from config
|
||||
for name, profileCfg := range cfg.Profiles {
|
||||
profile := &Profile{
|
||||
Name: profileCfg.Name,
|
||||
EnableGREASE: profileCfg.EnableGREASE,
|
||||
CipherSuites: profileCfg.CipherSuites,
|
||||
Curves: profileCfg.Curves,
|
||||
PointFormats: profileCfg.PointFormats,
|
||||
}
|
||||
|
||||
// If the profile has empty values, they will use defaults in dialer
|
||||
r.RegisterProfile(name, profile)
|
||||
slog.Debug("tls_registry_loaded_profile", "key", name, "name", profileCfg.Name)
|
||||
}
|
||||
|
||||
slog.Debug("tls_registry_initialized", "profile_count", len(r.profileNames), "profiles", r.profileNames)
|
||||
return r
|
||||
}
|
||||
|
||||
// registerBuiltinProfile adds the default Claude CLI profile to the registry.
|
||||
func (r *Registry) registerBuiltinProfile() {
|
||||
defaultProfile := &Profile{
|
||||
Name: "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)",
|
||||
EnableGREASE: false, // Node.js does not use GREASE
|
||||
// Empty slices will cause dialer to use built-in defaults
|
||||
CipherSuites: nil,
|
||||
Curves: nil,
|
||||
PointFormats: nil,
|
||||
}
|
||||
r.RegisterProfile(DefaultProfileName, defaultProfile)
|
||||
}
|
||||
|
||||
// RegisterProfile adds or updates a profile in the registry.
|
||||
func (r *Registry) RegisterProfile(name string, profile *Profile) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// Check if this is a new profile
|
||||
_, exists := r.profiles[name]
|
||||
r.profiles[name] = profile
|
||||
|
||||
if !exists {
|
||||
r.profileNames = append(r.profileNames, name)
|
||||
// Keep names sorted for deterministic selection
|
||||
sort.Strings(r.profileNames)
|
||||
}
|
||||
}
|
||||
|
||||
// GetProfile returns a profile by name.
|
||||
// Returns nil if the profile does not exist.
|
||||
func (r *Registry) GetProfile(name string) *Profile {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.profiles[name]
|
||||
}
|
||||
|
||||
// GetDefaultProfile returns the built-in default profile.
|
||||
func (r *Registry) GetDefaultProfile() *Profile {
|
||||
return r.GetProfile(DefaultProfileName)
|
||||
}
|
||||
|
||||
// GetProfileByAccountID returns a profile for the given account ID.
|
||||
// The profile is selected using: profileNames[accountID % len(profiles)]
|
||||
// This ensures deterministic profile assignment for each account.
|
||||
func (r *Registry) GetProfileByAccountID(accountID int64) *Profile {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
if len(r.profileNames) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use modulo to select profile index
|
||||
// Use absolute value to handle negative IDs (though unlikely)
|
||||
idx := accountID
|
||||
if idx < 0 {
|
||||
idx = -idx
|
||||
}
|
||||
selectedIndex := int(idx % int64(len(r.profileNames)))
|
||||
selectedName := r.profileNames[selectedIndex]
|
||||
|
||||
return r.profiles[selectedName]
|
||||
}
|
||||
|
||||
// ProfileCount returns the number of registered profiles.
|
||||
func (r *Registry) ProfileCount() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return len(r.profiles)
|
||||
}
|
||||
|
||||
// ProfileNames returns a sorted list of all registered profile names.
|
||||
func (r *Registry) ProfileNames() []string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
// Return a copy to prevent modification
|
||||
names := make([]string, len(r.profileNames))
|
||||
copy(names, r.profileNames)
|
||||
return names
|
||||
}
|
||||
|
||||
// Global registry instance for convenience
|
||||
var globalRegistry *Registry
|
||||
var globalRegistryOnce sync.Once
|
||||
|
||||
// GlobalRegistry returns the global TLS fingerprint registry.
|
||||
// The registry is lazily initialized with the default profile.
|
||||
func GlobalRegistry() *Registry {
|
||||
globalRegistryOnce.Do(func() {
|
||||
globalRegistry = NewRegistry()
|
||||
})
|
||||
return globalRegistry
|
||||
}
|
||||
|
||||
// InitGlobalRegistry initializes the global registry with configuration.
|
||||
// This should be called during application startup.
|
||||
// It is safe to call multiple times; subsequent calls will update the registry.
|
||||
func InitGlobalRegistry(cfg *config.TLSFingerprintConfig) *Registry {
|
||||
globalRegistryOnce.Do(func() {
|
||||
globalRegistry = NewRegistryFromConfig(cfg)
|
||||
})
|
||||
return globalRegistry
|
||||
}
|
||||
243
backend/internal/pkg/tlsfingerprint/registry_test.go
Normal file
243
backend/internal/pkg/tlsfingerprint/registry_test.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package tlsfingerprint
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
func TestNewRegistry(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
// Should have exactly one profile (the default)
|
||||
if r.ProfileCount() != 1 {
|
||||
t.Errorf("expected 1 profile, got %d", r.ProfileCount())
|
||||
}
|
||||
|
||||
// Should have the default profile
|
||||
profile := r.GetDefaultProfile()
|
||||
if profile == nil {
|
||||
t.Error("expected default profile to exist")
|
||||
}
|
||||
|
||||
// Default profile name should be in the list
|
||||
names := r.ProfileNames()
|
||||
if len(names) != 1 || names[0] != DefaultProfileName {
|
||||
t.Errorf("expected profile names to be [%s], got %v", DefaultProfileName, names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterProfile(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
// Register a new profile
|
||||
customProfile := &Profile{
|
||||
Name: "Custom Profile",
|
||||
EnableGREASE: true,
|
||||
}
|
||||
r.RegisterProfile("custom", customProfile)
|
||||
|
||||
// Should now have 2 profiles
|
||||
if r.ProfileCount() != 2 {
|
||||
t.Errorf("expected 2 profiles, got %d", r.ProfileCount())
|
||||
}
|
||||
|
||||
// Should be able to retrieve the custom profile
|
||||
retrieved := r.GetProfile("custom")
|
||||
if retrieved == nil {
|
||||
t.Fatal("expected custom profile to exist")
|
||||
}
|
||||
if retrieved.Name != "Custom Profile" {
|
||||
t.Errorf("expected profile name 'Custom Profile', got '%s'", retrieved.Name)
|
||||
}
|
||||
if !retrieved.EnableGREASE {
|
||||
t.Error("expected EnableGREASE to be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetProfile(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
// Get existing profile
|
||||
profile := r.GetProfile(DefaultProfileName)
|
||||
if profile == nil {
|
||||
t.Error("expected default profile to exist")
|
||||
}
|
||||
|
||||
// Get non-existing profile
|
||||
nonExistent := r.GetProfile("nonexistent")
|
||||
if nonExistent != nil {
|
||||
t.Error("expected nil for non-existent profile")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetProfileByAccountID(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
// With only default profile, all account IDs should return the same profile
|
||||
for i := int64(0); i < 10; i++ {
|
||||
profile := r.GetProfileByAccountID(i)
|
||||
if profile == nil {
|
||||
t.Errorf("expected profile for account %d, got nil", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Add more profiles
|
||||
r.RegisterProfile("profile_a", &Profile{Name: "Profile A"})
|
||||
r.RegisterProfile("profile_b", &Profile{Name: "Profile B"})
|
||||
|
||||
// Now we have 3 profiles: claude_cli_v2, profile_a, profile_b
|
||||
// Names are sorted, so order is: claude_cli_v2, profile_a, profile_b
|
||||
expectedOrder := []string{DefaultProfileName, "profile_a", "profile_b"}
|
||||
names := r.ProfileNames()
|
||||
for i, name := range expectedOrder {
|
||||
if names[i] != name {
|
||||
t.Errorf("expected name at index %d to be %s, got %s", i, name, names[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Test modulo selection
|
||||
// Account ID 0 % 3 = 0 -> claude_cli_v2
|
||||
// Account ID 1 % 3 = 1 -> profile_a
|
||||
// Account ID 2 % 3 = 2 -> profile_b
|
||||
// Account ID 3 % 3 = 0 -> claude_cli_v2
|
||||
testCases := []struct {
|
||||
accountID int64
|
||||
expectedName string
|
||||
}{
|
||||
{0, "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"},
|
||||
{1, "Profile A"},
|
||||
{2, "Profile B"},
|
||||
{3, "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"},
|
||||
{4, "Profile A"},
|
||||
{5, "Profile B"},
|
||||
{100, "Profile A"}, // 100 % 3 = 1
|
||||
{-1, "Profile A"}, // |-1| % 3 = 1
|
||||
{-3, "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"}, // |-3| % 3 = 0
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
profile := r.GetProfileByAccountID(tc.accountID)
|
||||
if profile == nil {
|
||||
t.Errorf("expected profile for account %d, got nil", tc.accountID)
|
||||
continue
|
||||
}
|
||||
if profile.Name != tc.expectedName {
|
||||
t.Errorf("account %d: expected profile name '%s', got '%s'", tc.accountID, tc.expectedName, profile.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRegistryFromConfig(t *testing.T) {
|
||||
// Test with nil config
|
||||
r := NewRegistryFromConfig(nil)
|
||||
if r.ProfileCount() != 1 {
|
||||
t.Errorf("expected 1 profile with nil config, got %d", r.ProfileCount())
|
||||
}
|
||||
|
||||
// Test with disabled config
|
||||
disabledCfg := &config.TLSFingerprintConfig{
|
||||
Enabled: false,
|
||||
}
|
||||
r = NewRegistryFromConfig(disabledCfg)
|
||||
if r.ProfileCount() != 1 {
|
||||
t.Errorf("expected 1 profile with disabled config, got %d", r.ProfileCount())
|
||||
}
|
||||
|
||||
// Test with enabled config and custom profiles
|
||||
enabledCfg := &config.TLSFingerprintConfig{
|
||||
Enabled: true,
|
||||
Profiles: map[string]config.TLSProfileConfig{
|
||||
"custom1": {
|
||||
Name: "Custom Profile 1",
|
||||
EnableGREASE: true,
|
||||
},
|
||||
"custom2": {
|
||||
Name: "Custom Profile 2",
|
||||
EnableGREASE: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
r = NewRegistryFromConfig(enabledCfg)
|
||||
|
||||
// Should have 3 profiles: default + 2 custom
|
||||
if r.ProfileCount() != 3 {
|
||||
t.Errorf("expected 3 profiles, got %d", r.ProfileCount())
|
||||
}
|
||||
|
||||
// Check custom profiles exist
|
||||
custom1 := r.GetProfile("custom1")
|
||||
if custom1 == nil || custom1.Name != "Custom Profile 1" {
|
||||
t.Error("expected custom1 profile to exist with correct name")
|
||||
}
|
||||
custom2 := r.GetProfile("custom2")
|
||||
if custom2 == nil || custom2.Name != "Custom Profile 2" {
|
||||
t.Error("expected custom2 profile to exist with correct name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProfileNames(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
// Add profiles in non-alphabetical order
|
||||
r.RegisterProfile("zebra", &Profile{Name: "Zebra"})
|
||||
r.RegisterProfile("alpha", &Profile{Name: "Alpha"})
|
||||
r.RegisterProfile("beta", &Profile{Name: "Beta"})
|
||||
|
||||
names := r.ProfileNames()
|
||||
|
||||
// Should be sorted alphabetically
|
||||
expected := []string{"alpha", "beta", DefaultProfileName, "zebra"}
|
||||
if len(names) != len(expected) {
|
||||
t.Errorf("expected %d names, got %d", len(expected), len(names))
|
||||
}
|
||||
for i, name := range expected {
|
||||
if names[i] != name {
|
||||
t.Errorf("expected name at index %d to be %s, got %s", i, name, names[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Test that returned slice is a copy (modifying it shouldn't affect registry)
|
||||
names[0] = "modified"
|
||||
originalNames := r.ProfileNames()
|
||||
if originalNames[0] == "modified" {
|
||||
t.Error("modifying returned slice should not affect registry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
// Run concurrent reads and writes
|
||||
done := make(chan bool)
|
||||
|
||||
// Writers
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(id int) {
|
||||
for j := 0; j < 100; j++ {
|
||||
r.RegisterProfile("concurrent"+string(rune('0'+id)), &Profile{Name: "Concurrent"})
|
||||
}
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Readers
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(id int) {
|
||||
for j := 0; j < 100; j++ {
|
||||
_ = r.ProfileCount()
|
||||
_ = r.ProfileNames()
|
||||
_ = r.GetProfileByAccountID(int64(id * j))
|
||||
_ = r.GetProfile(DefaultProfileName)
|
||||
}
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for i := 0; i < 20; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Test should pass without data races (run with -race flag)
|
||||
}
|
||||
@@ -39,9 +39,15 @@ import (
|
||||
// 设计说明:
|
||||
// - client: Ent 客户端,用于类型安全的 ORM 操作
|
||||
// - sql: 原生 SQL 执行器,用于复杂查询和批量操作
|
||||
// - schedulerCache: 调度器缓存,用于在账号状态变更时同步快照
|
||||
type accountRepository struct {
|
||||
client *dbent.Client // Ent ORM 客户端
|
||||
sql sqlExecutor // 原生 SQL 执行接口
|
||||
// schedulerCache 用于在账号状态变更时主动同步快照到缓存,
|
||||
// 确保粘性会话能及时感知账号不可用状态。
|
||||
// Used to proactively sync account snapshot to cache when status changes,
|
||||
// ensuring sticky sessions can promptly detect unavailable accounts.
|
||||
schedulerCache service.SchedulerCache
|
||||
}
|
||||
|
||||
type tempUnschedSnapshot struct {
|
||||
@@ -51,14 +57,14 @@ type tempUnschedSnapshot struct {
|
||||
|
||||
// NewAccountRepository 创建账户仓储实例。
|
||||
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
|
||||
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRepository {
|
||||
return newAccountRepositoryWithSQL(client, sqlDB)
|
||||
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
|
||||
return newAccountRepositoryWithSQL(client, sqlDB, schedulerCache)
|
||||
}
|
||||
|
||||
// newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。
|
||||
// 这种设计便于单元测试时注入 mock 对象。
|
||||
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accountRepository {
|
||||
return &accountRepository{client: client, sql: sqlq}
|
||||
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor, schedulerCache service.SchedulerCache) *accountRepository {
|
||||
return &accountRepository{client: client, sql: sqlq, schedulerCache: schedulerCache}
|
||||
}
|
||||
|
||||
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
|
||||
@@ -356,6 +362,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
|
||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -473,37 +482,6 @@ func (r *accountRepository) ListByPlatform(ctx context.Context, platform string)
|
||||
return r.accountsToService(ctx, accounts)
|
||||
}
|
||||
|
||||
func (r *accountRepository) ListByPlatformAndCredentialEmails(
|
||||
ctx context.Context,
|
||||
platform string,
|
||||
emails []string,
|
||||
) ([]service.Account, error) {
|
||||
if len(emails) == 0 {
|
||||
return []service.Account{}, nil
|
||||
}
|
||||
args := make([]any, 0, len(emails))
|
||||
for _, email := range emails {
|
||||
if email == "" {
|
||||
continue
|
||||
}
|
||||
args = append(args, email)
|
||||
}
|
||||
if len(args) == 0 {
|
||||
return []service.Account{}, nil
|
||||
}
|
||||
|
||||
accounts, err := r.client.Account.Query().
|
||||
Where(dbaccount.PlatformEQ(platform)).
|
||||
Where(func(s *entsql.Selector) {
|
||||
s.Where(sqljson.ValueIn(dbaccount.FieldCredentials, args, sqljson.Path("email")))
|
||||
}).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.accountsToService(ctx, accounts)
|
||||
}
|
||||
|
||||
func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
|
||||
now := time.Now()
|
||||
_, err := r.client.Account.Update().
|
||||
@@ -571,9 +549,32 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// syncSchedulerAccountSnapshot 在账号状态变更时主动同步快照到调度器缓存。
|
||||
// 当账号被设置为错误、禁用、不可调度或临时不可调度时调用,
|
||||
// 确保调度器和粘性会话逻辑能及时感知账号的最新状态,避免继续使用不可用账号。
|
||||
//
|
||||
// syncSchedulerAccountSnapshot proactively syncs account snapshot to scheduler cache
|
||||
// when account status changes. Called when account is set to error, disabled,
|
||||
// unschedulable, or temporarily unschedulable, ensuring scheduler and sticky session
|
||||
// logic can promptly detect the latest account state and avoid using unavailable accounts.
|
||||
func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, accountID int64) {
|
||||
if r == nil || r.schedulerCache == nil || accountID <= 0 {
|
||||
return
|
||||
}
|
||||
account, err := r.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
|
||||
return
|
||||
}
|
||||
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
|
||||
log.Printf("[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
|
||||
_, err := r.client.Account.Update().
|
||||
Where(dbaccount.IDEQ(id)).
|
||||
@@ -787,7 +788,6 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
Where(dbaccount.IDEQ(id)).
|
||||
SetRateLimitedAt(now).
|
||||
SetRateLimitResetAt(resetAt).
|
||||
SetLastUsedAt(now).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -922,6 +922,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1018,7 +1019,16 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
|
||||
builder.SetSessionWindowEnd(*end)
|
||||
}
|
||||
_, err := builder.Save(ctx)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 触发调度器缓存更新(仅当窗口时间有变化时)
|
||||
if start != nil || end != nil {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
@@ -1032,6 +1042,9 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
|
||||
}
|
||||
if !schedulable {
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1186,6 +1199,18 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
|
||||
}
|
||||
shouldSync := false
|
||||
if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) {
|
||||
shouldSync = true
|
||||
}
|
||||
if updates.Schedulable != nil && !*updates.Schedulable {
|
||||
shouldSync = true
|
||||
}
|
||||
if shouldSync {
|
||||
for _, id := range ids {
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
@@ -21,11 +21,56 @@ type AccountRepoSuite struct {
|
||||
repo *accountRepository
|
||||
}
|
||||
|
||||
type schedulerCacheRecorder struct {
|
||||
setAccounts []*service.Account
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error {
|
||||
s.setAccounts = append(s.setAccounts, account)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) DeleteAccount(ctx context.Context, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) TryLockBucket(ctx context.Context, bucket service.SchedulerBucket, ttl time.Duration) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) GetOutboxWatermark(ctx context.Context) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) SetOutboxWatermark(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.repo = newAccountRepositoryWithSQL(s.client, tx)
|
||||
s.repo = newAccountRepositoryWithSQL(s.client, tx, nil)
|
||||
}
|
||||
|
||||
func TestAccountRepoSuite(t *testing.T) {
|
||||
@@ -73,6 +118,20 @@ func (s *AccountRepoSuite) TestUpdate() {
|
||||
s.Require().Equal("updated", got.Name)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "sync-update", Status: service.StatusActive, Schedulable: true})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
account.Status = service.StatusDisabled
|
||||
err := s.repo.Update(s.ctx, account)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
|
||||
|
||||
@@ -174,7 +233,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
// 每个 case 重新获取隔离资源
|
||||
tx := testEntTx(s.T())
|
||||
client := tx.Client()
|
||||
repo := newAccountRepositoryWithSQL(client, tx)
|
||||
repo := newAccountRepositoryWithSQL(client, tx, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
tt.setup(client)
|
||||
@@ -365,12 +424,38 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
|
||||
|
||||
func (s *AccountRepoSuite) TestSetSchedulable() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-sched", Schedulable: true})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(got.Schedulable)
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_SyncSchedulerSnapshotOnDisabled() {
|
||||
account1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-1", Status: service.StatusActive, Schedulable: true})
|
||||
account2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-2", Status: service.StatusActive, Schedulable: true})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
disabled := service.StatusDisabled
|
||||
rows, err := s.repo.BulkUpdate(s.ctx, []int64{account1.ID, account2.ID}, service.AccountBulkUpdate{
|
||||
Status: &disabled,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(int64(2), rows)
|
||||
|
||||
s.Require().Len(cacheRecorder.setAccounts, 2)
|
||||
ids := map[int64]struct{}{}
|
||||
for _, acc := range cacheRecorder.setAccounts {
|
||||
ids[acc.ID] = struct{}{}
|
||||
}
|
||||
s.Require().Contains(ids, account1.ID)
|
||||
s.Require().Contains(ids, account2.ID)
|
||||
}
|
||||
|
||||
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
|
||||
|
||||
95
backend/internal/repository/aes_encryptor.go
Normal file
95
backend/internal/repository/aes_encryptor.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// AESEncryptor implements SecretEncryptor using AES-256-GCM
|
||||
type AESEncryptor struct {
|
||||
key []byte
|
||||
}
|
||||
|
||||
// NewAESEncryptor creates a new AES encryptor
|
||||
func NewAESEncryptor(cfg *config.Config) (service.SecretEncryptor, error) {
|
||||
key, err := hex.DecodeString(cfg.Totp.EncryptionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid totp encryption key: %w", err)
|
||||
}
|
||||
|
||||
if len(key) != 32 {
|
||||
return nil, fmt.Errorf("totp encryption key must be 32 bytes (64 hex chars), got %d bytes", len(key))
|
||||
}
|
||||
|
||||
return &AESEncryptor{key: key}, nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts plaintext using AES-256-GCM
|
||||
// Output format: base64(nonce + ciphertext + tag)
|
||||
func (e *AESEncryptor) Encrypt(plaintext string) (string, error) {
|
||||
block, err := aes.NewCipher(e.key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create gcm: %w", err)
|
||||
}
|
||||
|
||||
// Generate a random nonce
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", fmt.Errorf("generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the plaintext
|
||||
// Seal appends the ciphertext and tag to the nonce
|
||||
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
|
||||
|
||||
// Encode as base64
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts ciphertext using AES-256-GCM
|
||||
func (e *AESEncryptor) Decrypt(ciphertext string) (string, error) {
|
||||
// Decode from base64
|
||||
data, err := base64.StdEncoding.DecodeString(ciphertext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode base64: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(e.key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create gcm: %w", err)
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return "", fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
// Extract nonce and ciphertext
|
||||
nonce, ciphertextData := data[:nonceSize], data[nonceSize:]
|
||||
|
||||
// Decrypt
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertextData, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt: %w", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
83
backend/internal/repository/announcement_read_repo.go
Normal file
83
backend/internal/repository/announcement_read_repo.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/announcementread"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type announcementReadRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewAnnouncementReadRepository(client *dbent.Client) service.AnnouncementReadRepository {
|
||||
return &announcementReadRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *announcementReadRepository) MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
return client.AnnouncementRead.Create().
|
||||
SetAnnouncementID(announcementID).
|
||||
SetUserID(userID).
|
||||
SetReadAt(readAt).
|
||||
OnConflictColumns(announcementread.FieldAnnouncementID, announcementread.FieldUserID).
|
||||
DoNothing().
|
||||
Exec(ctx)
|
||||
}
|
||||
|
||||
func (r *announcementReadRepository) GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error) {
|
||||
if len(announcementIDs) == 0 {
|
||||
return map[int64]time.Time{}, nil
|
||||
}
|
||||
|
||||
rows, err := r.client.AnnouncementRead.Query().
|
||||
Where(
|
||||
announcementread.UserIDEQ(userID),
|
||||
announcementread.AnnouncementIDIn(announcementIDs...),
|
||||
).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make(map[int64]time.Time, len(rows))
|
||||
for i := range rows {
|
||||
out[rows[i].AnnouncementID] = rows[i].ReadAt
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *announcementReadRepository) GetReadMapByUsers(ctx context.Context, announcementID int64, userIDs []int64) (map[int64]time.Time, error) {
|
||||
if len(userIDs) == 0 {
|
||||
return map[int64]time.Time{}, nil
|
||||
}
|
||||
|
||||
rows, err := r.client.AnnouncementRead.Query().
|
||||
Where(
|
||||
announcementread.AnnouncementIDEQ(announcementID),
|
||||
announcementread.UserIDIn(userIDs...),
|
||||
).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make(map[int64]time.Time, len(rows))
|
||||
for i := range rows {
|
||||
out[rows[i].UserID] = rows[i].ReadAt
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *announcementReadRepository) CountByAnnouncementID(ctx context.Context, announcementID int64) (int64, error) {
|
||||
count, err := r.client.AnnouncementRead.Query().
|
||||
Where(announcementread.AnnouncementIDEQ(announcementID)).
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(count), nil
|
||||
}
|
||||
194
backend/internal/repository/announcement_repo.go
Normal file
194
backend/internal/repository/announcement_repo.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/announcement"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type announcementRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewAnnouncementRepository(client *dbent.Client) service.AnnouncementRepository {
|
||||
return &announcementRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *announcementRepository) Create(ctx context.Context, a *service.Announcement) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
builder := client.Announcement.Create().
|
||||
SetTitle(a.Title).
|
||||
SetContent(a.Content).
|
||||
SetStatus(a.Status).
|
||||
SetTargeting(a.Targeting)
|
||||
|
||||
if a.StartsAt != nil {
|
||||
builder.SetStartsAt(*a.StartsAt)
|
||||
}
|
||||
if a.EndsAt != nil {
|
||||
builder.SetEndsAt(*a.EndsAt)
|
||||
}
|
||||
if a.CreatedBy != nil {
|
||||
builder.SetCreatedBy(*a.CreatedBy)
|
||||
}
|
||||
if a.UpdatedBy != nil {
|
||||
builder.SetUpdatedBy(*a.UpdatedBy)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
applyAnnouncementEntityToService(a, created)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *announcementRepository) GetByID(ctx context.Context, id int64) (*service.Announcement, error) {
|
||||
m, err := r.client.Announcement.Query().
|
||||
Where(announcement.IDEQ(id)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrAnnouncementNotFound, nil)
|
||||
}
|
||||
return announcementEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *announcementRepository) Update(ctx context.Context, a *service.Announcement) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
builder := client.Announcement.UpdateOneID(a.ID).
|
||||
SetTitle(a.Title).
|
||||
SetContent(a.Content).
|
||||
SetStatus(a.Status).
|
||||
SetTargeting(a.Targeting)
|
||||
|
||||
if a.StartsAt != nil {
|
||||
builder.SetStartsAt(*a.StartsAt)
|
||||
} else {
|
||||
builder.ClearStartsAt()
|
||||
}
|
||||
if a.EndsAt != nil {
|
||||
builder.SetEndsAt(*a.EndsAt)
|
||||
} else {
|
||||
builder.ClearEndsAt()
|
||||
}
|
||||
if a.CreatedBy != nil {
|
||||
builder.SetCreatedBy(*a.CreatedBy)
|
||||
} else {
|
||||
builder.ClearCreatedBy()
|
||||
}
|
||||
if a.UpdatedBy != nil {
|
||||
builder.SetUpdatedBy(*a.UpdatedBy)
|
||||
} else {
|
||||
builder.ClearUpdatedBy()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrAnnouncementNotFound, nil)
|
||||
}
|
||||
|
||||
a.UpdatedAt = updated.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *announcementRepository) Delete(ctx context.Context, id int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.Announcement.Delete().Where(announcement.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *announcementRepository) List(
|
||||
ctx context.Context,
|
||||
params pagination.PaginationParams,
|
||||
filters service.AnnouncementListFilters,
|
||||
) ([]service.Announcement, *pagination.PaginationResult, error) {
|
||||
q := r.client.Announcement.Query()
|
||||
|
||||
if filters.Status != "" {
|
||||
q = q.Where(announcement.StatusEQ(filters.Status))
|
||||
}
|
||||
if filters.Search != "" {
|
||||
q = q.Where(
|
||||
announcement.Or(
|
||||
announcement.TitleContainsFold(filters.Search),
|
||||
announcement.ContentContainsFold(filters.Search),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
items, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(announcement.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
out := announcementEntitiesToService(items)
|
||||
return out, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *announcementRepository) ListActive(ctx context.Context, now time.Time) ([]service.Announcement, error) {
|
||||
q := r.client.Announcement.Query().
|
||||
Where(
|
||||
announcement.StatusEQ(service.AnnouncementStatusActive),
|
||||
announcement.Or(announcement.StartsAtIsNil(), announcement.StartsAtLTE(now)),
|
||||
announcement.Or(announcement.EndsAtIsNil(), announcement.EndsAtGT(now)),
|
||||
).
|
||||
Order(dbent.Desc(announcement.FieldID))
|
||||
|
||||
items, err := q.All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return announcementEntitiesToService(items), nil
|
||||
}
|
||||
|
||||
func applyAnnouncementEntityToService(dst *service.Announcement, src *dbent.Announcement) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
dst.ID = src.ID
|
||||
dst.CreatedAt = src.CreatedAt
|
||||
dst.UpdatedAt = src.UpdatedAt
|
||||
}
|
||||
|
||||
func announcementEntityToService(m *dbent.Announcement) *service.Announcement {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.Announcement{
|
||||
ID: m.ID,
|
||||
Title: m.Title,
|
||||
Content: m.Content,
|
||||
Status: m.Status,
|
||||
Targeting: m.Targeting,
|
||||
StartsAt: m.StartsAt,
|
||||
EndsAt: m.EndsAt,
|
||||
CreatedBy: m.CreatedBy,
|
||||
UpdatedBy: m.UpdatedBy,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func announcementEntitiesToService(models []*dbent.Announcement) []service.Announcement {
|
||||
out := make([]service.Announcement, 0, len(models))
|
||||
for i := range models {
|
||||
if s := announcementEntityToService(models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -12,9 +13,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
||||
apiKeyRateLimitDuration = 24 * time.Hour
|
||||
apiKeyAuthCachePrefix = "apikey:auth:"
|
||||
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
||||
apiKeyRateLimitDuration = 24 * time.Hour
|
||||
apiKeyAuthCachePrefix = "apikey:auth:"
|
||||
authCacheInvalidateChannel = "auth:cache:invalidate"
|
||||
)
|
||||
|
||||
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
|
||||
@@ -91,3 +93,45 @@ func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *servi
|
||||
func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
|
||||
return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err()
|
||||
}
|
||||
|
||||
// PublishAuthCacheInvalidation publishes a cache invalidation message to all instances
|
||||
func (c *apiKeyCache) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
|
||||
return c.rdb.Publish(ctx, authCacheInvalidateChannel, cacheKey).Err()
|
||||
}
|
||||
|
||||
// SubscribeAuthCacheInvalidation subscribes to cache invalidation messages
|
||||
func (c *apiKeyCache) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
|
||||
pubsub := c.rdb.Subscribe(ctx, authCacheInvalidateChannel)
|
||||
|
||||
// Verify subscription is working
|
||||
_, err := pubsub.Receive(ctx)
|
||||
if err != nil {
|
||||
_ = pubsub.Close()
|
||||
return fmt.Errorf("subscribe to auth cache invalidation: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := pubsub.Close(); err != nil {
|
||||
log.Printf("Warning: failed to close auth cache invalidation pubsub: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ch := pubsub.Channel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case msg, ok := <-ch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if msg != nil {
|
||||
handler(msg.Payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -389,17 +389,20 @@ func userEntityToService(u *dbent.User) *service.User {
|
||||
return nil
|
||||
}
|
||||
return &service.User{
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Notes: u.Notes,
|
||||
PasswordHash: u.PasswordHash,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Notes: u.Notes,
|
||||
PasswordHash: u.PasswordHash,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
TotpSecretEncrypted: u.TotpSecretEncrypted,
|
||||
TotpEnabled: u.TotpEnabled,
|
||||
TotpEnabledAt: u.TotpEnabledAt,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -35,7 +35,9 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
|
||||
client := s.clientFactory(proxyURL)
|
||||
|
||||
var orgs []struct {
|
||||
UUID string `json:"uuid"`
|
||||
UUID string `json:"uuid"`
|
||||
Name string `json:"name"`
|
||||
RavenType *string `json:"raven_type"` // nil for personal, "team" for team organization
|
||||
}
|
||||
|
||||
targetURL := s.baseURL + "/api/organizations"
|
||||
@@ -65,7 +67,23 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
|
||||
return "", fmt.Errorf("no organizations found")
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
|
||||
// 如果只有一个组织,直接使用
|
||||
if len(orgs) == 1 {
|
||||
log.Printf("[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
|
||||
return orgs[0].UUID, nil
|
||||
}
|
||||
|
||||
// 如果有多个组织,优先选择 raven_type 为 "team" 的组织
|
||||
for _, org := range orgs {
|
||||
if org.RavenType != nil && *org.RavenType == "team" {
|
||||
log.Printf("[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s",
|
||||
org.UUID, org.Name, *org.RavenType)
|
||||
return org.UUID, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有 team 类型的组织,使用第一个
|
||||
log.Printf("[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
|
||||
return orgs[0].UUID, nil
|
||||
}
|
||||
|
||||
@@ -182,7 +200,9 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Accept", "application/json, text/plain, */*").
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetHeader("User-Agent", "axios/1.8.4").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
@@ -205,8 +225,6 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
||||
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
|
||||
client := s.clientFactory(proxyURL)
|
||||
|
||||
// 使用 JSON 格式(与 ExchangeCodeForToken 保持一致)
|
||||
// Anthropic OAuth API 期望 JSON 格式的请求体
|
||||
reqBody := map[string]any{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refreshToken,
|
||||
@@ -217,7 +235,9 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Accept", "application/json, text/plain, */*").
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetHeader("User-Agent", "axios/1.8.4").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
|
||||
@@ -171,7 +171,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
|
||||
s.client.baseURL = "http://in-process"
|
||||
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
|
||||
|
||||
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "")
|
||||
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeInference, "cc", "st", "")
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
|
||||
@@ -14,37 +14,82 @@ import (
|
||||
|
||||
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
|
||||
|
||||
// 默认 User-Agent,与用户抓包的请求一致
|
||||
const defaultUsageUserAgent = "claude-code/2.1.7"
|
||||
|
||||
type claudeUsageService struct {
|
||||
usageURL string
|
||||
allowPrivateHosts bool
|
||||
httpUpstream service.HTTPUpstream
|
||||
}
|
||||
|
||||
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
||||
return &claudeUsageService{usageURL: defaultClaudeUsageURL}
|
||||
// NewClaudeUsageFetcher 创建 Claude 用量获取服务
|
||||
// httpUpstream: 可选,如果提供则支持 TLS 指纹伪装
|
||||
func NewClaudeUsageFetcher(httpUpstream service.HTTPUpstream) service.ClaudeUsageFetcher {
|
||||
return &claudeUsageService{
|
||||
usageURL: defaultClaudeUsageURL,
|
||||
httpUpstream: httpUpstream,
|
||||
}
|
||||
}
|
||||
|
||||
// FetchUsage 简单版本,不支持 TLS 指纹(向后兼容)
|
||||
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 30 * time.Second,
|
||||
ValidateResolvedIP: true,
|
||||
AllowPrivateHosts: s.allowPrivateHosts,
|
||||
return s.FetchUsageWithOptions(ctx, &service.ClaudeUsageFetchOptions{
|
||||
AccessToken: accessToken,
|
||||
ProxyURL: proxyURL,
|
||||
})
|
||||
if err != nil {
|
||||
client = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
|
||||
// FetchUsageWithOptions 完整版本,支持 TLS 指纹和自定义 User-Agent
|
||||
func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *service.ClaudeUsageFetchOptions) (*service.ClaudeUsageResponse, error) {
|
||||
if opts == nil {
|
||||
return nil, fmt.Errorf("options is nil")
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request failed: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
// 设置请求头(与抓包一致,但不设置 Accept-Encoding,让 Go 自动处理压缩)
|
||||
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+opts.AccessToken)
|
||||
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
// 设置 User-Agent(优先使用缓存的 Fingerprint,否则使用默认值)
|
||||
userAgent := defaultUsageUserAgent
|
||||
if opts.Fingerprint != nil && opts.Fingerprint.UserAgent != "" {
|
||||
userAgent = opts.Fingerprint.UserAgent
|
||||
}
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
var resp *http.Response
|
||||
|
||||
// 如果启用 TLS 指纹且有 HTTPUpstream,使用 DoWithTLS
|
||||
if opts.EnableTLSFingerprint && s.httpUpstream != nil {
|
||||
// accountConcurrency 传 0 使用默认连接池配置,usage 请求不需要特殊的并发设置
|
||||
resp, err = s.httpUpstream.DoWithTLS(req, opts.ProxyURL, opts.AccountID, 0, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request with TLS fingerprint failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
// 不启用 TLS 指纹,使用普通 HTTP 客户端
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: opts.ProxyURL,
|
||||
Timeout: 30 * time.Second,
|
||||
ValidateResolvedIP: true,
|
||||
AllowPrivateHosts: s.allowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
client = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
|
||||
resp, err = client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
|
||||
@@ -77,6 +77,75 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
||||
if r == nil || r.sql == nil {
|
||||
return nil
|
||||
}
|
||||
loc := timezone.Location()
|
||||
startLocal := start.In(loc)
|
||||
endLocal := end.In(loc)
|
||||
if !endLocal.After(startLocal) {
|
||||
return nil
|
||||
}
|
||||
|
||||
hourStart := startLocal.Truncate(time.Hour)
|
||||
hourEnd := endLocal.Truncate(time.Hour)
|
||||
if endLocal.After(hourEnd) {
|
||||
hourEnd = hourEnd.Add(time.Hour)
|
||||
}
|
||||
|
||||
dayStart := truncateToDay(startLocal)
|
||||
dayEnd := truncateToDay(endLocal)
|
||||
if endLocal.After(dayEnd) {
|
||||
dayEnd = dayEnd.Add(24 * time.Hour)
|
||||
}
|
||||
|
||||
// 尽量使用事务保证范围内的一致性(允许在非 *sql.DB 的情况下退化为非事务执行)。
|
||||
if db, ok := r.sql.(*sql.DB); ok {
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
txRepo := newDashboardAggregationRepositoryWithSQL(tx)
|
||||
if err := txRepo.recomputeRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
return r.recomputeRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd)
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) recomputeRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error {
|
||||
// 先清空范围内桶,再重建(避免仅增量插入导致活跃用户等指标无法回退)。
|
||||
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly WHERE bucket_start >= $1 AND bucket_start < $2", hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly_users WHERE bucket_start >= $1 AND bucket_start < $2", hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily WHERE bucket_date >= $1::date AND bucket_date < $2::date", dayStart, dayEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily_users WHERE bucket_date >= $1::date AND bucket_date < $2::date", dayStart, dayEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.insertDailyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.upsertHourlyAggregates(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.upsertDailyAggregates(ctx, dayStart, dayEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||
var ts time.Time
|
||||
query := "SELECT last_aggregated_at FROM usage_dashboard_aggregation_watermark WHERE id = 1"
|
||||
|
||||
@@ -9,13 +9,27 @@ import (
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const verifyCodeKeyPrefix = "verify_code:"
|
||||
const (
|
||||
verifyCodeKeyPrefix = "verify_code:"
|
||||
passwordResetKeyPrefix = "password_reset:"
|
||||
passwordResetSentAtKeyPrefix = "password_reset_sent:"
|
||||
)
|
||||
|
||||
// verifyCodeKey generates the Redis key for email verification code.
|
||||
func verifyCodeKey(email string) string {
|
||||
return verifyCodeKeyPrefix + email
|
||||
}
|
||||
|
||||
// passwordResetKey generates the Redis key for password reset token.
|
||||
func passwordResetKey(email string) string {
|
||||
return passwordResetKeyPrefix + email
|
||||
}
|
||||
|
||||
// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
|
||||
func passwordResetSentAtKey(email string) string {
|
||||
return passwordResetSentAtKeyPrefix + email
|
||||
}
|
||||
|
||||
type emailCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
@@ -50,3 +64,45 @@ func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) e
|
||||
key := verifyCodeKey(email)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// Password reset token methods
|
||||
|
||||
func (c *emailCache) GetPasswordResetToken(ctx context.Context, email string) (*service.PasswordResetTokenData, error) {
|
||||
key := passwordResetKey(email)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var data service.PasswordResetTokenData
|
||||
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func (c *emailCache) SetPasswordResetToken(ctx context.Context, email string, data *service.PasswordResetTokenData, ttl time.Duration) error {
|
||||
key := passwordResetKey(email)
|
||||
val, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.rdb.Set(ctx, key, val, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *emailCache) DeletePasswordResetToken(ctx context.Context, email string) error {
|
||||
key := passwordResetKey(email)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// Password reset email cooldown methods
|
||||
|
||||
func (c *emailCache) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
|
||||
key := passwordResetSentAtKey(email)
|
||||
exists, err := c.rdb.Exists(ctx, key).Result()
|
||||
return err == nil && exists > 0
|
||||
}
|
||||
|
||||
func (c *emailCache) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
|
||||
key := passwordResetSentAtKey(email)
|
||||
return c.rdb.Set(ctx, key, "1", ttl).Err()
|
||||
}
|
||||
|
||||
@@ -65,5 +65,18 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
|
||||
|
||||
// 创建 Ent 客户端,绑定到已配置的数据库驱动。
|
||||
client := ent.NewClient(ent.Driver(drv))
|
||||
|
||||
// SIMPLE 模式:启动时补齐各平台默认分组。
|
||||
// - anthropic/openai/gemini: 确保存在 <platform>-default
|
||||
// - antigravity: 仅要求存在 >=2 个未软删除分组(用于 claude/gemini 混合调度场景)
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
seedCtx, seedCancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer seedCancel()
|
||||
if err := ensureSimpleModeDefaultGroups(seedCtx, client); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return client, drv.DB(), nil
|
||||
}
|
||||
|
||||
@@ -39,3 +39,15 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Expire(ctx, key, ttl).Err()
|
||||
}
|
||||
|
||||
// DeleteSessionAccountID 删除粘性会话与账号的绑定关系。
|
||||
// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用,
|
||||
// 以便下次请求能够重新选择可用账号。
|
||||
//
|
||||
// DeleteSessionAccountID removes the sticky session binding for the given session.
|
||||
// Called when the bound account becomes unavailable (e.g., error status, disabled,
|
||||
// or unschedulable), allowing subsequent requests to select a new available account.
|
||||
func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
@@ -78,6 +78,19 @@ func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
|
||||
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestDeleteSessionAccountID() {
|
||||
sessionID := "openai:s4"
|
||||
accountID := int64(102)
|
||||
groupID := int64(1)
|
||||
sessionTTL := 1 * time.Minute
|
||||
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
|
||||
require.NoError(s.T(), s.cache.DeleteSessionAccountID(s.ctx, groupID, sessionID), "DeleteSessionAccountID")
|
||||
|
||||
_, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
|
||||
sessionID := "corrupted"
|
||||
groupID := int64(1)
|
||||
|
||||
@@ -24,7 +24,7 @@ func (s *GatewayRoutingSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx)
|
||||
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx, nil)
|
||||
}
|
||||
|
||||
func TestGatewayRoutingSuite(t *testing.T) {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
)
|
||||
@@ -150,6 +152,172 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求
|
||||
// 根据 enableTLSFingerprint 参数决定是否使用 TLS 指纹
|
||||
//
|
||||
// 参数:
|
||||
// - req: HTTP 请求对象
|
||||
// - proxyURL: 代理地址,空字符串表示直连
|
||||
// - accountID: 账户 ID,用于账户级隔离和 TLS 指纹模板选择
|
||||
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
|
||||
// - enableTLSFingerprint: 是否启用 TLS 指纹伪装
|
||||
//
|
||||
// TLS 指纹说明:
|
||||
// - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹
|
||||
// - 指纹模板根据 accountID % len(profiles) 自动选择
|
||||
// - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景
|
||||
func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
|
||||
// 如果未启用 TLS 指纹,直接使用标准请求路径
|
||||
if !enableTLSFingerprint {
|
||||
return s.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
// TLS 指纹已启用,记录调试日志
|
||||
targetHost := ""
|
||||
if req != nil && req.URL != nil {
|
||||
targetHost = req.URL.Host
|
||||
}
|
||||
proxyInfo := "direct"
|
||||
if proxyURL != "" {
|
||||
proxyInfo = proxyURL
|
||||
}
|
||||
slog.Debug("tls_fingerprint_enabled", "account_id", accountID, "target", targetHost, "proxy", proxyInfo)
|
||||
|
||||
if err := s.validateRequestHost(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取 TLS 指纹 Profile
|
||||
registry := tlsfingerprint.GlobalRegistry()
|
||||
profile := registry.GetProfileByAccountID(accountID)
|
||||
if profile == nil {
|
||||
// 如果获取不到 profile,回退到普通请求
|
||||
slog.Debug("tls_fingerprint_no_profile", "account_id", accountID, "fallback", "standard_request")
|
||||
return s.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
slog.Debug("tls_fingerprint_using_profile", "account_id", accountID, "profile", profile.Name, "grease", profile.EnableGREASE)
|
||||
|
||||
// 获取或创建带 TLS 指纹的客户端
|
||||
entry, err := s.acquireClientWithTLS(proxyURL, accountID, accountConcurrency, profile)
|
||||
if err != nil {
|
||||
slog.Debug("tls_fingerprint_acquire_client_failed", "account_id", accountID, "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
resp, err := entry.client.Do(req)
|
||||
if err != nil {
|
||||
// 请求失败,立即减少计数
|
||||
atomic.AddInt64(&entry.inFlight, -1)
|
||||
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
|
||||
slog.Debug("tls_fingerprint_request_failed", "account_id", accountID, "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
slog.Debug("tls_fingerprint_request_success", "account_id", accountID, "status", resp.StatusCode)
|
||||
|
||||
// 包装响应体,在关闭时自动减少计数并更新时间戳
|
||||
resp.Body = wrapTrackedBody(resp.Body, func() {
|
||||
atomic.AddInt64(&entry.inFlight, -1)
|
||||
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
|
||||
})
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// acquireClientWithTLS 获取或创建带 TLS 指纹的客户端
|
||||
func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*upstreamClientEntry, error) {
|
||||
return s.getClientEntryWithTLS(proxyURL, accountID, accountConcurrency, profile, true, true)
|
||||
}
|
||||
|
||||
// getClientEntryWithTLS 获取或创建带 TLS 指纹的客户端条目
|
||||
// TLS 指纹客户端使用独立的缓存键,与普通客户端隔离
|
||||
func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
|
||||
isolation := s.getIsolationMode()
|
||||
proxyKey, parsedProxy := normalizeProxyURL(proxyURL)
|
||||
// TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀
|
||||
cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID)
|
||||
poolKey := s.buildPoolKey(isolation, accountConcurrency) + ":tls"
|
||||
|
||||
now := time.Now()
|
||||
nowUnix := now.UnixNano()
|
||||
|
||||
// 读锁快速路径
|
||||
s.mu.RLock()
|
||||
if entry, ok := s.clients[cacheKey]; ok && s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
|
||||
atomic.StoreInt64(&entry.lastUsed, nowUnix)
|
||||
if markInFlight {
|
||||
atomic.AddInt64(&entry.inFlight, 1)
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
slog.Debug("tls_fingerprint_reusing_client", "account_id", accountID, "cache_key", cacheKey)
|
||||
return entry, nil
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
// 写锁慢路径
|
||||
s.mu.Lock()
|
||||
if entry, ok := s.clients[cacheKey]; ok {
|
||||
if s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
|
||||
atomic.StoreInt64(&entry.lastUsed, nowUnix)
|
||||
if markInFlight {
|
||||
atomic.AddInt64(&entry.inFlight, 1)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
slog.Debug("tls_fingerprint_reusing_client", "account_id", accountID, "cache_key", cacheKey)
|
||||
return entry, nil
|
||||
}
|
||||
slog.Debug("tls_fingerprint_evicting_stale_client",
|
||||
"account_id", accountID,
|
||||
"cache_key", cacheKey,
|
||||
"proxy_changed", entry.proxyKey != proxyKey,
|
||||
"pool_changed", entry.poolKey != poolKey)
|
||||
s.removeClientLocked(cacheKey, entry)
|
||||
}
|
||||
|
||||
// 超出缓存上限时尝试淘汰
|
||||
if enforceLimit && s.maxUpstreamClients() > 0 {
|
||||
s.evictIdleLocked(now)
|
||||
if len(s.clients) >= s.maxUpstreamClients() {
|
||||
if !s.evictOldestIdleLocked() {
|
||||
s.mu.Unlock()
|
||||
return nil, errUpstreamClientLimitReached
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建带 TLS 指纹的 Transport
|
||||
slog.Debug("tls_fingerprint_creating_new_client", "account_id", accountID, "cache_key", cacheKey, "proxy", proxyKey)
|
||||
settings := s.resolvePoolSettings(isolation, accountConcurrency)
|
||||
transport, err := buildUpstreamTransportWithTLSFingerprint(settings, parsedProxy, profile)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return nil, fmt.Errorf("build TLS fingerprint transport: %w", err)
|
||||
}
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
if s.shouldValidateResolvedIP() {
|
||||
client.CheckRedirect = s.redirectChecker
|
||||
}
|
||||
|
||||
entry := &upstreamClientEntry{
|
||||
client: client,
|
||||
proxyKey: proxyKey,
|
||||
poolKey: poolKey,
|
||||
}
|
||||
atomic.StoreInt64(&entry.lastUsed, nowUnix)
|
||||
if markInFlight {
|
||||
atomic.StoreInt64(&entry.inFlight, 1)
|
||||
}
|
||||
s.clients[cacheKey] = entry
|
||||
|
||||
s.evictIdleLocked(now)
|
||||
s.evictOverLimitLocked()
|
||||
s.mu.Unlock()
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func (s *httpUpstreamService) shouldValidateResolvedIP() bool {
|
||||
if s.cfg == nil {
|
||||
return false
|
||||
@@ -618,6 +786,64 @@ func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Tra
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
// buildUpstreamTransportWithTLSFingerprint 构建带 TLS 指纹伪装的 Transport
|
||||
// 使用 utls 库模拟 Claude CLI 的 TLS 指纹
|
||||
//
|
||||
// 参数:
|
||||
// - settings: 连接池配置
|
||||
// - proxyURL: 代理 URL(nil 表示直连)
|
||||
// - profile: TLS 指纹配置
|
||||
//
|
||||
// 返回:
|
||||
// - *http.Transport: 配置好的 Transport 实例
|
||||
// - error: 配置错误
|
||||
//
|
||||
// 代理类型处理:
|
||||
// - nil/空: 直连,使用 TLSFingerprintDialer
|
||||
// - http/https: HTTP 代理,使用 HTTPProxyDialer(CONNECT 隧道 + utls 握手)
|
||||
// - socks5: SOCKS5 代理,使用 SOCKS5ProxyDialer(SOCKS5 隧道 + utls 握手)
|
||||
func buildUpstreamTransportWithTLSFingerprint(settings poolSettings, proxyURL *url.URL, profile *tlsfingerprint.Profile) (*http.Transport, error) {
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: settings.maxIdleConns,
|
||||
MaxIdleConnsPerHost: settings.maxIdleConnsPerHost,
|
||||
MaxConnsPerHost: settings.maxConnsPerHost,
|
||||
IdleConnTimeout: settings.idleConnTimeout,
|
||||
ResponseHeaderTimeout: settings.responseHeaderTimeout,
|
||||
// 禁用默认的 TLS,我们使用自定义的 DialTLSContext
|
||||
ForceAttemptHTTP2: false,
|
||||
}
|
||||
|
||||
// 根据代理类型选择合适的 TLS 指纹 Dialer
|
||||
if proxyURL == nil {
|
||||
// 直连:使用 TLSFingerprintDialer
|
||||
slog.Debug("tls_fingerprint_transport_direct")
|
||||
dialer := tlsfingerprint.NewDialer(profile, nil)
|
||||
transport.DialTLSContext = dialer.DialTLSContext
|
||||
} else {
|
||||
scheme := strings.ToLower(proxyURL.Scheme)
|
||||
switch scheme {
|
||||
case "socks5", "socks5h":
|
||||
// SOCKS5 代理:使用 SOCKS5ProxyDialer
|
||||
slog.Debug("tls_fingerprint_transport_socks5", "proxy", proxyURL.Host)
|
||||
socks5Dialer := tlsfingerprint.NewSOCKS5ProxyDialer(profile, proxyURL)
|
||||
transport.DialTLSContext = socks5Dialer.DialTLSContext
|
||||
case "http", "https":
|
||||
// HTTP/HTTPS 代理:使用 HTTPProxyDialer(CONNECT 隧道)
|
||||
slog.Debug("tls_fingerprint_transport_http_connect", "proxy", proxyURL.Host)
|
||||
httpDialer := tlsfingerprint.NewHTTPProxyDialer(profile, proxyURL)
|
||||
transport.DialTLSContext = httpDialer.DialTLSContext
|
||||
default:
|
||||
// 未知代理类型,回退到普通代理配置(无 TLS 指纹)
|
||||
slog.Debug("tls_fingerprint_transport_unknown_scheme_fallback", "scheme", scheme)
|
||||
if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
// trackedBody 带跟踪功能的响应体包装器
|
||||
// 在 Close 时执行回调,用于更新请求计数
|
||||
type trackedBody struct {
|
||||
|
||||
@@ -11,8 +11,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
fingerprintKeyPrefix = "fingerprint:"
|
||||
fingerprintTTL = 24 * time.Hour
|
||||
fingerprintKeyPrefix = "fingerprint:"
|
||||
fingerprintTTL = 24 * time.Hour
|
||||
maskedSessionKeyPrefix = "masked_session:"
|
||||
maskedSessionTTL = 15 * time.Minute
|
||||
)
|
||||
|
||||
// fingerprintKey generates the Redis key for account fingerprint cache.
|
||||
@@ -20,6 +22,11 @@ func fingerprintKey(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
// maskedSessionKey generates the Redis key for masked session ID cache.
|
||||
func maskedSessionKey(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d", maskedSessionKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
type identityCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
@@ -49,3 +56,20 @@ func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp
|
||||
}
|
||||
return c.rdb.Set(ctx, key, val, fingerprintTTL).Err()
|
||||
}
|
||||
|
||||
func (c *identityCache) GetMaskedSessionID(ctx context.Context, accountID int64) (string, error) {
|
||||
key := maskedSessionKey(accountID)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return "", nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (c *identityCache) SetMaskedSessionID(ctx context.Context, accountID int64, sessionID string) error {
|
||||
key := maskedSessionKey(accountID)
|
||||
return c.rdb.Set(ctx, key, sessionID, maskedSessionTTL).Err()
|
||||
}
|
||||
|
||||
@@ -2,10 +2,11 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/imroc/req/v3"
|
||||
@@ -38,16 +39,17 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("User-Agent", "codex-cli/0.91.0").
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_EXCHANGE_FAILED", "token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
@@ -66,16 +68,17 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("User-Agent", "codex-cli/0.91.0").
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
@@ -84,6 +87,6 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
||||
func createOpenAIReqClient(proxyURL string) *req.Client {
|
||||
return getSharedReqClient(reqClientOptions{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 60 * time.Second,
|
||||
Timeout: 120 * time.Second,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -244,6 +244,13 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
|
||||
require.ErrorContains(s.T(), err, "status 401")
|
||||
}
|
||||
|
||||
func TestNewOpenAIOAuthClient_DefaultTokenURL(t *testing.T) {
|
||||
client := NewOpenAIOAuthClient()
|
||||
svc, ok := client.(*openaiOAuthService)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, openai.TokenURL, svc.tokenURL)
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(OpenAIOAuthServiceSuite))
|
||||
}
|
||||
|
||||
@@ -992,7 +992,8 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
||||
}
|
||||
|
||||
// View filter: errors vs excluded vs all.
|
||||
// Excluded = upstream 429/529 and business-limited (quota/concurrency/billing) errors.
|
||||
// Excluded = business-limited errors (quota/concurrency/billing).
|
||||
// Upstream 429/529 are included in errors view to match SLA calculation.
|
||||
view := ""
|
||||
if filter != nil {
|
||||
view = strings.ToLower(strings.TrimSpace(filter.View))
|
||||
@@ -1000,15 +1001,13 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
||||
switch view {
|
||||
case "", "errors":
|
||||
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
|
||||
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)")
|
||||
case "excluded":
|
||||
clauses = append(clauses, "(COALESCE(is_business_limited,false) = true OR COALESCE(upstream_status_code, status_code, 0) IN (429, 529))")
|
||||
clauses = append(clauses, "COALESCE(is_business_limited,false) = true")
|
||||
case "all":
|
||||
// no-op
|
||||
default:
|
||||
// treat unknown as default 'errors'
|
||||
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
|
||||
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)")
|
||||
}
|
||||
if len(filter.StatusCodes) > 0 {
|
||||
args = append(args, pq.Array(filter.StatusCodes))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
@@ -26,7 +27,7 @@ func InitRedis(cfg *config.Config) *redis.Client {
|
||||
// buildRedisOptions 构建 Redis 连接选项
|
||||
// 从配置文件读取连接池和超时参数,支持生产环境调优
|
||||
func buildRedisOptions(cfg *config.Config) *redis.Options {
|
||||
return &redis.Options{
|
||||
opts := &redis.Options{
|
||||
Addr: cfg.Redis.Address(),
|
||||
Password: cfg.Redis.Password,
|
||||
DB: cfg.Redis.DB,
|
||||
@@ -36,4 +37,13 @@ func buildRedisOptions(cfg *config.Config) *redis.Options {
|
||||
PoolSize: cfg.Redis.PoolSize, // 连接池大小
|
||||
MinIdleConns: cfg.Redis.MinIdleConns, // 最小空闲连接
|
||||
}
|
||||
|
||||
if cfg.Redis.EnableTLS {
|
||||
opts.TLSConfig = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
ServerName: cfg.Redis.Host,
|
||||
}
|
||||
}
|
||||
|
||||
return opts
|
||||
}
|
||||
|
||||
@@ -32,4 +32,16 @@ func TestBuildRedisOptions(t *testing.T) {
|
||||
require.Equal(t, 4*time.Second, opts.WriteTimeout)
|
||||
require.Equal(t, 100, opts.PoolSize)
|
||||
require.Equal(t, 10, opts.MinIdleConns)
|
||||
require.Nil(t, opts.TLSConfig)
|
||||
|
||||
// Test case with TLS enabled
|
||||
cfgTLS := &config.Config{
|
||||
Redis: config.RedisConfig{
|
||||
Host: "localhost",
|
||||
EnableTLS: true,
|
||||
},
|
||||
}
|
||||
optsTLS := buildRedisOptions(cfgTLS)
|
||||
require.NotNil(t, optsTLS.TLSConfig)
|
||||
require.Equal(t, "localhost", optsTLS.TLSConfig.ServerName)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ type reqClientOptions struct {
|
||||
ProxyURL string // 代理 URL(支持 http/https/socks5)
|
||||
Timeout time.Duration // 请求超时时间
|
||||
Impersonate bool // 是否模拟 Chrome 浏览器指纹
|
||||
ForceHTTP2 bool // 是否强制使用 HTTP/2
|
||||
}
|
||||
|
||||
// sharedReqClients 存储按配置参数缓存的 req 客户端实例
|
||||
@@ -41,6 +42,9 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
|
||||
}
|
||||
|
||||
client := req.C().SetTimeout(opts.Timeout)
|
||||
if opts.ForceHTTP2 {
|
||||
client = client.EnableForceHTTP2()
|
||||
}
|
||||
if opts.Impersonate {
|
||||
client = client.ImpersonateChrome()
|
||||
}
|
||||
@@ -56,9 +60,10 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
|
||||
}
|
||||
|
||||
func buildReqClientKey(opts reqClientOptions) string {
|
||||
return fmt.Sprintf("%s|%s|%t",
|
||||
return fmt.Sprintf("%s|%s|%t|%t",
|
||||
strings.TrimSpace(opts.ProxyURL),
|
||||
opts.Timeout.String(),
|
||||
opts.Impersonate,
|
||||
opts.ForceHTTP2,
|
||||
)
|
||||
}
|
||||
|
||||
90
backend/internal/repository/req_client_pool_test.go
Normal file
90
backend/internal/repository/req_client_pool_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func forceHTTPVersion(t *testing.T, client *req.Client) string {
|
||||
t.Helper()
|
||||
transport := client.GetTransport()
|
||||
field := reflect.ValueOf(transport).Elem().FieldByName("forceHttpVersion")
|
||||
require.True(t, field.IsValid(), "forceHttpVersion field not found")
|
||||
require.True(t, field.CanAddr(), "forceHttpVersion field not addressable")
|
||||
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().String()
|
||||
}
|
||||
|
||||
func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
base := reqClientOptions{
|
||||
ProxyURL: "http://proxy.local:8080",
|
||||
Timeout: time.Second,
|
||||
}
|
||||
clientDefault := getSharedReqClient(base)
|
||||
|
||||
force := base
|
||||
force.ForceHTTP2 = true
|
||||
clientForce := getSharedReqClient(force)
|
||||
|
||||
require.NotSame(t, clientDefault, clientForce)
|
||||
require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force))
|
||||
}
|
||||
|
||||
func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
opts := reqClientOptions{
|
||||
ProxyURL: "http://proxy.local:8080",
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
first := getSharedReqClient(opts)
|
||||
second := getSharedReqClient(opts)
|
||||
require.Same(t, first, second)
|
||||
}
|
||||
|
||||
func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
opts := reqClientOptions{
|
||||
ProxyURL: " http://proxy.local:8080 ",
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
key := buildReqClientKey(opts)
|
||||
sharedReqClients.Store(key, "invalid")
|
||||
|
||||
client := getSharedReqClient(opts)
|
||||
|
||||
require.NotNil(t, client)
|
||||
loaded, ok := sharedReqClients.Load(key)
|
||||
require.True(t, ok)
|
||||
require.IsType(t, "invalid", loaded)
|
||||
}
|
||||
|
||||
func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
opts := reqClientOptions{
|
||||
ProxyURL: " http://proxy.local:8080 ",
|
||||
Timeout: 4 * time.Second,
|
||||
Impersonate: true,
|
||||
}
|
||||
client := getSharedReqClient(opts)
|
||||
|
||||
require.NotNil(t, client)
|
||||
require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts))
|
||||
}
|
||||
|
||||
func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
client := createOpenAIReqClient("http://proxy.local:8080")
|
||||
require.Equal(t, 120*time.Second, client.GetClient().Timeout)
|
||||
}
|
||||
|
||||
func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
client := createGeminiReqClient("http://proxy.local:8080")
|
||||
require.Equal(t, "", forceHTTPVersion(t, client))
|
||||
}
|
||||
@@ -58,7 +58,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul
|
||||
return nil, false, err
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return []*service.Account{}, true, nil
|
||||
// 空快照视为缓存未命中,触发数据库回退查询
|
||||
// 这解决了新分组创建后立即绑定账号时的竞态条件问题
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(ids))
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
|
||||
|
||||
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox")
|
||||
|
||||
accountRepo := newAccountRepositoryWithSQL(client, integrationDB)
|
||||
accountRepo := newAccountRepositoryWithSQL(client, integrationDB, nil)
|
||||
outboxRepo := NewSchedulerOutboxRepository(integrationDB)
|
||||
cache := NewSchedulerCache(rdb)
|
||||
|
||||
|
||||
@@ -217,7 +217,7 @@ func (c *sessionLimitCache) GetActiveSessionCount(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
|
||||
func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64, idleTimeouts map[int64]time.Duration) (map[int64]int, error) {
|
||||
if len(accountIDs) == 0 {
|
||||
return make(map[int64]int), nil
|
||||
}
|
||||
@@ -226,11 +226,18 @@ func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, acco
|
||||
|
||||
// 使用 pipeline 批量执行
|
||||
pipe := c.rdb.Pipeline()
|
||||
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
|
||||
|
||||
cmds := make(map[int64]*redis.Cmd, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
key := sessionLimitKey(accountID)
|
||||
// 使用各账号自己的 idleTimeout,如果没有则用默认值
|
||||
idleTimeout := c.defaultIdleTimeout
|
||||
if idleTimeouts != nil {
|
||||
if t, ok := idleTimeouts[accountID]; ok && t > 0 {
|
||||
idleTimeout = t
|
||||
}
|
||||
}
|
||||
idleTimeoutSeconds := int(idleTimeout.Seconds())
|
||||
cmds[accountID] = getActiveSessionCountScript.Run(ctx, pipe, []string{key}, idleTimeoutSeconds)
|
||||
}
|
||||
|
||||
|
||||
82
backend/internal/repository/simple_mode_default_groups.go
Normal file
82
backend/internal/repository/simple_mode_default_groups.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func ensureSimpleModeDefaultGroups(ctx context.Context, client *dbent.Client) error {
|
||||
if client == nil {
|
||||
return fmt.Errorf("nil ent client")
|
||||
}
|
||||
|
||||
requiredByPlatform := map[string]int{
|
||||
service.PlatformAnthropic: 1,
|
||||
service.PlatformOpenAI: 1,
|
||||
service.PlatformGemini: 1,
|
||||
service.PlatformAntigravity: 2,
|
||||
}
|
||||
|
||||
for platform, minCount := range requiredByPlatform {
|
||||
count, err := client.Group.Query().
|
||||
Where(group.PlatformEQ(platform), group.DeletedAtIsNil()).
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("count groups for platform %s: %w", platform, err)
|
||||
}
|
||||
|
||||
if platform == service.PlatformAntigravity {
|
||||
if count < minCount {
|
||||
for i := count; i < minCount; i++ {
|
||||
name := fmt.Sprintf("%s-default-%d", platform, i+1)
|
||||
if err := createGroupIfNotExists(ctx, client, name, platform); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Non-antigravity platforms: ensure <platform>-default exists.
|
||||
name := platform + "-default"
|
||||
if err := createGroupIfNotExists(ctx, client, name, platform); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createGroupIfNotExists(ctx context.Context, client *dbent.Client, name, platform string) error {
|
||||
exists, err := client.Group.Query().
|
||||
Where(group.NameEQ(name), group.DeletedAtIsNil()).
|
||||
Exist(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check group exists %s: %w", name, err)
|
||||
}
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = client.Group.Create().
|
||||
SetName(name).
|
||||
SetDescription("Auto-created default group").
|
||||
SetPlatform(platform).
|
||||
SetStatus(service.StatusActive).
|
||||
SetSubscriptionType(service.SubscriptionTypeStandard).
|
||||
SetRateMultiplier(1.0).
|
||||
SetIsExclusive(false).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsConstraintError(err) {
|
||||
// Concurrent server startups may race on creation; treat as success.
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("create default group %s: %w", name, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEnsureSimpleModeDefaultGroups_CreatesMissingDefaults(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
client := tx.Client()
|
||||
|
||||
seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client))
|
||||
|
||||
assertGroupExists := func(name string) {
|
||||
exists, err := client.Group.Query().Where(group.NameEQ(name), group.DeletedAtIsNil()).Exist(seedCtx)
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists, "expected group %s to exist", name)
|
||||
}
|
||||
|
||||
assertGroupExists(service.PlatformAnthropic + "-default")
|
||||
assertGroupExists(service.PlatformOpenAI + "-default")
|
||||
assertGroupExists(service.PlatformGemini + "-default")
|
||||
assertGroupExists(service.PlatformAntigravity + "-default-1")
|
||||
assertGroupExists(service.PlatformAntigravity + "-default-2")
|
||||
}
|
||||
|
||||
func TestEnsureSimpleModeDefaultGroups_IgnoresSoftDeletedGroups(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
client := tx.Client()
|
||||
|
||||
seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Create and then soft-delete an anthropic default group.
|
||||
g, err := client.Group.Create().
|
||||
SetName(service.PlatformAnthropic + "-default").
|
||||
SetPlatform(service.PlatformAnthropic).
|
||||
SetStatus(service.StatusActive).
|
||||
SetSubscriptionType(service.SubscriptionTypeStandard).
|
||||
SetRateMultiplier(1.0).
|
||||
SetIsExclusive(false).
|
||||
Save(seedCtx)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.Group.Delete().Where(group.IDEQ(g.ID)).Exec(seedCtx)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client))
|
||||
|
||||
// New active one should exist.
|
||||
count, err := client.Group.Query().Where(group.NameEQ(service.PlatformAnthropic+"-default"), group.DeletedAtIsNil()).Count(seedCtx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestEnsureSimpleModeDefaultGroups_AntigravityNeedsTwoGroupsOnlyByCount(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
client := tx.Client()
|
||||
|
||||
seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
mustCreateGroup(t, client, &service.Group{Name: "ag-custom-1-" + time.Now().Format(time.RFC3339Nano), Platform: service.PlatformAntigravity})
|
||||
mustCreateGroup(t, client, &service.Group{Name: "ag-custom-2-" + time.Now().Format(time.RFC3339Nano), Platform: service.PlatformAntigravity})
|
||||
|
||||
require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client))
|
||||
|
||||
count, err := client.Group.Query().Where(group.PlatformEQ(service.PlatformAntigravity), group.DeletedAtIsNil()).Count(seedCtx)
|
||||
require.NoError(t, err)
|
||||
require.GreaterOrEqual(t, count, 2)
|
||||
}
|
||||
149
backend/internal/repository/totp_cache.go
Normal file
149
backend/internal/repository/totp_cache.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
const (
|
||||
totpSetupKeyPrefix = "totp:setup:"
|
||||
totpLoginKeyPrefix = "totp:login:"
|
||||
totpAttemptsKeyPrefix = "totp:attempts:"
|
||||
totpAttemptsTTL = 15 * time.Minute
|
||||
)
|
||||
|
||||
// TotpCache implements service.TotpCache using Redis
|
||||
type TotpCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewTotpCache creates a new TOTP cache
|
||||
func NewTotpCache(rdb *redis.Client) service.TotpCache {
|
||||
return &TotpCache{rdb: rdb}
|
||||
}
|
||||
|
||||
// GetSetupSession retrieves a TOTP setup session
|
||||
func (c *TotpCache) GetSetupSession(ctx context.Context, userID int64) (*service.TotpSetupSession, error) {
|
||||
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
|
||||
data, err := c.rdb.Get(ctx, key).Bytes()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("get setup session: %w", err)
|
||||
}
|
||||
|
||||
var session service.TotpSetupSession
|
||||
if err := json.Unmarshal(data, &session); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal setup session: %w", err)
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// SetSetupSession stores a TOTP setup session
|
||||
func (c *TotpCache) SetSetupSession(ctx context.Context, userID int64, session *service.TotpSetupSession, ttl time.Duration) error {
|
||||
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
|
||||
data, err := json.Marshal(session)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal setup session: %w", err)
|
||||
}
|
||||
|
||||
if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
|
||||
return fmt.Errorf("set setup session: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteSetupSession deletes a TOTP setup session
|
||||
func (c *TotpCache) DeleteSetupSession(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// GetLoginSession retrieves a TOTP login session
|
||||
func (c *TotpCache) GetLoginSession(ctx context.Context, tempToken string) (*service.TotpLoginSession, error) {
|
||||
key := totpLoginKeyPrefix + tempToken
|
||||
data, err := c.rdb.Get(ctx, key).Bytes()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("get login session: %w", err)
|
||||
}
|
||||
|
||||
var session service.TotpLoginSession
|
||||
if err := json.Unmarshal(data, &session); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal login session: %w", err)
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// SetLoginSession stores a TOTP login session
|
||||
func (c *TotpCache) SetLoginSession(ctx context.Context, tempToken string, session *service.TotpLoginSession, ttl time.Duration) error {
|
||||
key := totpLoginKeyPrefix + tempToken
|
||||
data, err := json.Marshal(session)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal login session: %w", err)
|
||||
}
|
||||
|
||||
if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
|
||||
return fmt.Errorf("set login session: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteLoginSession deletes a TOTP login session
|
||||
func (c *TotpCache) DeleteLoginSession(ctx context.Context, tempToken string) error {
|
||||
key := totpLoginKeyPrefix + tempToken
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// IncrementVerifyAttempts increments the verify attempt counter
|
||||
func (c *TotpCache) IncrementVerifyAttempts(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
|
||||
|
||||
// Use pipeline for atomic increment and set TTL
|
||||
pipe := c.rdb.Pipeline()
|
||||
incrCmd := pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, totpAttemptsTTL)
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return 0, fmt.Errorf("increment verify attempts: %w", err)
|
||||
}
|
||||
|
||||
count, err := incrCmd.Result()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get increment result: %w", err)
|
||||
}
|
||||
|
||||
return int(count), nil
|
||||
}
|
||||
|
||||
// GetVerifyAttempts gets the current verify attempt count
|
||||
func (c *TotpCache) GetVerifyAttempts(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
|
||||
count, err := c.rdb.Get(ctx, key).Int()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, fmt.Errorf("get verify attempts: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// ClearVerifyAttempts clears the verify attempt counter
|
||||
func (c *TotpCache) ClearVerifyAttempts(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
551
backend/internal/repository/usage_cleanup_repo.go
Normal file
551
backend/internal/repository/usage_cleanup_repo.go
Normal file
@@ -0,0 +1,551 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
dbusagecleanuptask "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type usageCleanupRepository struct {
|
||||
client *dbent.Client
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
func NewUsageCleanupRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageCleanupRepository {
|
||||
return newUsageCleanupRepositoryWithSQL(client, sqlDB)
|
||||
}
|
||||
|
||||
func newUsageCleanupRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageCleanupRepository {
|
||||
return &usageCleanupRepository{client: client, sql: sqlq}
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) CreateTask(ctx context.Context, task *service.UsageCleanupTask) error {
|
||||
if task == nil {
|
||||
return nil
|
||||
}
|
||||
if r.client != nil {
|
||||
return r.createTaskWithEnt(ctx, task)
|
||||
}
|
||||
return r.createTaskWithSQL(ctx, task)
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) {
|
||||
if r.client != nil {
|
||||
return r.listTasksWithEnt(ctx, params)
|
||||
}
|
||||
var total int64
|
||||
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM usage_cleanup_tasks", nil, &total); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if total == 0 {
|
||||
return []service.UsageCleanupTask{}, paginationResultFromTotal(0, params), nil
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT id, status, filters, created_by, deleted_rows, error_message,
|
||||
canceled_by, canceled_at,
|
||||
started_at, finished_at, created_at, updated_at
|
||||
FROM usage_cleanup_tasks
|
||||
ORDER BY created_at DESC, id DESC
|
||||
LIMIT $1 OFFSET $2
|
||||
`
|
||||
rows, err := r.sql.QueryContext(ctx, query, params.Limit(), params.Offset())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
tasks := make([]service.UsageCleanupTask, 0)
|
||||
for rows.Next() {
|
||||
var task service.UsageCleanupTask
|
||||
var filtersJSON []byte
|
||||
var errMsg sql.NullString
|
||||
var canceledBy sql.NullInt64
|
||||
var canceledAt sql.NullTime
|
||||
var startedAt sql.NullTime
|
||||
var finishedAt sql.NullTime
|
||||
if err := rows.Scan(
|
||||
&task.ID,
|
||||
&task.Status,
|
||||
&filtersJSON,
|
||||
&task.CreatedBy,
|
||||
&task.DeletedRows,
|
||||
&errMsg,
|
||||
&canceledBy,
|
||||
&canceledAt,
|
||||
&startedAt,
|
||||
&finishedAt,
|
||||
&task.CreatedAt,
|
||||
&task.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := json.Unmarshal(filtersJSON, &task.Filters); err != nil {
|
||||
return nil, nil, fmt.Errorf("parse cleanup filters: %w", err)
|
||||
}
|
||||
if errMsg.Valid {
|
||||
task.ErrorMsg = &errMsg.String
|
||||
}
|
||||
if canceledBy.Valid {
|
||||
v := canceledBy.Int64
|
||||
task.CanceledBy = &v
|
||||
}
|
||||
if canceledAt.Valid {
|
||||
task.CanceledAt = &canceledAt.Time
|
||||
}
|
||||
if startedAt.Valid {
|
||||
task.StartedAt = &startedAt.Time
|
||||
}
|
||||
if finishedAt.Valid {
|
||||
task.FinishedAt = &finishedAt.Time
|
||||
}
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return tasks, paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*service.UsageCleanupTask, error) {
|
||||
if staleRunningAfterSeconds <= 0 {
|
||||
staleRunningAfterSeconds = 1800
|
||||
}
|
||||
query := `
|
||||
WITH next AS (
|
||||
SELECT id
|
||||
FROM usage_cleanup_tasks
|
||||
WHERE status = $1
|
||||
OR (
|
||||
status = $2
|
||||
AND started_at IS NOT NULL
|
||||
AND started_at < NOW() - ($3 * interval '1 second')
|
||||
)
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 1
|
||||
FOR UPDATE SKIP LOCKED
|
||||
)
|
||||
UPDATE usage_cleanup_tasks AS tasks
|
||||
SET status = $4,
|
||||
started_at = NOW(),
|
||||
finished_at = NULL,
|
||||
error_message = NULL,
|
||||
updated_at = NOW()
|
||||
FROM next
|
||||
WHERE tasks.id = next.id
|
||||
RETURNING tasks.id, tasks.status, tasks.filters, tasks.created_by, tasks.deleted_rows, tasks.error_message,
|
||||
tasks.started_at, tasks.finished_at, tasks.created_at, tasks.updated_at
|
||||
`
|
||||
var task service.UsageCleanupTask
|
||||
var filtersJSON []byte
|
||||
var errMsg sql.NullString
|
||||
var startedAt sql.NullTime
|
||||
var finishedAt sql.NullTime
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
query,
|
||||
[]any{
|
||||
service.UsageCleanupStatusPending,
|
||||
service.UsageCleanupStatusRunning,
|
||||
staleRunningAfterSeconds,
|
||||
service.UsageCleanupStatusRunning,
|
||||
},
|
||||
&task.ID,
|
||||
&task.Status,
|
||||
&filtersJSON,
|
||||
&task.CreatedBy,
|
||||
&task.DeletedRows,
|
||||
&errMsg,
|
||||
&startedAt,
|
||||
&finishedAt,
|
||||
&task.CreatedAt,
|
||||
&task.UpdatedAt,
|
||||
); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if err := json.Unmarshal(filtersJSON, &task.Filters); err != nil {
|
||||
return nil, fmt.Errorf("parse cleanup filters: %w", err)
|
||||
}
|
||||
if errMsg.Valid {
|
||||
task.ErrorMsg = &errMsg.String
|
||||
}
|
||||
if startedAt.Valid {
|
||||
task.StartedAt = &startedAt.Time
|
||||
}
|
||||
if finishedAt.Valid {
|
||||
task.FinishedAt = &finishedAt.Time
|
||||
}
|
||||
return &task, nil
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
|
||||
if r.client != nil {
|
||||
return r.getTaskStatusWithEnt(ctx, taskID)
|
||||
}
|
||||
var status string
|
||||
if err := scanSingleRow(ctx, r.sql, "SELECT status FROM usage_cleanup_tasks WHERE id = $1", []any{taskID}, &status); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error {
|
||||
if r.client != nil {
|
||||
return r.updateTaskProgressWithEnt(ctx, taskID, deletedRows)
|
||||
}
|
||||
query := `
|
||||
UPDATE usage_cleanup_tasks
|
||||
SET deleted_rows = $1,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, deletedRows, taskID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
|
||||
if r.client != nil {
|
||||
return r.cancelTaskWithEnt(ctx, taskID, canceledBy)
|
||||
}
|
||||
query := `
|
||||
UPDATE usage_cleanup_tasks
|
||||
SET status = $1,
|
||||
canceled_by = $3,
|
||||
canceled_at = NOW(),
|
||||
finished_at = NOW(),
|
||||
error_message = NULL,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2
|
||||
AND status IN ($4, $5)
|
||||
RETURNING id
|
||||
`
|
||||
var id int64
|
||||
err := scanSingleRow(ctx, r.sql, query, []any{
|
||||
service.UsageCleanupStatusCanceled,
|
||||
taskID,
|
||||
canceledBy,
|
||||
service.UsageCleanupStatusPending,
|
||||
service.UsageCleanupStatusRunning,
|
||||
}, &id)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error {
|
||||
if r.client != nil {
|
||||
return r.markTaskSucceededWithEnt(ctx, taskID, deletedRows)
|
||||
}
|
||||
query := `
|
||||
UPDATE usage_cleanup_tasks
|
||||
SET status = $1,
|
||||
deleted_rows = $2,
|
||||
finished_at = NOW(),
|
||||
updated_at = NOW()
|
||||
WHERE id = $3
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, service.UsageCleanupStatusSucceeded, deletedRows, taskID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
|
||||
if r.client != nil {
|
||||
return r.markTaskFailedWithEnt(ctx, taskID, deletedRows, errorMsg)
|
||||
}
|
||||
query := `
|
||||
UPDATE usage_cleanup_tasks
|
||||
SET status = $1,
|
||||
deleted_rows = $2,
|
||||
error_message = $3,
|
||||
finished_at = NOW(),
|
||||
updated_at = NOW()
|
||||
WHERE id = $4
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, service.UsageCleanupStatusFailed, deletedRows, errorMsg, taskID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) DeleteUsageLogsBatch(ctx context.Context, filters service.UsageCleanupFilters, limit int) (int64, error) {
|
||||
if filters.StartTime.IsZero() || filters.EndTime.IsZero() {
|
||||
return 0, fmt.Errorf("cleanup filters missing time range")
|
||||
}
|
||||
whereClause, args := buildUsageCleanupWhere(filters)
|
||||
if whereClause == "" {
|
||||
return 0, fmt.Errorf("cleanup filters missing time range")
|
||||
}
|
||||
args = append(args, limit)
|
||||
query := fmt.Sprintf(`
|
||||
WITH target AS (
|
||||
SELECT id
|
||||
FROM usage_logs
|
||||
WHERE %s
|
||||
ORDER BY created_at ASC, id ASC
|
||||
LIMIT $%d
|
||||
)
|
||||
DELETE FROM usage_logs
|
||||
WHERE id IN (SELECT id FROM target)
|
||||
RETURNING id
|
||||
`, whereClause, len(args))
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var deleted int64
|
||||
for rows.Next() {
|
||||
deleted++
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
func buildUsageCleanupWhere(filters service.UsageCleanupFilters) (string, []any) {
|
||||
conditions := make([]string, 0, 8)
|
||||
args := make([]any, 0, 8)
|
||||
idx := 1
|
||||
if !filters.StartTime.IsZero() {
|
||||
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", idx))
|
||||
args = append(args, filters.StartTime)
|
||||
idx++
|
||||
}
|
||||
if !filters.EndTime.IsZero() {
|
||||
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", idx))
|
||||
args = append(args, filters.EndTime)
|
||||
idx++
|
||||
}
|
||||
if filters.UserID != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("user_id = $%d", idx))
|
||||
args = append(args, *filters.UserID)
|
||||
idx++
|
||||
}
|
||||
if filters.APIKeyID != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", idx))
|
||||
args = append(args, *filters.APIKeyID)
|
||||
idx++
|
||||
}
|
||||
if filters.AccountID != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("account_id = $%d", idx))
|
||||
args = append(args, *filters.AccountID)
|
||||
idx++
|
||||
}
|
||||
if filters.GroupID != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("group_id = $%d", idx))
|
||||
args = append(args, *filters.GroupID)
|
||||
idx++
|
||||
}
|
||||
if filters.Model != nil {
|
||||
model := strings.TrimSpace(*filters.Model)
|
||||
if model != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("model = $%d", idx))
|
||||
args = append(args, model)
|
||||
idx++
|
||||
}
|
||||
}
|
||||
if filters.Stream != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("stream = $%d", idx))
|
||||
args = append(args, *filters.Stream)
|
||||
idx++
|
||||
}
|
||||
if filters.BillingType != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", idx))
|
||||
args = append(args, *filters.BillingType)
|
||||
}
|
||||
return strings.Join(conditions, " AND "), args
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) createTaskWithEnt(ctx context.Context, task *service.UsageCleanupTask) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
filtersJSON, err := json.Marshal(task.Filters)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal cleanup filters: %w", err)
|
||||
}
|
||||
created, err := client.UsageCleanupTask.
|
||||
Create().
|
||||
SetStatus(task.Status).
|
||||
SetFilters(json.RawMessage(filtersJSON)).
|
||||
SetCreatedBy(task.CreatedBy).
|
||||
SetDeletedRows(task.DeletedRows).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
task.ID = created.ID
|
||||
task.CreatedAt = created.CreatedAt
|
||||
task.UpdatedAt = created.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) createTaskWithSQL(ctx context.Context, task *service.UsageCleanupTask) error {
|
||||
filtersJSON, err := json.Marshal(task.Filters)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal cleanup filters: %w", err)
|
||||
}
|
||||
query := `
|
||||
INSERT INTO usage_cleanup_tasks (
|
||||
status,
|
||||
filters,
|
||||
created_by,
|
||||
deleted_rows
|
||||
) VALUES ($1, $2, $3, $4)
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
if err := scanSingleRow(ctx, r.sql, query, []any{task.Status, filtersJSON, task.CreatedBy, task.DeletedRows}, &task.ID, &task.CreatedAt, &task.UpdatedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) listTasksWithEnt(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
query := client.UsageCleanupTask.Query()
|
||||
total, err := query.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if total == 0 {
|
||||
return []service.UsageCleanupTask{}, paginationResultFromTotal(0, params), nil
|
||||
}
|
||||
rows, err := query.
|
||||
Order(dbent.Desc(dbusagecleanuptask.FieldCreatedAt), dbent.Desc(dbusagecleanuptask.FieldID)).
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
tasks := make([]service.UsageCleanupTask, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
task, err := usageCleanupTaskFromEnt(row)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
return tasks, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) getTaskStatusWithEnt(ctx context.Context, taskID int64) (string, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
task, err := client.UsageCleanupTask.Query().
|
||||
Where(dbusagecleanuptask.IDEQ(taskID)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return "", sql.ErrNoRows
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return task.Status, nil
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) updateTaskProgressWithEnt(ctx context.Context, taskID int64, deletedRows int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
now := time.Now()
|
||||
_, err := client.UsageCleanupTask.Update().
|
||||
Where(dbusagecleanuptask.IDEQ(taskID)).
|
||||
SetDeletedRows(deletedRows).
|
||||
SetUpdatedAt(now).
|
||||
Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) cancelTaskWithEnt(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
now := time.Now()
|
||||
affected, err := client.UsageCleanupTask.Update().
|
||||
Where(
|
||||
dbusagecleanuptask.IDEQ(taskID),
|
||||
dbusagecleanuptask.StatusIn(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning),
|
||||
).
|
||||
SetStatus(service.UsageCleanupStatusCanceled).
|
||||
SetCanceledBy(canceledBy).
|
||||
SetCanceledAt(now).
|
||||
SetFinishedAt(now).
|
||||
ClearErrorMessage().
|
||||
SetUpdatedAt(now).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return affected > 0, nil
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) markTaskSucceededWithEnt(ctx context.Context, taskID int64, deletedRows int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
now := time.Now()
|
||||
_, err := client.UsageCleanupTask.Update().
|
||||
Where(dbusagecleanuptask.IDEQ(taskID)).
|
||||
SetStatus(service.UsageCleanupStatusSucceeded).
|
||||
SetDeletedRows(deletedRows).
|
||||
SetFinishedAt(now).
|
||||
SetUpdatedAt(now).
|
||||
Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) markTaskFailedWithEnt(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
now := time.Now()
|
||||
_, err := client.UsageCleanupTask.Update().
|
||||
Where(dbusagecleanuptask.IDEQ(taskID)).
|
||||
SetStatus(service.UsageCleanupStatusFailed).
|
||||
SetDeletedRows(deletedRows).
|
||||
SetErrorMessage(errorMsg).
|
||||
SetFinishedAt(now).
|
||||
SetUpdatedAt(now).
|
||||
Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func usageCleanupTaskFromEnt(row *dbent.UsageCleanupTask) (service.UsageCleanupTask, error) {
|
||||
task := service.UsageCleanupTask{
|
||||
ID: row.ID,
|
||||
Status: row.Status,
|
||||
CreatedBy: row.CreatedBy,
|
||||
DeletedRows: row.DeletedRows,
|
||||
CreatedAt: row.CreatedAt,
|
||||
UpdatedAt: row.UpdatedAt,
|
||||
}
|
||||
if len(row.Filters) > 0 {
|
||||
if err := json.Unmarshal(row.Filters, &task.Filters); err != nil {
|
||||
return service.UsageCleanupTask{}, fmt.Errorf("parse cleanup filters: %w", err)
|
||||
}
|
||||
}
|
||||
if row.ErrorMessage != nil {
|
||||
task.ErrorMsg = row.ErrorMessage
|
||||
}
|
||||
if row.CanceledBy != nil {
|
||||
task.CanceledBy = row.CanceledBy
|
||||
}
|
||||
if row.CanceledAt != nil {
|
||||
task.CanceledAt = row.CanceledAt
|
||||
}
|
||||
if row.StartedAt != nil {
|
||||
task.StartedAt = row.StartedAt
|
||||
}
|
||||
if row.FinishedAt != nil {
|
||||
task.FinishedAt = row.FinishedAt
|
||||
}
|
||||
return task, nil
|
||||
}
|
||||
251
backend/internal/repository/usage_cleanup_repo_ent_test.go
Normal file
251
backend/internal/repository/usage_cleanup_repo_ent_test.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
||||
dbusagecleanuptask "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func newUsageCleanupEntRepo(t *testing.T) (*usageCleanupRepository, *dbent.Client) {
|
||||
t.Helper()
|
||||
db, err := sql.Open("sqlite", "file:usage_cleanup?mode=memory&cache=shared")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
||||
require.NoError(t, err)
|
||||
|
||||
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
repo := &usageCleanupRepository{client: client, sql: db}
|
||||
return repo, client
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryEntCreateAndList(t *testing.T) {
|
||||
repo, _ := newUsageCleanupEntRepo(t)
|
||||
|
||||
start := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
task := &service.UsageCleanupTask{
|
||||
Status: service.UsageCleanupStatusPending,
|
||||
Filters: service.UsageCleanupFilters{StartTime: start, EndTime: end},
|
||||
CreatedBy: 9,
|
||||
}
|
||||
require.NoError(t, repo.CreateTask(context.Background(), task))
|
||||
require.NotZero(t, task.ID)
|
||||
|
||||
task2 := &service.UsageCleanupTask{
|
||||
Status: service.UsageCleanupStatusRunning,
|
||||
Filters: service.UsageCleanupFilters{StartTime: start.Add(-24 * time.Hour), EndTime: end.Add(-24 * time.Hour)},
|
||||
CreatedBy: 10,
|
||||
}
|
||||
require.NoError(t, repo.CreateTask(context.Background(), task2))
|
||||
|
||||
tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, tasks, 2)
|
||||
require.Equal(t, int64(2), result.Total)
|
||||
require.Greater(t, tasks[0].ID, tasks[1].ID)
|
||||
require.Equal(t, start, tasks[1].Filters.StartTime)
|
||||
require.Equal(t, end, tasks[1].Filters.EndTime)
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryEntListEmpty(t *testing.T) {
|
||||
repo, _ := newUsageCleanupEntRepo(t)
|
||||
|
||||
tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, tasks)
|
||||
require.Equal(t, int64(0), result.Total)
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryEntGetStatusAndProgress(t *testing.T) {
|
||||
repo, client := newUsageCleanupEntRepo(t)
|
||||
|
||||
task := &service.UsageCleanupTask{
|
||||
Status: service.UsageCleanupStatusPending,
|
||||
Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)},
|
||||
CreatedBy: 3,
|
||||
}
|
||||
require.NoError(t, repo.CreateTask(context.Background(), task))
|
||||
|
||||
status, err := repo.GetTaskStatus(context.Background(), task.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, service.UsageCleanupStatusPending, status)
|
||||
|
||||
_, err = repo.GetTaskStatus(context.Background(), task.ID+99)
|
||||
require.ErrorIs(t, err, sql.ErrNoRows)
|
||||
|
||||
require.NoError(t, repo.UpdateTaskProgress(context.Background(), task.ID, 42))
|
||||
loaded, err := client.UsageCleanupTask.Get(context.Background(), task.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(42), loaded.DeletedRows)
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryEntCancelAndFinish(t *testing.T) {
|
||||
repo, client := newUsageCleanupEntRepo(t)
|
||||
|
||||
task := &service.UsageCleanupTask{
|
||||
Status: service.UsageCleanupStatusPending,
|
||||
Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)},
|
||||
CreatedBy: 5,
|
||||
}
|
||||
require.NoError(t, repo.CreateTask(context.Background(), task))
|
||||
|
||||
ok, err := repo.CancelTask(context.Background(), task.ID, 7)
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
|
||||
loaded, err := client.UsageCleanupTask.Get(context.Background(), task.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, service.UsageCleanupStatusCanceled, loaded.Status)
|
||||
require.NotNil(t, loaded.CanceledBy)
|
||||
require.NotNil(t, loaded.CanceledAt)
|
||||
require.NotNil(t, loaded.FinishedAt)
|
||||
|
||||
loaded.Status = service.UsageCleanupStatusSucceeded
|
||||
_, err = client.UsageCleanupTask.Update().Where(dbusagecleanuptask.IDEQ(task.ID)).SetStatus(loaded.Status).Save(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
ok, err = repo.CancelTask(context.Background(), task.ID, 7)
|
||||
require.NoError(t, err)
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryEntCancelError(t *testing.T) {
|
||||
repo, client := newUsageCleanupEntRepo(t)
|
||||
|
||||
task := &service.UsageCleanupTask{
|
||||
Status: service.UsageCleanupStatusPending,
|
||||
Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)},
|
||||
CreatedBy: 5,
|
||||
}
|
||||
require.NoError(t, repo.CreateTask(context.Background(), task))
|
||||
|
||||
require.NoError(t, client.Close())
|
||||
_, err := repo.CancelTask(context.Background(), task.ID, 7)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryEntMarkResults(t *testing.T) {
|
||||
repo, client := newUsageCleanupEntRepo(t)
|
||||
|
||||
task := &service.UsageCleanupTask{
|
||||
Status: service.UsageCleanupStatusRunning,
|
||||
Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)},
|
||||
CreatedBy: 12,
|
||||
}
|
||||
require.NoError(t, repo.CreateTask(context.Background(), task))
|
||||
|
||||
require.NoError(t, repo.MarkTaskSucceeded(context.Background(), task.ID, 6))
|
||||
loaded, err := client.UsageCleanupTask.Get(context.Background(), task.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, service.UsageCleanupStatusSucceeded, loaded.Status)
|
||||
require.Equal(t, int64(6), loaded.DeletedRows)
|
||||
require.NotNil(t, loaded.FinishedAt)
|
||||
|
||||
task2 := &service.UsageCleanupTask{
|
||||
Status: service.UsageCleanupStatusRunning,
|
||||
Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)},
|
||||
CreatedBy: 12,
|
||||
}
|
||||
require.NoError(t, repo.CreateTask(context.Background(), task2))
|
||||
|
||||
require.NoError(t, repo.MarkTaskFailed(context.Background(), task2.ID, 4, "boom"))
|
||||
loaded2, err := client.UsageCleanupTask.Get(context.Background(), task2.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, service.UsageCleanupStatusFailed, loaded2.Status)
|
||||
require.Equal(t, "boom", *loaded2.ErrorMessage)
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryEntInvalidStatus(t *testing.T) {
|
||||
repo, _ := newUsageCleanupEntRepo(t)
|
||||
|
||||
task := &service.UsageCleanupTask{
|
||||
Status: "invalid",
|
||||
Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)},
|
||||
CreatedBy: 1,
|
||||
}
|
||||
require.Error(t, repo.CreateTask(context.Background(), task))
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryEntListInvalidFilters(t *testing.T) {
|
||||
repo, client := newUsageCleanupEntRepo(t)
|
||||
|
||||
now := time.Now().UTC()
|
||||
driver, ok := client.Driver().(*entsql.Driver)
|
||||
require.True(t, ok)
|
||||
_, err := driver.DB().ExecContext(
|
||||
context.Background(),
|
||||
`INSERT INTO usage_cleanup_tasks (status, filters, created_by, deleted_rows, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`,
|
||||
service.UsageCleanupStatusPending,
|
||||
[]byte("invalid-json"),
|
||||
int64(1),
|
||||
int64(0),
|
||||
now,
|
||||
now,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUsageCleanupTaskFromEntFull(t *testing.T) {
|
||||
start := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
errMsg := "failed"
|
||||
canceledBy := int64(2)
|
||||
canceledAt := start.Add(time.Minute)
|
||||
startedAt := start.Add(2 * time.Minute)
|
||||
finishedAt := start.Add(3 * time.Minute)
|
||||
filters := service.UsageCleanupFilters{StartTime: start, EndTime: end}
|
||||
filtersJSON, err := json.Marshal(filters)
|
||||
require.NoError(t, err)
|
||||
|
||||
task, err := usageCleanupTaskFromEnt(&dbent.UsageCleanupTask{
|
||||
ID: 10,
|
||||
Status: service.UsageCleanupStatusFailed,
|
||||
Filters: filtersJSON,
|
||||
CreatedBy: 11,
|
||||
DeletedRows: 7,
|
||||
ErrorMessage: &errMsg,
|
||||
CanceledBy: &canceledBy,
|
||||
CanceledAt: &canceledAt,
|
||||
StartedAt: &startedAt,
|
||||
FinishedAt: &finishedAt,
|
||||
CreatedAt: start,
|
||||
UpdatedAt: end,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(10), task.ID)
|
||||
require.Equal(t, service.UsageCleanupStatusFailed, task.Status)
|
||||
require.NotNil(t, task.ErrorMsg)
|
||||
require.NotNil(t, task.CanceledBy)
|
||||
require.NotNil(t, task.CanceledAt)
|
||||
require.NotNil(t, task.StartedAt)
|
||||
require.NotNil(t, task.FinishedAt)
|
||||
}
|
||||
|
||||
func TestUsageCleanupTaskFromEntInvalidFilters(t *testing.T) {
|
||||
task, err := usageCleanupTaskFromEnt(&dbent.UsageCleanupTask{
|
||||
Filters: json.RawMessage("invalid-json"),
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Empty(t, task)
|
||||
}
|
||||
482
backend/internal/repository/usage_cleanup_repo_test.go
Normal file
482
backend/internal/repository/usage_cleanup_repo_test.go
Normal file
@@ -0,0 +1,482 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newSQLMock(t *testing.T) (*sql.DB, sqlmock.Sqlmock) {
|
||||
t.Helper()
|
||||
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherRegexp))
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
return db, mock
|
||||
}
|
||||
|
||||
func TestNewUsageCleanupRepository(t *testing.T) {
|
||||
db, _ := newSQLMock(t)
|
||||
repo := NewUsageCleanupRepository(nil, db)
|
||||
require.NotNil(t, repo)
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryCreateTask(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
task := &service.UsageCleanupTask{
|
||||
Status: service.UsageCleanupStatusPending,
|
||||
Filters: service.UsageCleanupFilters{StartTime: start, EndTime: end},
|
||||
CreatedBy: 12,
|
||||
}
|
||||
now := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
mock.ExpectQuery("INSERT INTO usage_cleanup_tasks").
|
||||
WithArgs(task.Status, sqlmock.AnyArg(), task.CreatedBy, task.DeletedRows).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at", "updated_at"}).AddRow(int64(1), now, now))
|
||||
|
||||
err := repo.CreateTask(context.Background(), task)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), task.ID)
|
||||
require.Equal(t, now, task.CreatedAt)
|
||||
require.Equal(t, now, task.UpdatedAt)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryCreateTaskNil(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
err := repo.CreateTask(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryCreateTaskQueryError(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
task := &service.UsageCleanupTask{
|
||||
Status: service.UsageCleanupStatusPending,
|
||||
Filters: service.UsageCleanupFilters{StartTime: time.Now(), EndTime: time.Now().Add(time.Hour)},
|
||||
CreatedBy: 1,
|
||||
}
|
||||
|
||||
mock.ExpectQuery("INSERT INTO usage_cleanup_tasks").
|
||||
WithArgs(task.Status, sqlmock.AnyArg(), task.CreatedBy, task.DeletedRows).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
err := repo.CreateTask(context.Background(), task)
|
||||
require.Error(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryListTasksEmpty(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(0)))
|
||||
|
||||
tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, tasks)
|
||||
require.Equal(t, int64(0), result.Total)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryListTasks(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(2 * time.Hour)
|
||||
filters := service.UsageCleanupFilters{StartTime: start, EndTime: end}
|
||||
filtersJSON, err := json.Marshal(filters)
|
||||
require.NoError(t, err)
|
||||
|
||||
createdAt := time.Date(2024, 1, 2, 12, 0, 0, 0, time.UTC)
|
||||
updatedAt := createdAt.Add(time.Minute)
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
|
||||
"canceled_by", "canceled_at",
|
||||
"started_at", "finished_at", "created_at", "updated_at",
|
||||
}).AddRow(
|
||||
int64(1),
|
||||
service.UsageCleanupStatusSucceeded,
|
||||
filtersJSON,
|
||||
int64(2),
|
||||
int64(9),
|
||||
"error",
|
||||
nil,
|
||||
nil,
|
||||
start,
|
||||
end,
|
||||
createdAt,
|
||||
updatedAt,
|
||||
)
|
||||
|
||||
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(1)))
|
||||
mock.ExpectQuery("SELECT id, status, filters, created_by, deleted_rows, error_message").
|
||||
WithArgs(20, 0).
|
||||
WillReturnRows(rows)
|
||||
|
||||
tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, tasks, 1)
|
||||
require.Equal(t, int64(1), tasks[0].ID)
|
||||
require.Equal(t, service.UsageCleanupStatusSucceeded, tasks[0].Status)
|
||||
require.Equal(t, int64(2), tasks[0].CreatedBy)
|
||||
require.Equal(t, int64(9), tasks[0].DeletedRows)
|
||||
require.NotNil(t, tasks[0].ErrorMsg)
|
||||
require.Equal(t, "error", *tasks[0].ErrorMsg)
|
||||
require.NotNil(t, tasks[0].StartedAt)
|
||||
require.NotNil(t, tasks[0].FinishedAt)
|
||||
require.Equal(t, int64(1), result.Total)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryListTasksQueryError(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(2)))
|
||||
mock.ExpectQuery("SELECT id, status, filters, created_by, deleted_rows, error_message").
|
||||
WithArgs(20, 0).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
_, _, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
|
||||
require.Error(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryListTasksInvalidFilters(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
|
||||
"canceled_by", "canceled_at",
|
||||
"started_at", "finished_at", "created_at", "updated_at",
|
||||
}).AddRow(
|
||||
int64(1),
|
||||
service.UsageCleanupStatusSucceeded,
|
||||
[]byte("not-json"),
|
||||
int64(2),
|
||||
int64(9),
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
time.Now().UTC(),
|
||||
time.Now().UTC(),
|
||||
)
|
||||
|
||||
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(1)))
|
||||
mock.ExpectQuery("SELECT id, status, filters, created_by, deleted_rows, error_message").
|
||||
WithArgs(20, 0).
|
||||
WillReturnRows(rows)
|
||||
|
||||
_, _, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
|
||||
require.Error(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryClaimNextPendingTaskNone(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
|
||||
WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning).
|
||||
WillReturnRows(sqlmock.NewRows([]string{
|
||||
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
|
||||
"started_at", "finished_at", "created_at", "updated_at",
|
||||
}))
|
||||
|
||||
task, err := repo.ClaimNextPendingTask(context.Background(), 1800)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, task)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryClaimNextPendingTask(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
filters := service.UsageCleanupFilters{StartTime: start, EndTime: end}
|
||||
filtersJSON, err := json.Marshal(filters)
|
||||
require.NoError(t, err)
|
||||
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
|
||||
"started_at", "finished_at", "created_at", "updated_at",
|
||||
}).AddRow(
|
||||
int64(4),
|
||||
service.UsageCleanupStatusRunning,
|
||||
filtersJSON,
|
||||
int64(7),
|
||||
int64(0),
|
||||
nil,
|
||||
start,
|
||||
nil,
|
||||
start,
|
||||
start,
|
||||
)
|
||||
|
||||
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
|
||||
WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning).
|
||||
WillReturnRows(rows)
|
||||
|
||||
task, err := repo.ClaimNextPendingTask(context.Background(), 1800)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, task)
|
||||
require.Equal(t, int64(4), task.ID)
|
||||
require.Equal(t, service.UsageCleanupStatusRunning, task.Status)
|
||||
require.Equal(t, int64(7), task.CreatedBy)
|
||||
require.NotNil(t, task.StartedAt)
|
||||
require.Nil(t, task.ErrorMsg)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryClaimNextPendingTaskError(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
|
||||
WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
_, err := repo.ClaimNextPendingTask(context.Background(), 1800)
|
||||
require.Error(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryClaimNextPendingTaskInvalidFilters(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
|
||||
"started_at", "finished_at", "created_at", "updated_at",
|
||||
}).AddRow(
|
||||
int64(4),
|
||||
service.UsageCleanupStatusRunning,
|
||||
[]byte("invalid"),
|
||||
int64(7),
|
||||
int64(0),
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
time.Now().UTC(),
|
||||
time.Now().UTC(),
|
||||
)
|
||||
|
||||
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
|
||||
WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning).
|
||||
WillReturnRows(rows)
|
||||
|
||||
_, err := repo.ClaimNextPendingTask(context.Background(), 1800)
|
||||
require.Error(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryMarkTaskSucceeded(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
mock.ExpectExec("UPDATE usage_cleanup_tasks").
|
||||
WithArgs(service.UsageCleanupStatusSucceeded, int64(12), int64(9)).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
err := repo.MarkTaskSucceeded(context.Background(), 9, 12)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryMarkTaskFailed(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
mock.ExpectExec("UPDATE usage_cleanup_tasks").
|
||||
WithArgs(service.UsageCleanupStatusFailed, int64(4), "boom", int64(2)).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
err := repo.MarkTaskFailed(context.Background(), 2, 4, "boom")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryGetTaskStatus(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
mock.ExpectQuery("SELECT status FROM usage_cleanup_tasks").
|
||||
WithArgs(int64(9)).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow(service.UsageCleanupStatusPending))
|
||||
|
||||
status, err := repo.GetTaskStatus(context.Background(), 9)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, service.UsageCleanupStatusPending, status)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryGetTaskStatusQueryError(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
mock.ExpectQuery("SELECT status FROM usage_cleanup_tasks").
|
||||
WithArgs(int64(9)).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
_, err := repo.GetTaskStatus(context.Background(), 9)
|
||||
require.Error(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryUpdateTaskProgress(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
mock.ExpectExec("UPDATE usage_cleanup_tasks").
|
||||
WithArgs(int64(123), int64(8)).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
err := repo.UpdateTaskProgress(context.Background(), 8, 123)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryCancelTask(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
|
||||
WithArgs(service.UsageCleanupStatusCanceled, int64(6), int64(9), service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(int64(6)))
|
||||
|
||||
ok, err := repo.CancelTask(context.Background(), 6, 9)
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryCancelTaskNoRows(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
|
||||
WithArgs(service.UsageCleanupStatusCanceled, int64(6), int64(9), service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}))
|
||||
|
||||
ok, err := repo.CancelTask(context.Background(), 6, 9)
|
||||
require.NoError(t, err)
|
||||
require.False(t, ok)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryDeleteUsageLogsBatchMissingRange(t *testing.T) {
|
||||
db, _ := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
_, err := repo.DeleteUsageLogsBatch(context.Background(), service.UsageCleanupFilters{}, 10)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryDeleteUsageLogsBatch(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
userID := int64(3)
|
||||
model := " gpt-4 "
|
||||
filters := service.UsageCleanupFilters{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
UserID: &userID,
|
||||
Model: &model,
|
||||
}
|
||||
|
||||
mock.ExpectQuery("DELETE FROM usage_logs").
|
||||
WithArgs(start, end, userID, "gpt-4", 2).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(int64(1)).AddRow(int64(2)))
|
||||
|
||||
deleted, err := repo.DeleteUsageLogsBatch(context.Background(), filters, 2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), deleted)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryDeleteUsageLogsBatchQueryError(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
filters := service.UsageCleanupFilters{StartTime: start, EndTime: end}
|
||||
|
||||
mock.ExpectQuery("DELETE FROM usage_logs").
|
||||
WithArgs(start, end, 5).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
_, err := repo.DeleteUsageLogsBatch(context.Background(), filters, 5)
|
||||
require.Error(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestBuildUsageCleanupWhere(t *testing.T) {
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
userID := int64(1)
|
||||
apiKeyID := int64(2)
|
||||
accountID := int64(3)
|
||||
groupID := int64(4)
|
||||
model := " gpt-4 "
|
||||
stream := true
|
||||
billingType := int8(2)
|
||||
|
||||
where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
UserID: &userID,
|
||||
APIKeyID: &apiKeyID,
|
||||
AccountID: &accountID,
|
||||
GroupID: &groupID,
|
||||
Model: &model,
|
||||
Stream: &stream,
|
||||
BillingType: &billingType,
|
||||
})
|
||||
|
||||
require.Equal(t, "created_at >= $1 AND created_at <= $2 AND user_id = $3 AND api_key_id = $4 AND account_id = $5 AND group_id = $6 AND model = $7 AND stream = $8 AND billing_type = $9", where)
|
||||
require.Equal(t, []any{start, end, userID, apiKeyID, accountID, groupID, "gpt-4", stream, billingType}, args)
|
||||
}
|
||||
|
||||
func TestBuildUsageCleanupWhereModelEmpty(t *testing.T) {
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
model := " "
|
||||
|
||||
where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Model: &model,
|
||||
})
|
||||
|
||||
require.Equal(t, "created_at >= $1 AND created_at <= $2", where)
|
||||
require.Equal(t, []any{start, end}, args)
|
||||
}
|
||||
@@ -1411,7 +1411,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
|
||||
}
|
||||
|
||||
// GetUsageTrendWithFilters returns usage trend data with optional filters
|
||||
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) (results []TrendDataPoint, err error) {
|
||||
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
@@ -1456,6 +1456,10 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
||||
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
|
||||
args = append(args, *stream)
|
||||
}
|
||||
if billingType != nil {
|
||||
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
|
||||
args = append(args, int16(*billingType))
|
||||
}
|
||||
query += " GROUP BY date ORDER BY date ASC"
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||
@@ -1479,7 +1483,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
||||
}
|
||||
|
||||
// GetModelStatsWithFilters returns model statistics with optional filters
|
||||
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) (results []ModelStat, err error) {
|
||||
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) (results []ModelStat, err error) {
|
||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
|
||||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||
@@ -1520,6 +1524,10 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
|
||||
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
|
||||
args = append(args, *stream)
|
||||
}
|
||||
if billingType != nil {
|
||||
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
|
||||
args = append(args, int16(*billingType))
|
||||
}
|
||||
query += " GROUP BY model ORDER BY total_tokens DESC"
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||
@@ -1825,7 +1833,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
}
|
||||
}
|
||||
|
||||
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil)
|
||||
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil)
|
||||
if err != nil {
|
||||
models = []ModelStat{}
|
||||
}
|
||||
|
||||
@@ -944,17 +944,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
|
||||
endTime := base.Add(48 * time.Hour)
|
||||
|
||||
// Test with user filter
|
||||
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil)
|
||||
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil)
|
||||
s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
|
||||
s.Require().Len(trend, 2)
|
||||
|
||||
// Test with apiKey filter
|
||||
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil)
|
||||
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil)
|
||||
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
|
||||
s.Require().Len(trend, 2)
|
||||
|
||||
// Test with both filters
|
||||
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil)
|
||||
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil)
|
||||
s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
|
||||
s.Require().Len(trend, 2)
|
||||
}
|
||||
@@ -971,7 +971,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(3 * time.Hour)
|
||||
|
||||
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil)
|
||||
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil)
|
||||
s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
|
||||
s.Require().Len(trend, 2)
|
||||
}
|
||||
@@ -1017,17 +1017,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
|
||||
// Test with user filter
|
||||
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil)
|
||||
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil)
|
||||
s.Require().NoError(err, "GetModelStatsWithFilters user filter")
|
||||
s.Require().Len(stats, 2)
|
||||
|
||||
// Test with apiKey filter
|
||||
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil)
|
||||
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil)
|
||||
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
|
||||
s.Require().Len(stats, 2)
|
||||
|
||||
// Test with account filter
|
||||
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil)
|
||||
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil)
|
||||
s.Require().NoError(err, "GetModelStatsWithFilters account filter")
|
||||
s.Require().Len(stats, 2)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
@@ -189,6 +190,7 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
|
||||
dbuser.Or(
|
||||
dbuser.EmailContainsFold(filters.Search),
|
||||
dbuser.UsernameContainsFold(filters.Search),
|
||||
dbuser.NotesContainsFold(filters.Search),
|
||||
),
|
||||
)
|
||||
}
|
||||
@@ -466,3 +468,46 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
|
||||
dst.CreatedAt = src.CreatedAt
|
||||
dst.UpdatedAt = src.UpdatedAt
|
||||
}
|
||||
|
||||
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
|
||||
func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
update := client.User.UpdateOneID(userID)
|
||||
if encryptedSecret == nil {
|
||||
update = update.ClearTotpSecretEncrypted()
|
||||
} else {
|
||||
update = update.SetTotpSecretEncrypted(*encryptedSecret)
|
||||
}
|
||||
_, err := update.Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnableTotp 启用用户的 TOTP 双因素认证
|
||||
func (r *userRepository) EnableTotp(ctx context.Context, userID int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.User.UpdateOneID(userID).
|
||||
SetTotpEnabled(true).
|
||||
SetTotpEnabledAt(time.Now()).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DisableTotp 禁用用户的 TOTP 双因素认证
|
||||
func (r *userRepository) DisableTotp(ctx context.Context, userID int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.User.UpdateOneID(userID).
|
||||
SetTotpEnabled(false).
|
||||
ClearTotpEnabledAt().
|
||||
ClearTotpSecretEncrypted().
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -190,7 +190,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
|
||||
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
q := client.UserSubscription.Query()
|
||||
if userID != nil {
|
||||
@@ -199,7 +199,31 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
|
||||
if groupID != nil {
|
||||
q = q.Where(usersubscription.GroupIDEQ(*groupID))
|
||||
}
|
||||
if status != "" {
|
||||
|
||||
// Status filtering with real-time expiration check
|
||||
now := time.Now()
|
||||
switch status {
|
||||
case service.SubscriptionStatusActive:
|
||||
// Active: status is active AND not yet expired
|
||||
q = q.Where(
|
||||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||||
usersubscription.ExpiresAtGT(now),
|
||||
)
|
||||
case service.SubscriptionStatusExpired:
|
||||
// Expired: status is expired OR (status is active but already expired)
|
||||
q = q.Where(
|
||||
usersubscription.Or(
|
||||
usersubscription.StatusEQ(service.SubscriptionStatusExpired),
|
||||
usersubscription.And(
|
||||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||||
usersubscription.ExpiresAtLTE(now),
|
||||
),
|
||||
),
|
||||
)
|
||||
case "":
|
||||
// No filter
|
||||
default:
|
||||
// Other status (e.g., revoked)
|
||||
q = q.Where(usersubscription.StatusEQ(status))
|
||||
}
|
||||
|
||||
@@ -208,11 +232,28 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Apply sorting
|
||||
q = q.WithUser().WithGroup().WithAssignedByUser()
|
||||
|
||||
// Determine sort field
|
||||
var field string
|
||||
switch sortBy {
|
||||
case "expires_at":
|
||||
field = usersubscription.FieldExpiresAt
|
||||
case "status":
|
||||
field = usersubscription.FieldStatus
|
||||
default:
|
||||
field = usersubscription.FieldCreatedAt
|
||||
}
|
||||
|
||||
// Determine sort order (default: desc)
|
||||
if sortOrder == "asc" && sortBy != "" {
|
||||
q = q.Order(dbent.Asc(field))
|
||||
} else {
|
||||
q = q.Order(dbent.Desc(field))
|
||||
}
|
||||
|
||||
subs, err := q.
|
||||
WithUser().
|
||||
WithGroup().
|
||||
WithAssignedByUser().
|
||||
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
All(ctx)
|
||||
|
||||
@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
|
||||
group := s.mustCreateGroup("g-list")
|
||||
s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "")
|
||||
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "")
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
|
||||
s.mustCreateSubscription(user1.ID, group.ID, nil)
|
||||
s.mustCreateSubscription(user2.ID, group.ID, nil)
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "")
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(user1.ID, subs[0].UserID)
|
||||
@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
|
||||
s.mustCreateSubscription(user.ID, g1.ID, nil)
|
||||
s.mustCreateSubscription(user.ID, g2.ID, nil)
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "")
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(g1.ID, subs[0].GroupID)
|
||||
@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
|
||||
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
|
||||
})
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired)
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)
|
||||
|
||||
@@ -56,7 +56,10 @@ var ProviderSet = wire.NewSet(
|
||||
NewProxyRepository,
|
||||
NewRedeemCodeRepository,
|
||||
NewPromoCodeRepository,
|
||||
NewAnnouncementRepository,
|
||||
NewAnnouncementReadRepository,
|
||||
NewUsageLogRepository,
|
||||
NewUsageCleanupRepository,
|
||||
NewDashboardAggregationRepository,
|
||||
NewSettingRepository,
|
||||
NewOpsRepository,
|
||||
@@ -81,6 +84,10 @@ var ProviderSet = wire.NewSet(
|
||||
NewSchedulerCache,
|
||||
NewSchedulerOutboxRepository,
|
||||
NewProxyLatencyCache,
|
||||
NewTotpCache,
|
||||
|
||||
// Encryptors
|
||||
NewAESEncryptor,
|
||||
|
||||
// HTTP service ports (DI Strategy A: return interface directly)
|
||||
NewTurnstileVerifier,
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -52,7 +51,6 @@ func TestAPIContracts(t *testing.T) {
|
||||
"id": 1,
|
||||
"email": "alice@example.com",
|
||||
"username": "alice",
|
||||
"notes": "hello",
|
||||
"role": "user",
|
||||
"balance": 12.5,
|
||||
"concurrency": 5,
|
||||
@@ -132,6 +130,153 @@ func TestAPIContracts(t *testing.T) {
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "GET /api/v1/groups/available",
|
||||
setup: func(t *testing.T, deps *contractDeps) {
|
||||
t.Helper()
|
||||
// 普通用户可见的分组列表不应包含内部字段(如 model_routing/account_count)。
|
||||
deps.groupRepo.SetActive([]service.Group{
|
||||
{
|
||||
ID: 10,
|
||||
Name: "Group One",
|
||||
Description: "desc",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.5,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
ModelRoutingEnabled: true,
|
||||
ModelRouting: map[string][]int64{
|
||||
"claude-3-*": []int64{101, 102},
|
||||
},
|
||||
AccountCount: 2,
|
||||
CreatedAt: deps.now,
|
||||
UpdatedAt: deps.now,
|
||||
},
|
||||
})
|
||||
deps.userSubRepo.SetActiveByUserID(1, nil)
|
||||
},
|
||||
method: http.MethodGet,
|
||||
path: "/api/v1/groups/available",
|
||||
wantStatus: http.StatusOK,
|
||||
wantJSON: `{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": [
|
||||
{
|
||||
"id": 10,
|
||||
"name": "Group One",
|
||||
"description": "desc",
|
||||
"platform": "anthropic",
|
||||
"rate_multiplier": 1.5,
|
||||
"is_exclusive": false,
|
||||
"status": "active",
|
||||
"subscription_type": "standard",
|
||||
"daily_limit_usd": null,
|
||||
"weekly_limit_usd": null,
|
||||
"monthly_limit_usd": null,
|
||||
"image_price_1k": null,
|
||||
"image_price_2k": null,
|
||||
"image_price_4k": null,
|
||||
"claude_code_only": false,
|
||||
"fallback_group_id": null,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"updated_at": "2025-01-02T03:04:05Z"
|
||||
}
|
||||
]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "GET /api/v1/subscriptions",
|
||||
setup: func(t *testing.T, deps *contractDeps) {
|
||||
t.Helper()
|
||||
// 普通用户订阅接口不应包含 assigned_* / notes 等管理员字段。
|
||||
deps.userSubRepo.SetByUserID(1, []service.UserSubscription{
|
||||
{
|
||||
ID: 501,
|
||||
UserID: 1,
|
||||
GroupID: 10,
|
||||
StartsAt: deps.now,
|
||||
ExpiresAt: time.Date(2099, 1, 2, 3, 4, 5, 0, time.UTC), // 使用未来日期避免 normalizeSubscriptionStatus 标记为过期
|
||||
Status: service.SubscriptionStatusActive,
|
||||
DailyUsageUSD: 1.23,
|
||||
WeeklyUsageUSD: 2.34,
|
||||
MonthlyUsageUSD: 3.45,
|
||||
AssignedBy: ptr(int64(999)),
|
||||
AssignedAt: deps.now,
|
||||
Notes: "admin-note",
|
||||
CreatedAt: deps.now,
|
||||
UpdatedAt: deps.now,
|
||||
},
|
||||
})
|
||||
},
|
||||
method: http.MethodGet,
|
||||
path: "/api/v1/subscriptions",
|
||||
wantStatus: http.StatusOK,
|
||||
wantJSON: `{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": [
|
||||
{
|
||||
"id": 501,
|
||||
"user_id": 1,
|
||||
"group_id": 10,
|
||||
"starts_at": "2025-01-02T03:04:05Z",
|
||||
"expires_at": "2099-01-02T03:04:05Z",
|
||||
"status": "active",
|
||||
"daily_window_start": null,
|
||||
"weekly_window_start": null,
|
||||
"monthly_window_start": null,
|
||||
"daily_usage_usd": 1.23,
|
||||
"weekly_usage_usd": 2.34,
|
||||
"monthly_usage_usd": 3.45,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"updated_at": "2025-01-02T03:04:05Z"
|
||||
}
|
||||
]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "GET /api/v1/redeem/history",
|
||||
setup: func(t *testing.T, deps *contractDeps) {
|
||||
t.Helper()
|
||||
// 普通用户兑换历史不应包含 notes 等内部字段。
|
||||
deps.redeemRepo.SetByUser(1, []service.RedeemCode{
|
||||
{
|
||||
ID: 900,
|
||||
Code: "CODE-123",
|
||||
Type: service.RedeemTypeBalance,
|
||||
Value: 1.25,
|
||||
Status: service.StatusUsed,
|
||||
UsedBy: ptr(int64(1)),
|
||||
UsedAt: ptr(deps.now),
|
||||
Notes: "internal-note",
|
||||
CreatedAt: deps.now,
|
||||
},
|
||||
})
|
||||
},
|
||||
method: http.MethodGet,
|
||||
path: "/api/v1/redeem/history",
|
||||
wantStatus: http.StatusOK,
|
||||
wantJSON: `{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": [
|
||||
{
|
||||
"id": 900,
|
||||
"code": "CODE-123",
|
||||
"type": "balance",
|
||||
"value": 1.25,
|
||||
"status": "used",
|
||||
"used_by": 1,
|
||||
"used_at": "2025-01-02T03:04:05Z",
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"group_id": null,
|
||||
"validity_days": 0
|
||||
}
|
||||
]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "GET /api/v1/usage/stats",
|
||||
setup: func(t *testing.T, deps *contractDeps) {
|
||||
@@ -191,24 +336,25 @@ func TestAPIContracts(t *testing.T) {
|
||||
t.Helper()
|
||||
deps.usageRepo.SetUserLogs(1, []service.UsageLog{
|
||||
{
|
||||
ID: 1,
|
||||
UserID: 1,
|
||||
APIKeyID: 100,
|
||||
AccountID: 200,
|
||||
RequestID: "req_123",
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
CacheCreationTokens: 1,
|
||||
CacheReadTokens: 2,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
RateMultiplier: 1,
|
||||
BillingType: service.BillingTypeBalance,
|
||||
Stream: true,
|
||||
DurationMs: ptr(100),
|
||||
FirstTokenMs: ptr(50),
|
||||
CreatedAt: deps.now,
|
||||
ID: 1,
|
||||
UserID: 1,
|
||||
APIKeyID: 100,
|
||||
AccountID: 200,
|
||||
AccountRateMultiplier: ptr(0.5),
|
||||
RequestID: "req_123",
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
CacheCreationTokens: 1,
|
||||
CacheReadTokens: 2,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
RateMultiplier: 1,
|
||||
BillingType: service.BillingTypeBalance,
|
||||
Stream: true,
|
||||
DurationMs: ptr(100),
|
||||
FirstTokenMs: ptr(50),
|
||||
CreatedAt: deps.now,
|
||||
},
|
||||
})
|
||||
},
|
||||
@@ -239,10 +385,9 @@ func TestAPIContracts(t *testing.T) {
|
||||
"output_cost": 0,
|
||||
"cache_creation_cost": 0,
|
||||
"cache_read_cost": 0,
|
||||
"total_cost": 0.5,
|
||||
"total_cost": 0.5,
|
||||
"actual_cost": 0.5,
|
||||
"rate_multiplier": 1,
|
||||
"account_rate_multiplier": null,
|
||||
"billing_type": 0,
|
||||
"stream": true,
|
||||
"duration_ms": 100,
|
||||
@@ -267,6 +412,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
deps.settingRepo.SetAll(map[string]string{
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
service.SettingKeyEmailVerifyEnabled: "false",
|
||||
service.SettingKeyPromoCodeEnabled: "true",
|
||||
|
||||
service.SettingKeySMTPHost: "smtp.example.com",
|
||||
service.SettingKeySMTPPort: "587",
|
||||
@@ -305,6 +451,10 @@ func TestAPIContracts(t *testing.T) {
|
||||
"data": {
|
||||
"registration_enabled": true,
|
||||
"email_verify_enabled": false,
|
||||
"promo_code_enabled": true,
|
||||
"password_reset_enabled": false,
|
||||
"totp_enabled": false,
|
||||
"totp_encryption_key_configured": false,
|
||||
"smtp_host": "smtp.example.com",
|
||||
"smtp_port": 587,
|
||||
"smtp_username": "user",
|
||||
@@ -338,45 +488,10 @@ func TestAPIContracts(t *testing.T) {
|
||||
"fallback_model_openai": "gpt-4o",
|
||||
"enable_identity_patch": true,
|
||||
"identity_patch_prompt": "",
|
||||
"home_content": ""
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "POST /api/v1/admin/accounts/lookup",
|
||||
setup: func(t *testing.T, deps *contractDeps) {
|
||||
t.Helper()
|
||||
deps.accountRepo.lookupAccounts = []service.Account{
|
||||
{
|
||||
ID: 101,
|
||||
Name: "Alice Account",
|
||||
Platform: "antigravity",
|
||||
Credentials: map[string]any{
|
||||
"email": "alice@example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
method: http.MethodPost,
|
||||
path: "/api/v1/admin/accounts/lookup",
|
||||
body: `{"platform":"antigravity","emails":["Alice@Example.com","bob@example.com"]}`,
|
||||
headers: map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
wantStatus: http.StatusOK,
|
||||
wantJSON: `{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": {
|
||||
"matched": [
|
||||
{
|
||||
"email": "alice@example.com",
|
||||
"account_id": 101,
|
||||
"platform": "antigravity",
|
||||
"name": "Alice Account"
|
||||
}
|
||||
],
|
||||
"missing": ["bob@example.com"]
|
||||
"home_content": "",
|
||||
"hide_ccs_import_button": false,
|
||||
"purchase_subscription_enabled": false,
|
||||
"purchase_subscription_url": ""
|
||||
}
|
||||
}`,
|
||||
},
|
||||
@@ -424,9 +539,11 @@ type contractDeps struct {
|
||||
now time.Time
|
||||
router http.Handler
|
||||
apiKeyRepo *stubApiKeyRepo
|
||||
groupRepo *stubGroupRepo
|
||||
userSubRepo *stubUserSubscriptionRepo
|
||||
usageRepo *stubUsageLogRepo
|
||||
settingRepo *stubSettingRepo
|
||||
accountRepo *stubAccountRepo
|
||||
redeemRepo *stubRedeemCodeRepo
|
||||
}
|
||||
|
||||
func newContractDeps(t *testing.T) *contractDeps {
|
||||
@@ -454,11 +571,11 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
|
||||
apiKeyRepo := newStubApiKeyRepo(now)
|
||||
apiKeyCache := stubApiKeyCache{}
|
||||
groupRepo := stubGroupRepo{}
|
||||
userSubRepo := stubUserSubscriptionRepo{}
|
||||
groupRepo := &stubGroupRepo{}
|
||||
userSubRepo := &stubUserSubscriptionRepo{}
|
||||
accountRepo := stubAccountRepo{}
|
||||
proxyRepo := stubProxyRepo{}
|
||||
redeemRepo := stubRedeemCodeRepo{}
|
||||
redeemRepo := &stubRedeemCodeRepo{}
|
||||
|
||||
cfg := &config.Config{
|
||||
Default: config.DefaultConfig{
|
||||
@@ -473,15 +590,21 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
usageRepo := newStubUsageLogRepo()
|
||||
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
|
||||
|
||||
subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil)
|
||||
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||
|
||||
redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil)
|
||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||
|
||||
settingRepo := newStubSettingRepo()
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, nil)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
|
||||
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
jwtAuth := func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
|
||||
@@ -512,25 +635,35 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
v1Keys.Use(jwtAuth)
|
||||
v1Keys.GET("/keys", apiKeyHandler.List)
|
||||
v1Keys.POST("/keys", apiKeyHandler.Create)
|
||||
v1Keys.GET("/groups/available", apiKeyHandler.GetAvailableGroups)
|
||||
|
||||
v1Usage := v1.Group("")
|
||||
v1Usage.Use(jwtAuth)
|
||||
v1Usage.GET("/usage", usageHandler.List)
|
||||
v1Usage.GET("/usage/stats", usageHandler.Stats)
|
||||
|
||||
v1Subs := v1.Group("")
|
||||
v1Subs.Use(jwtAuth)
|
||||
v1Subs.GET("/subscriptions", subscriptionHandler.List)
|
||||
|
||||
v1Redeem := v1.Group("")
|
||||
v1Redeem.Use(jwtAuth)
|
||||
v1Redeem.GET("/redeem/history", redeemHandler.GetHistory)
|
||||
|
||||
v1Admin := v1.Group("/admin")
|
||||
v1Admin.Use(adminAuth)
|
||||
v1Admin.GET("/settings", adminSettingHandler.GetSettings)
|
||||
v1Admin.POST("/accounts/bulk-update", adminAccountHandler.BulkUpdate)
|
||||
v1Admin.POST("/accounts/lookup", adminAccountHandler.Lookup)
|
||||
|
||||
return &contractDeps{
|
||||
now: now,
|
||||
router: r,
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
groupRepo: groupRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
usageRepo: usageRepo,
|
||||
settingRepo: settingRepo,
|
||||
accountRepo: &accountRepo,
|
||||
redeemRepo: redeemRepo,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -626,6 +759,18 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
type stubApiKeyCache struct{}
|
||||
|
||||
func (stubApiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
@@ -660,7 +805,21 @@ func (stubApiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type stubGroupRepo struct{}
|
||||
func (stubApiKeyCache) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (stubApiKeyCache) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type stubGroupRepo struct {
|
||||
active []service.Group
|
||||
}
|
||||
|
||||
func (r *stubGroupRepo) SetActive(groups []service.Group) {
|
||||
r.active = append([]service.Group(nil), groups...)
|
||||
}
|
||||
|
||||
func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error {
|
||||
return errors.New("not implemented")
|
||||
@@ -694,12 +853,19 @@ func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.Pagi
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
func (r *stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) {
|
||||
return append([]service.Group(nil), r.active...), nil
|
||||
}
|
||||
|
||||
func (stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
|
||||
out := make([]service.Group, 0, len(r.active))
|
||||
for i := range r.active {
|
||||
g := r.active[i]
|
||||
if g.Platform == platform {
|
||||
out = append(out, g)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
@@ -715,8 +881,7 @@ func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID i
|
||||
}
|
||||
|
||||
type stubAccountRepo struct {
|
||||
bulkUpdateIDs []int64
|
||||
lookupAccounts []service.Account
|
||||
bulkUpdateIDs []int64
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) Create(ctx context.Context, account *service.Account) error {
|
||||
@@ -767,36 +932,6 @@ func (s *stubAccountRepo) ListByPlatform(ctx context.Context, platform string) (
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) ListByPlatformAndCredentialEmails(ctx context.Context, platform string, emails []string) ([]service.Account, error) {
|
||||
if len(s.lookupAccounts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
emailSet := make(map[string]struct{}, len(emails))
|
||||
for _, email := range emails {
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
emailSet[normalized] = struct{}{}
|
||||
}
|
||||
var matches []service.Account
|
||||
for i := range s.lookupAccounts {
|
||||
account := &s.lookupAccounts[i]
|
||||
if account.Platform != platform {
|
||||
continue
|
||||
}
|
||||
accountEmail := strings.ToLower(strings.TrimSpace(account.GetCredential("email")))
|
||||
if accountEmail == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := emailSet[accountEmail]; !ok {
|
||||
continue
|
||||
}
|
||||
matches = append(matches, *account)
|
||||
}
|
||||
return matches, nil
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) UpdateLastUsed(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
@@ -948,7 +1083,16 @@ func (stubProxyRepo) ListAccountSummariesByProxyID(ctx context.Context, proxyID
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type stubRedeemCodeRepo struct{}
|
||||
type stubRedeemCodeRepo struct {
|
||||
byUser map[int64][]service.RedeemCode
|
||||
}
|
||||
|
||||
func (r *stubRedeemCodeRepo) SetByUser(userID int64, codes []service.RedeemCode) {
|
||||
if r.byUser == nil {
|
||||
r.byUser = make(map[int64][]service.RedeemCode)
|
||||
}
|
||||
r.byUser[userID] = append([]service.RedeemCode(nil), codes...)
|
||||
}
|
||||
|
||||
func (stubRedeemCodeRepo) Create(ctx context.Context, code *service.RedeemCode) error {
|
||||
return errors.New("not implemented")
|
||||
@@ -986,11 +1130,35 @@ func (stubRedeemCodeRepo) ListWithFilters(ctx context.Context, params pagination
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
func (r *stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) {
|
||||
if r.byUser == nil {
|
||||
return nil, nil
|
||||
}
|
||||
codes := r.byUser[userID]
|
||||
if limit > 0 && len(codes) > limit {
|
||||
codes = codes[:limit]
|
||||
}
|
||||
return append([]service.RedeemCode(nil), codes...), nil
|
||||
}
|
||||
|
||||
type stubUserSubscriptionRepo struct{}
|
||||
type stubUserSubscriptionRepo struct {
|
||||
byUser map[int64][]service.UserSubscription
|
||||
activeByUser map[int64][]service.UserSubscription
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) SetByUserID(userID int64, subs []service.UserSubscription) {
|
||||
if r.byUser == nil {
|
||||
r.byUser = make(map[int64][]service.UserSubscription)
|
||||
}
|
||||
r.byUser[userID] = append([]service.UserSubscription(nil), subs...)
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) SetActiveByUserID(userID int64, subs []service.UserSubscription) {
|
||||
if r.activeByUser == nil {
|
||||
r.activeByUser = make(map[int64][]service.UserSubscription)
|
||||
}
|
||||
r.activeByUser[userID] = append([]service.UserSubscription(nil), subs...)
|
||||
}
|
||||
|
||||
func (stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
|
||||
return errors.New("not implemented")
|
||||
@@ -1010,16 +1178,22 @@ func (stubUserSubscriptionRepo) Update(ctx context.Context, sub *service.UserSub
|
||||
func (stubUserSubscriptionRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
func (r *stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
if r.byUser == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return append([]service.UserSubscription(nil), r.byUser[userID]...), nil
|
||||
}
|
||||
func (stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
if r.activeByUser == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return append([]service.UserSubscription(nil), r.activeByUser[userID]...), nil
|
||||
}
|
||||
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
||||
@@ -1319,11 +1493,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) {
|
||||
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) {
|
||||
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
@@ -367,7 +367,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
@@ -29,6 +29,9 @@ func RegisterAdminRoutes(
|
||||
// 账号管理
|
||||
registerAccountRoutes(admin, h)
|
||||
|
||||
// 公告管理
|
||||
registerAnnouncementRoutes(admin, h)
|
||||
|
||||
// OpenAI OAuth
|
||||
registerOpenAIOAuthRoutes(admin, h)
|
||||
|
||||
@@ -197,7 +200,6 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts := admin.Group("/accounts")
|
||||
{
|
||||
accounts.GET("", h.Admin.Account.List)
|
||||
accounts.POST("/lookup", h.Admin.Account.Lookup)
|
||||
accounts.GET("/:id", h.Admin.Account.GetByID)
|
||||
accounts.POST("", h.Admin.Account.Create)
|
||||
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
|
||||
@@ -230,6 +232,18 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
}
|
||||
}
|
||||
|
||||
func registerAnnouncementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
announcements := admin.Group("/announcements")
|
||||
{
|
||||
announcements.GET("", h.Admin.Announcement.List)
|
||||
announcements.POST("", h.Admin.Announcement.Create)
|
||||
announcements.GET("/:id", h.Admin.Announcement.GetByID)
|
||||
announcements.PUT("/:id", h.Admin.Announcement.Update)
|
||||
announcements.DELETE("/:id", h.Admin.Announcement.Delete)
|
||||
announcements.GET("/:id/read-status", h.Admin.Announcement.ListReadStatus)
|
||||
}
|
||||
}
|
||||
|
||||
func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
openai := admin.Group("/openai")
|
||||
{
|
||||
@@ -355,6 +369,9 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
usage.GET("/stats", h.Admin.Usage.Stats)
|
||||
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
|
||||
usage.GET("/search-api-keys", h.Admin.Usage.SearchAPIKeys)
|
||||
usage.GET("/cleanup-tasks", h.Admin.Usage.ListCleanupTasks)
|
||||
usage.POST("/cleanup-tasks", h.Admin.Usage.CreateCleanupTask)
|
||||
usage.POST("/cleanup-tasks/:id/cancel", h.Admin.Usage.CancelCleanupTask)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -26,11 +26,20 @@ func RegisterAuthRoutes(
|
||||
{
|
||||
auth.POST("/register", h.Auth.Register)
|
||||
auth.POST("/login", h.Auth.Login)
|
||||
auth.POST("/login/2fa", h.Auth.Login2FA)
|
||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
||||
// 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
|
||||
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.ValidatePromoCode)
|
||||
// 忘记密码接口添加速率限制:每分钟最多 5 次(Redis 故障时 fail-close)
|
||||
auth.POST("/forgot-password", rateLimiter.LimitWithOptions("forgot-password", 5, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.ForgotPassword)
|
||||
// 重置密码接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
|
||||
auth.POST("/reset-password", rateLimiter.LimitWithOptions("reset-password", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.ResetPassword)
|
||||
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
||||
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
|
||||
}
|
||||
|
||||
@@ -22,6 +22,17 @@ func RegisterUserRoutes(
|
||||
user.GET("/profile", h.User.GetProfile)
|
||||
user.PUT("/password", h.User.ChangePassword)
|
||||
user.PUT("", h.User.UpdateProfile)
|
||||
|
||||
// TOTP 双因素认证
|
||||
totp := user.Group("/totp")
|
||||
{
|
||||
totp.GET("/status", h.Totp.GetStatus)
|
||||
totp.GET("/verification-method", h.Totp.GetVerificationMethod)
|
||||
totp.POST("/send-code", h.Totp.SendVerifyCode)
|
||||
totp.POST("/setup", h.Totp.InitiateSetup)
|
||||
totp.POST("/enable", h.Totp.Enable)
|
||||
totp.POST("/disable", h.Totp.Disable)
|
||||
}
|
||||
}
|
||||
|
||||
// API Key管理
|
||||
@@ -53,6 +64,13 @@ func RegisterUserRoutes(
|
||||
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardAPIKeysUsage)
|
||||
}
|
||||
|
||||
// 公告(用户可见)
|
||||
announcements := authenticated.Group("/announcements")
|
||||
{
|
||||
announcements.GET("", h.Announcement.List)
|
||||
announcements.POST("/:id/read", h.Announcement.MarkRead)
|
||||
}
|
||||
|
||||
// 卡密兑换
|
||||
redeem := authenticated.Group("/redeem")
|
||||
{
|
||||
|
||||
@@ -197,6 +197,35 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCredentialAsInt64 解析凭证中的 int64 字段
|
||||
// 用于读取 _token_version 等内部字段
|
||||
func (a *Account) GetCredentialAsInt64(key string) int64 {
|
||||
if a == nil || a.Credentials == nil {
|
||||
return 0
|
||||
}
|
||||
val, ok := a.Credentials[key]
|
||||
if !ok || val == nil {
|
||||
return 0
|
||||
}
|
||||
switch v := val.(type) {
|
||||
case int64:
|
||||
return v
|
||||
case float64:
|
||||
return int64(v)
|
||||
case int:
|
||||
return int64(v)
|
||||
case json.Number:
|
||||
if i, err := v.Int64(); err == nil {
|
||||
return i
|
||||
}
|
||||
case string:
|
||||
if i, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (a *Account) IsTempUnschedulableEnabled() bool {
|
||||
if a.Credentials == nil {
|
||||
return false
|
||||
@@ -576,6 +605,44 @@ func (a *Account) IsAnthropicOAuthOrSetupToken() bool {
|
||||
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken)
|
||||
}
|
||||
|
||||
// IsTLSFingerprintEnabled 检查是否启用 TLS 指纹伪装
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
|
||||
// 启用后将模拟 Claude Code (Node.js) 客户端的 TLS 握手特征
|
||||
func (a *Account) IsTLSFingerprintEnabled() bool {
|
||||
// 仅支持 Anthropic OAuth/SetupToken 账号
|
||||
if !a.IsAnthropicOAuthOrSetupToken() {
|
||||
return false
|
||||
}
|
||||
if a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Extra["enable_tls_fingerprint"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsSessionIDMaskingEnabled 检查是否启用会话ID伪装
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
|
||||
// 启用后将在一段时间内(15分钟)固定 metadata.user_id 中的 session ID,
|
||||
// 使上游认为请求来自同一个会话
|
||||
func (a *Account) IsSessionIDMaskingEnabled() bool {
|
||||
if !a.IsAnthropicOAuthOrSetupToken() {
|
||||
return false
|
||||
}
|
||||
if a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Extra["session_id_masking_enabled"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
|
||||
// 返回 0 表示未启用
|
||||
func (a *Account) GetWindowCostLimit() float64 {
|
||||
@@ -652,6 +719,23 @@ func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) Windo
|
||||
return WindowCostNotSchedulable
|
||||
}
|
||||
|
||||
// GetCurrentWindowStartTime 获取当前有效的窗口开始时间
|
||||
// 逻辑:
|
||||
// 1. 如果窗口未过期(SessionWindowEnd 存在且在当前时间之后),使用记录的 SessionWindowStart
|
||||
// 2. 否则(窗口过期或未设置),使用新的预测窗口开始时间(从当前整点开始)
|
||||
func (a *Account) GetCurrentWindowStartTime() time.Time {
|
||||
now := time.Now()
|
||||
|
||||
// 窗口未过期,使用记录的窗口开始时间
|
||||
if a.SessionWindowStart != nil && a.SessionWindowEnd != nil && now.Before(*a.SessionWindowEnd) {
|
||||
return *a.SessionWindowStart
|
||||
}
|
||||
|
||||
// 窗口已过期或未设置,预测新的窗口开始时间(从当前整点开始)
|
||||
// 与 ratelimit_service.go 中 UpdateSessionWindow 的预测逻辑保持一致
|
||||
return time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location())
|
||||
}
|
||||
|
||||
// parseExtraFloat64 从 extra 字段解析 float64 值
|
||||
func parseExtraFloat64(value any) float64 {
|
||||
switch v := value.(type) {
|
||||
|
||||
@@ -33,7 +33,6 @@ type AccountRepository interface {
|
||||
ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
|
||||
ListActive(ctx context.Context) ([]Account, error)
|
||||
ListByPlatform(ctx context.Context, platform string) ([]Account, error)
|
||||
ListByPlatformAndCredentialEmails(ctx context.Context, platform string, emails []string) ([]Account, error)
|
||||
|
||||
UpdateLastUsed(ctx context.Context, id int64) error
|
||||
BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
|
||||
|
||||
@@ -87,10 +87,6 @@ func (s *accountRepoStub) ListByPlatform(ctx context.Context, platform string) (
|
||||
panic("unexpected ListByPlatform call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListByPlatformAndCredentialEmails(ctx context.Context, platform string, emails []string) ([]Account, error) {
|
||||
panic("unexpected ListByPlatformAndCredentialEmails call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) UpdateLastUsed(ctx context.Context, id int64) error {
|
||||
panic("unexpected UpdateLastUsed call")
|
||||
}
|
||||
|
||||
@@ -265,7 +265,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
@@ -375,7 +375,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
@@ -446,7 +446,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
|
||||
@@ -32,8 +32,8 @@ type UsageLogRepository interface {
|
||||
|
||||
// Admin dashboard stats
|
||||
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error)
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error)
|
||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
|
||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||
@@ -157,9 +157,20 @@ type ClaudeUsageResponse struct {
|
||||
} `json:"seven_day_sonnet"`
|
||||
}
|
||||
|
||||
// ClaudeUsageFetchOptions 包含获取 Claude 用量数据所需的所有选项
|
||||
type ClaudeUsageFetchOptions struct {
|
||||
AccessToken string // OAuth access token
|
||||
ProxyURL string // 代理 URL(可选)
|
||||
AccountID int64 // 账号 ID(用于 TLS 指纹选择)
|
||||
EnableTLSFingerprint bool // 是否启用 TLS 指纹伪装
|
||||
Fingerprint *Fingerprint // 缓存的指纹信息(User-Agent 等)
|
||||
}
|
||||
|
||||
// ClaudeUsageFetcher fetches usage data from Anthropic OAuth API
|
||||
type ClaudeUsageFetcher interface {
|
||||
FetchUsage(ctx context.Context, accessToken, proxyURL string) (*ClaudeUsageResponse, error)
|
||||
// FetchUsageWithOptions 使用完整选项获取用量数据,支持 TLS 指纹和自定义 User-Agent
|
||||
FetchUsageWithOptions(ctx context.Context, opts *ClaudeUsageFetchOptions) (*ClaudeUsageResponse, error)
|
||||
}
|
||||
|
||||
// AccountUsageService 账号使用量查询服务
|
||||
@@ -170,6 +181,7 @@ type AccountUsageService struct {
|
||||
geminiQuotaService *GeminiQuotaService
|
||||
antigravityQuotaFetcher *AntigravityQuotaFetcher
|
||||
cache *UsageCache
|
||||
identityCache IdentityCache
|
||||
}
|
||||
|
||||
// NewAccountUsageService 创建AccountUsageService实例
|
||||
@@ -180,6 +192,7 @@ func NewAccountUsageService(
|
||||
geminiQuotaService *GeminiQuotaService,
|
||||
antigravityQuotaFetcher *AntigravityQuotaFetcher,
|
||||
cache *UsageCache,
|
||||
identityCache IdentityCache,
|
||||
) *AccountUsageService {
|
||||
return &AccountUsageService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -188,6 +201,7 @@ func NewAccountUsageService(
|
||||
geminiQuotaService: geminiQuotaService,
|
||||
antigravityQuotaFetcher: antigravityQuotaFetcher,
|
||||
cache: cache,
|
||||
identityCache: identityCache,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -272,7 +286,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
|
||||
}
|
||||
|
||||
dayStart := geminiDailyWindowStart(now)
|
||||
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil)
|
||||
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
|
||||
}
|
||||
@@ -294,7 +308,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
|
||||
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
|
||||
minuteStart := now.Truncate(time.Minute)
|
||||
minuteResetAt := minuteStart.Add(time.Minute)
|
||||
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil)
|
||||
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
|
||||
}
|
||||
@@ -369,12 +383,8 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
|
||||
|
||||
// 如果没有缓存,从数据库查询
|
||||
if windowStats == nil {
|
||||
var startTime time.Time
|
||||
if account.SessionWindowStart != nil {
|
||||
startTime = *account.SessionWindowStart
|
||||
} else {
|
||||
startTime = time.Now().Add(-5 * time.Hour)
|
||||
}
|
||||
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
|
||||
startTime := account.GetCurrentWindowStartTime()
|
||||
|
||||
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
|
||||
if err != nil {
|
||||
@@ -428,6 +438,8 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
// fetchOAuthUsageRaw 从 Anthropic API 获取原始响应(不构建 UsageInfo)
|
||||
// 如果账号开启了 TLS 指纹,则使用 TLS 指纹伪装
|
||||
// 如果有缓存的 Fingerprint,则使用缓存的 User-Agent 等信息
|
||||
func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *Account) (*ClaudeUsageResponse, error) {
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if accessToken == "" {
|
||||
@@ -439,7 +451,22 @@ func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *A
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
return s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL)
|
||||
// 构建完整的选项
|
||||
opts := &ClaudeUsageFetchOptions{
|
||||
AccessToken: accessToken,
|
||||
ProxyURL: proxyURL,
|
||||
AccountID: account.ID,
|
||||
EnableTLSFingerprint: account.IsTLSFingerprintEnabled(),
|
||||
}
|
||||
|
||||
// 尝试获取缓存的 Fingerprint(包含 User-Agent 等信息)
|
||||
if s.identityCache != nil {
|
||||
if fp, err := s.identityCache.GetFingerprint(ctx, account.ID); err == nil && fp != nil {
|
||||
opts.Fingerprint = fp
|
||||
}
|
||||
}
|
||||
|
||||
return s.usageFetcher.FetchUsageWithOptions(ctx, opts)
|
||||
}
|
||||
|
||||
// parseTime 尝试多种格式解析时间
|
||||
|
||||
@@ -40,7 +40,6 @@ type AdminService interface {
|
||||
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
|
||||
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error)
|
||||
DeleteAccount(ctx context.Context, id int64) error
|
||||
LookupAccountsByCredentialEmail(ctx context.Context, platform string, emails []string) ([]Account, error)
|
||||
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
|
||||
ClearAccountError(ctx context.Context, id int64) (*Account, error)
|
||||
SetAccountError(ctx context.Context, id int64, errorMsg string) error
|
||||
@@ -866,13 +865,6 @@ func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account,
|
||||
return s.accountRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) LookupAccountsByCredentialEmail(ctx context.Context, platform string, emails []string) ([]Account, error) {
|
||||
if platform == "" || len(emails) == 0 {
|
||||
return []Account{}, nil
|
||||
}
|
||||
return s.accountRepo.ListByPlatformAndCredentialEmails(ctx, platform, emails)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
|
||||
if len(ids) == 0 {
|
||||
return []*Account{}, nil
|
||||
|
||||
@@ -93,6 +93,18 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
|
||||
panic("unexpected RemoveGroupFromAllowedGroups call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||
panic("unexpected UpdateTotpSecret call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) EnableTotp(ctx context.Context, userID int64) error {
|
||||
panic("unexpected EnableTotp call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) DisableTotp(ctx context.Context, userID int64) error {
|
||||
panic("unexpected DisableTotp call")
|
||||
}
|
||||
|
||||
type groupRepoStub struct {
|
||||
affectedUserIDs []int64
|
||||
deleteErr error
|
||||
|
||||
64
backend/internal/service/announcement.go
Normal file
64
backend/internal/service/announcement.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementStatusDraft = domain.AnnouncementStatusDraft
|
||||
AnnouncementStatusActive = domain.AnnouncementStatusActive
|
||||
AnnouncementStatusArchived = domain.AnnouncementStatusArchived
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementConditionTypeSubscription = domain.AnnouncementConditionTypeSubscription
|
||||
AnnouncementConditionTypeBalance = domain.AnnouncementConditionTypeBalance
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementOperatorIn = domain.AnnouncementOperatorIn
|
||||
AnnouncementOperatorGT = domain.AnnouncementOperatorGT
|
||||
AnnouncementOperatorGTE = domain.AnnouncementOperatorGTE
|
||||
AnnouncementOperatorLT = domain.AnnouncementOperatorLT
|
||||
AnnouncementOperatorLTE = domain.AnnouncementOperatorLTE
|
||||
AnnouncementOperatorEQ = domain.AnnouncementOperatorEQ
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAnnouncementNotFound = domain.ErrAnnouncementNotFound
|
||||
ErrAnnouncementInvalidTarget = domain.ErrAnnouncementInvalidTarget
|
||||
)
|
||||
|
||||
type AnnouncementTargeting = domain.AnnouncementTargeting
|
||||
|
||||
type AnnouncementConditionGroup = domain.AnnouncementConditionGroup
|
||||
|
||||
type AnnouncementCondition = domain.AnnouncementCondition
|
||||
|
||||
type Announcement = domain.Announcement
|
||||
|
||||
type AnnouncementListFilters struct {
|
||||
Status string
|
||||
Search string
|
||||
}
|
||||
|
||||
type AnnouncementRepository interface {
|
||||
Create(ctx context.Context, a *Announcement) error
|
||||
GetByID(ctx context.Context, id int64) (*Announcement, error)
|
||||
Update(ctx context.Context, a *Announcement) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams, filters AnnouncementListFilters) ([]Announcement, *pagination.PaginationResult, error)
|
||||
ListActive(ctx context.Context, now time.Time) ([]Announcement, error)
|
||||
}
|
||||
|
||||
type AnnouncementReadRepository interface {
|
||||
MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error
|
||||
GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error)
|
||||
GetReadMapByUsers(ctx context.Context, announcementID int64, userIDs []int64) (map[int64]time.Time, error)
|
||||
CountByAnnouncementID(ctx context.Context, announcementID int64) (int64, error)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user