Compare commits

..

1 Commits

Author SHA1 Message Date
shaw
f36170dbfe chore: remove accidentally committed test binary 2026-01-11 10:35:09 +08:00
381 changed files with 2082 additions and 71470 deletions

5
.gitignore vendored
View File

@@ -83,8 +83,6 @@ temp/
*.log
*.bak
.cache/
.dev/
.serena/
# ===================
# 构建产物
@@ -128,5 +126,6 @@ backend/cmd/server/server
deploy/docker-compose.override.yml
.gocache/
vite.config.js
!docs/
docs/*
.serena/
!docs/dependency-security.md

View File

@@ -1,164 +0,0 @@
## 概述
全面增强运维监控系统Ops的错误日志管理和告警静默功能优化前端 UI 组件代码质量和用户体验。本次更新重构了核心服务层和数据访问层,提升系统可维护性和运维效率。
## 主要改动
### 1. 错误日志查询优化
**功能特性:**
- 新增 GetErrorLogByID 接口,支持按 ID 精确查询错误详情
- 优化错误日志过滤逻辑,支持多维度筛选(平台、阶段、来源、所有者等)
- 改进查询参数处理,简化代码结构
- 增强错误分类和标准化处理
- 支持错误解决状态追踪resolved 字段)
**技术实现:**
- `ops_handler.go` - 新增单条错误日志查询接口
- `ops_repo.go` - 优化数据查询和过滤条件构建
- `ops_models.go` - 扩展错误日志数据模型
- 前端 API 接口同步更新
### 2. 告警静默功能
**功能特性:**
- 支持按规则、平台、分组、区域等维度静默告警
- 可设置静默时长和原因说明
- 静默记录可追溯,记录创建人和创建时间
- 自动过期机制,避免永久静默
**技术实现:**
- `037_ops_alert_silences.sql` - 新增告警静默表
- `ops_alerts.go` - 告警静默逻辑实现
- `ops_alerts_handler.go` - 告警静默 API 接口
- `OpsAlertEventsCard.vue` - 前端告警静默操作界面
**数据库结构:**
| 字段 | 类型 | 说明 |
|------|------|------|
| rule_id | BIGINT | 告警规则 ID |
| platform | VARCHAR(64) | 平台标识 |
| group_id | BIGINT | 分组 ID可选 |
| region | VARCHAR(64) | 区域(可选) |
| until | TIMESTAMPTZ | 静默截止时间 |
| reason | TEXT | 静默原因 |
| created_by | BIGINT | 创建人 ID |
### 3. 错误分类标准化
**功能特性:**
- 统一错误阶段分类request|auth|routing|upstream|network|internal
- 规范错误归属分类client|provider|platform
- 标准化错误来源分类client_request|upstream_http|gateway
- 自动迁移历史数据到新分类体系
**技术实现:**
- `038_ops_errors_resolution_retry_results_and_standardize_classification.sql` - 分类标准化迁移
- 自动映射历史遗留分类到新标准
- 自动解决已恢复的上游错误(客户端状态码 < 400
### 4. Gateway 服务集成
**功能特性:**
- 完善各 Gateway 服务的 Ops 集成
- 统一错误日志记录接口
- 增强上游错误追踪能力
**涉及服务:**
- `antigravity_gateway_service.go` - Antigravity 网关集成
- `gateway_service.go` - 通用网关集成
- `gemini_messages_compat_service.go` - Gemini 兼容层集成
- `openai_gateway_service.go` - OpenAI 网关集成
### 5. 前端 UI 优化
**代码重构:**
- 大幅简化错误详情模态框代码(从 828 行优化到 450 行)
- 优化错误日志表格组件,提升可读性
- 清理未使用的 i18n 翻译,减少冗余
- 统一组件代码风格和格式
- 优化骨架屏组件,更好匹配实际看板布局
**布局改进:**
- 修复模态框内容溢出和滚动问题
- 优化表格布局,使用 flex 布局确保正确显示
- 改进看板头部布局和交互
- 提升响应式体验
- 骨架屏支持全屏模式适配
**交互优化:**
- 优化告警事件卡片功能和展示
- 改进错误详情展示逻辑
- 增强请求详情模态框
- 完善运行时设置卡片
- 改进加载动画效果
### 6. 国际化完善
**文案补充:**
- 补充错误日志相关的英文翻译
- 添加告警静默功能的中英文文案
- 完善提示文本和错误信息
- 统一术语翻译标准
## 文件变更
**后端26 个文件):**
- `backend/internal/handler/admin/ops_alerts_handler.go` - 告警接口增强
- `backend/internal/handler/admin/ops_handler.go` - 错误日志接口优化
- `backend/internal/handler/ops_error_logger.go` - 错误记录器增强
- `backend/internal/repository/ops_repo.go` - 数据访问层重构
- `backend/internal/repository/ops_repo_alerts.go` - 告警数据访问增强
- `backend/internal/service/ops_*.go` - 核心服务层重构10 个文件)
- `backend/internal/service/*_gateway_service.go` - Gateway 集成4 个文件)
- `backend/internal/server/routes/admin.go` - 路由配置更新
- `backend/migrations/*.sql` - 数据库迁移2 个文件)
- 测试文件更新5 个文件)
**前端13 个文件):**
- `frontend/src/views/admin/ops/OpsDashboard.vue` - 看板主页优化
- `frontend/src/views/admin/ops/components/*.vue` - 组件重构10 个文件)
- `frontend/src/api/admin/ops.ts` - API 接口扩展
- `frontend/src/i18n/locales/*.ts` - 国际化文本2 个文件)
## 代码统计
- 44 个文件修改
- 3733 行新增
- 995 行删除
- 净增加 2738 行
## 核心改进
**可维护性提升:**
- 重构核心服务层,职责更清晰
- 简化前端组件代码,降低复杂度
- 统一代码风格和命名规范
- 清理冗余代码和未使用的翻译
- 标准化错误分类体系
**功能完善:**
- 告警静默功能,减少告警噪音
- 错误日志查询优化,提升运维效率
- Gateway 服务集成完善,统一监控能力
- 错误解决状态追踪,便于问题管理
**用户体验优化:**
- 修复多个 UI 布局问题
- 优化交互流程
- 完善国际化支持
- 提升响应式体验
- 改进加载状态展示
## 测试验证
- ✅ 错误日志查询和过滤功能
- ✅ 告警静默创建和自动过期
- ✅ 错误分类标准化迁移
- ✅ Gateway 服务错误日志记录
- ✅ 前端组件布局和交互
- ✅ 骨架屏全屏模式适配
- ✅ 国际化文本完整性
- ✅ API 接口功能正确性
- ✅ 数据库迁移执行成功

View File

@@ -57,13 +57,6 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
---
## OpenAI Responses 兼容注意事项
- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id``tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`
- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。
---
## 部署方式
### 方式一:脚本安装(推荐)

View File

@@ -1,2 +0,0 @@
.cache/
.DS_Store

View File

@@ -18,12 +18,6 @@ linters:
list-mode: original
files:
- "**/internal/service/**"
- "!**/internal/service/ops_aggregation_service.go"
- "!**/internal/service/ops_alert_evaluator_service.go"
- "!**/internal/service/ops_cleanup_service.go"
- "!**/internal/service/ops_metrics_collector.go"
- "!**/internal/service/ops_scheduled_report_service.go"
- "!**/internal/service/wire.go"
deny:
- pkg: github.com/Wei-Shaw/sub2api/internal/repository
desc: "service must not import repository"

View File

@@ -8,7 +8,6 @@ import (
"errors"
"flag"
"log"
"log/slog"
"net/http"
"os"
"os/signal"
@@ -45,25 +44,7 @@ func init() {
}
}
// initLogger configures the default slog handler based on gin.Mode().
// In non-release mode, Debug level logs are enabled.
func initLogger() {
var level slog.Level
if gin.Mode() == gin.ReleaseMode {
level = slog.LevelInfo
} else {
level = slog.LevelDebug
}
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
})
slog.SetDefault(slog.New(handler))
}
func main() {
// Initialize slog logger based on gin mode
initLogger()
// Parse command line flags
setupMode := flag.Bool("setup", false, "Run setup wizard in CLI mode")
showVersion := flag.Bool("version", false, "Show version information")

View File

@@ -62,15 +62,8 @@ func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
func provideCleanup(
entClient *ent.Client,
rdb *redis.Client,
opsMetricsCollector *service.OpsMetricsCollector,
opsAggregation *service.OpsAggregationService,
opsAlertEvaluator *service.OpsAlertEvaluatorService,
opsCleanup *service.OpsCleanupService,
opsScheduledReport *service.OpsScheduledReportService,
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
usageCleanup *service.UsageCleanupService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
billingCache *service.BillingCacheService,
@@ -88,48 +81,6 @@ func provideCleanup(
name string
fn func() error
}{
{"OpsScheduledReportService", func() error {
if opsScheduledReport != nil {
opsScheduledReport.Stop()
}
return nil
}},
{"OpsCleanupService", func() error {
if opsCleanup != nil {
opsCleanup.Stop()
}
return nil
}},
{"OpsAlertEvaluatorService", func() error {
if opsAlertEvaluator != nil {
opsAlertEvaluator.Stop()
}
return nil
}},
{"OpsAggregationService", func() error {
if opsAggregation != nil {
opsAggregation.Stop()
}
return nil
}},
{"OpsMetricsCollector", func() error {
if opsMetricsCollector != nil {
opsMetricsCollector.Stop()
}
return nil
}},
{"SchedulerSnapshotService", func() error {
if schedulerSnapshot != nil {
schedulerSnapshot.Stop()
}
return nil
}},
{"UsageCleanupService", func() error {
if usageCleanup != nil {
usageCleanup.Stop()
}
return nil
}},
{"TokenRefreshService", func() error {
tokenRefresh.Stop()
return nil

View File

@@ -55,40 +55,31 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingCache := repository.NewBillingCache(redisClient)
userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
userService := service.NewUserService(userRepository)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService)
userHandler := handler.NewUserHandler(userService)
apiKeyRepository := repository.NewAPIKeyRepository(client)
groupRepository := repository.NewGroupRepository(client, db)
apiKeyCache := repository.NewAPIKeyCache(redisClient)
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService)
userHandler := handler.NewUserHandler(userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
usageService := service.NewUsageService(usageLogRepository, userRepository, client)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client)
redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig)
dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig)
timingWheelService, err := service.ProvideTimingWheelService()
if err != nil {
return nil, err
}
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
dashboardService := service.NewDashboardService(usageLogRepository)
dashboardHandler := admin.NewDashboardHandler(dashboardService)
accountRepository := repository.NewAccountRepository(client, db)
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber)
adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService)
claudeOAuthClient := repository.NewClaudeOAuthClient()
@@ -101,14 +92,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
rateLimitService := service.NewRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache)
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
usageCache := service.NewUsageCache()
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
gatewayCache := repository.NewGatewayCache(redisClient)
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
@@ -118,8 +107,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
@@ -127,10 +115,19 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
proxyHandler := admin.NewProxyHandler(adminService)
adminRedeemHandler := admin.NewRedeemHandler(adminService)
promoHandler := admin.NewPromoHandler(promoService)
opsRepository := repository.NewOpsRepository(db)
schedulerCache := repository.NewSchedulerCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService)
updateCache := repository.NewUpdateCache(redisClient)
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
serviceBuildInfo := provideServiceBuildInfo(buildInfo)
updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo)
systemHandler := handler.ProvideSystemHandler(updateService)
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService)
userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client)
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil {
@@ -139,46 +136,23 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingService := service.NewBillingService(configConfig, pricingService)
identityCache := repository.NewIdentityCache(redisClient)
identityService := service.NewIdentityService(identityCache)
timingWheelService := service.ProvideTimingWheelService()
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache)
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
opsHandler := admin.NewOpsHandler(opsService)
updateCache := repository.NewUpdateCache(redisClient)
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
serviceBuildInfo := provideServiceBuildInfo(buildInfo)
updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo)
systemHandler := handler.ProvideSystemHandler(updateService)
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
usageCleanupRepository := repository.NewUsageCleanupRepository(client, db)
usageCleanupService := service.ProvideUsageCleanupService(usageCleanupRepository, timingWheelService, dashboardAggregationService, configConfig)
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService, usageCleanupService)
userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client)
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, redisClient)
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, settingService, redisClient)
httpServer := server.ProvideHTTPServer(configConfig, engine)
opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig)
opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig)
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
v := provideCleanup(client, redisClient, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -203,15 +177,8 @@ func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
func provideCleanup(
entClient *ent.Client,
rdb *redis.Client,
opsMetricsCollector *service.OpsMetricsCollector,
opsAggregation *service.OpsAggregationService,
opsAlertEvaluator *service.OpsAlertEvaluatorService,
opsCleanup *service.OpsCleanupService,
opsScheduledReport *service.OpsScheduledReportService,
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
usageCleanup *service.UsageCleanupService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
billingCache *service.BillingCacheService,
@@ -228,48 +195,6 @@ func provideCleanup(
name string
fn func() error
}{
{"OpsScheduledReportService", func() error {
if opsScheduledReport != nil {
opsScheduledReport.Stop()
}
return nil
}},
{"OpsCleanupService", func() error {
if opsCleanup != nil {
opsCleanup.Stop()
}
return nil
}},
{"OpsAlertEvaluatorService", func() error {
if opsAlertEvaluator != nil {
opsAlertEvaluator.Stop()
}
return nil
}},
{"OpsAggregationService", func() error {
if opsAggregation != nil {
opsAggregation.Stop()
}
return nil
}},
{"OpsMetricsCollector", func() error {
if opsMetricsCollector != nil {
opsMetricsCollector.Stop()
}
return nil
}},
{"SchedulerSnapshotService", func() error {
if schedulerSnapshot != nil {
schedulerSnapshot.Stop()
}
return nil
}},
{"UsageCleanupService", func() error {
if usageCleanup != nil {
usageCleanup.Stop()
}
return nil
}},
{"TokenRefreshService", func() error {
tokenRefresh.Stop()
return nil

View File

@@ -43,8 +43,6 @@ type Account struct {
Concurrency int `json:"concurrency,omitempty"`
// Priority holds the value of the "priority" field.
Priority int `json:"priority,omitempty"`
// RateMultiplier holds the value of the "rate_multiplier" field.
RateMultiplier float64 `json:"rate_multiplier,omitempty"`
// Status holds the value of the "status" field.
Status string `json:"status,omitempty"`
// ErrorMessage holds the value of the "error_message" field.
@@ -137,8 +135,6 @@ func (*Account) scanValues(columns []string) ([]any, error) {
values[i] = new([]byte)
case account.FieldAutoPauseOnExpired, account.FieldSchedulable:
values[i] = new(sql.NullBool)
case account.FieldRateMultiplier:
values[i] = new(sql.NullFloat64)
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority:
values[i] = new(sql.NullInt64)
case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldSessionWindowStatus:
@@ -245,12 +241,6 @@ func (_m *Account) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.Priority = int(value.Int64)
}
case account.FieldRateMultiplier:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field rate_multiplier", values[i])
} else if value.Valid {
_m.RateMultiplier = value.Float64
}
case account.FieldStatus:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field status", values[i])
@@ -430,9 +420,6 @@ func (_m *Account) String() string {
builder.WriteString("priority=")
builder.WriteString(fmt.Sprintf("%v", _m.Priority))
builder.WriteString(", ")
builder.WriteString("rate_multiplier=")
builder.WriteString(fmt.Sprintf("%v", _m.RateMultiplier))
builder.WriteString(", ")
builder.WriteString("status=")
builder.WriteString(_m.Status)
builder.WriteString(", ")

View File

@@ -39,8 +39,6 @@ const (
FieldConcurrency = "concurrency"
// FieldPriority holds the string denoting the priority field in the database.
FieldPriority = "priority"
// FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
FieldRateMultiplier = "rate_multiplier"
// FieldStatus holds the string denoting the status field in the database.
FieldStatus = "status"
// FieldErrorMessage holds the string denoting the error_message field in the database.
@@ -118,7 +116,6 @@ var Columns = []string{
FieldProxyID,
FieldConcurrency,
FieldPriority,
FieldRateMultiplier,
FieldStatus,
FieldErrorMessage,
FieldLastUsedAt,
@@ -177,8 +174,6 @@ var (
DefaultConcurrency int
// DefaultPriority holds the default value on creation for the "priority" field.
DefaultPriority int
// DefaultRateMultiplier holds the default value on creation for the "rate_multiplier" field.
DefaultRateMultiplier float64
// DefaultStatus holds the default value on creation for the "status" field.
DefaultStatus string
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
@@ -249,11 +244,6 @@ func ByPriority(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldPriority, opts...).ToFunc()
}
// ByRateMultiplier orders the results by the rate_multiplier field.
func ByRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRateMultiplier, opts...).ToFunc()
}
// ByStatus orders the results by the status field.
func ByStatus(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStatus, opts...).ToFunc()

View File

@@ -105,11 +105,6 @@ func Priority(v int) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldPriority, v))
}
// RateMultiplier applies equality check predicate on the "rate_multiplier" field. It's identical to RateMultiplierEQ.
func RateMultiplier(v float64) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldRateMultiplier, v))
}
// Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
func Status(v string) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldStatus, v))
@@ -680,46 +675,6 @@ func PriorityLTE(v int) predicate.Account {
return predicate.Account(sql.FieldLTE(FieldPriority, v))
}
// RateMultiplierEQ applies the EQ predicate on the "rate_multiplier" field.
func RateMultiplierEQ(v float64) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldRateMultiplier, v))
}
// RateMultiplierNEQ applies the NEQ predicate on the "rate_multiplier" field.
func RateMultiplierNEQ(v float64) predicate.Account {
return predicate.Account(sql.FieldNEQ(FieldRateMultiplier, v))
}
// RateMultiplierIn applies the In predicate on the "rate_multiplier" field.
func RateMultiplierIn(vs ...float64) predicate.Account {
return predicate.Account(sql.FieldIn(FieldRateMultiplier, vs...))
}
// RateMultiplierNotIn applies the NotIn predicate on the "rate_multiplier" field.
func RateMultiplierNotIn(vs ...float64) predicate.Account {
return predicate.Account(sql.FieldNotIn(FieldRateMultiplier, vs...))
}
// RateMultiplierGT applies the GT predicate on the "rate_multiplier" field.
func RateMultiplierGT(v float64) predicate.Account {
return predicate.Account(sql.FieldGT(FieldRateMultiplier, v))
}
// RateMultiplierGTE applies the GTE predicate on the "rate_multiplier" field.
func RateMultiplierGTE(v float64) predicate.Account {
return predicate.Account(sql.FieldGTE(FieldRateMultiplier, v))
}
// RateMultiplierLT applies the LT predicate on the "rate_multiplier" field.
func RateMultiplierLT(v float64) predicate.Account {
return predicate.Account(sql.FieldLT(FieldRateMultiplier, v))
}
// RateMultiplierLTE applies the LTE predicate on the "rate_multiplier" field.
func RateMultiplierLTE(v float64) predicate.Account {
return predicate.Account(sql.FieldLTE(FieldRateMultiplier, v))
}
// StatusEQ applies the EQ predicate on the "status" field.
func StatusEQ(v string) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldStatus, v))

View File

@@ -153,20 +153,6 @@ func (_c *AccountCreate) SetNillablePriority(v *int) *AccountCreate {
return _c
}
// SetRateMultiplier sets the "rate_multiplier" field.
func (_c *AccountCreate) SetRateMultiplier(v float64) *AccountCreate {
_c.mutation.SetRateMultiplier(v)
return _c
}
// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil.
func (_c *AccountCreate) SetNillableRateMultiplier(v *float64) *AccountCreate {
if v != nil {
_c.SetRateMultiplier(*v)
}
return _c
}
// SetStatus sets the "status" field.
func (_c *AccountCreate) SetStatus(v string) *AccountCreate {
_c.mutation.SetStatus(v)
@@ -443,10 +429,6 @@ func (_c *AccountCreate) defaults() error {
v := account.DefaultPriority
_c.mutation.SetPriority(v)
}
if _, ok := _c.mutation.RateMultiplier(); !ok {
v := account.DefaultRateMultiplier
_c.mutation.SetRateMultiplier(v)
}
if _, ok := _c.mutation.Status(); !ok {
v := account.DefaultStatus
_c.mutation.SetStatus(v)
@@ -506,9 +488,6 @@ func (_c *AccountCreate) check() error {
if _, ok := _c.mutation.Priority(); !ok {
return &ValidationError{Name: "priority", err: errors.New(`ent: missing required field "Account.priority"`)}
}
if _, ok := _c.mutation.RateMultiplier(); !ok {
return &ValidationError{Name: "rate_multiplier", err: errors.New(`ent: missing required field "Account.rate_multiplier"`)}
}
if _, ok := _c.mutation.Status(); !ok {
return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "Account.status"`)}
}
@@ -599,10 +578,6 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) {
_spec.SetField(account.FieldPriority, field.TypeInt, value)
_node.Priority = value
}
if value, ok := _c.mutation.RateMultiplier(); ok {
_spec.SetField(account.FieldRateMultiplier, field.TypeFloat64, value)
_node.RateMultiplier = value
}
if value, ok := _c.mutation.Status(); ok {
_spec.SetField(account.FieldStatus, field.TypeString, value)
_node.Status = value
@@ -918,24 +893,6 @@ func (u *AccountUpsert) AddPriority(v int) *AccountUpsert {
return u
}
// SetRateMultiplier sets the "rate_multiplier" field.
func (u *AccountUpsert) SetRateMultiplier(v float64) *AccountUpsert {
u.Set(account.FieldRateMultiplier, v)
return u
}
// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create.
func (u *AccountUpsert) UpdateRateMultiplier() *AccountUpsert {
u.SetExcluded(account.FieldRateMultiplier)
return u
}
// AddRateMultiplier adds v to the "rate_multiplier" field.
func (u *AccountUpsert) AddRateMultiplier(v float64) *AccountUpsert {
u.Add(account.FieldRateMultiplier, v)
return u
}
// SetStatus sets the "status" field.
func (u *AccountUpsert) SetStatus(v string) *AccountUpsert {
u.Set(account.FieldStatus, v)
@@ -1368,27 +1325,6 @@ func (u *AccountUpsertOne) UpdatePriority() *AccountUpsertOne {
})
}
// SetRateMultiplier sets the "rate_multiplier" field.
func (u *AccountUpsertOne) SetRateMultiplier(v float64) *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
s.SetRateMultiplier(v)
})
}
// AddRateMultiplier adds v to the "rate_multiplier" field.
func (u *AccountUpsertOne) AddRateMultiplier(v float64) *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
s.AddRateMultiplier(v)
})
}
// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create.
func (u *AccountUpsertOne) UpdateRateMultiplier() *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
s.UpdateRateMultiplier()
})
}
// SetStatus sets the "status" field.
func (u *AccountUpsertOne) SetStatus(v string) *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
@@ -2020,27 +1956,6 @@ func (u *AccountUpsertBulk) UpdatePriority() *AccountUpsertBulk {
})
}
// SetRateMultiplier sets the "rate_multiplier" field.
func (u *AccountUpsertBulk) SetRateMultiplier(v float64) *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
s.SetRateMultiplier(v)
})
}
// AddRateMultiplier adds v to the "rate_multiplier" field.
func (u *AccountUpsertBulk) AddRateMultiplier(v float64) *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
s.AddRateMultiplier(v)
})
}
// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create.
func (u *AccountUpsertBulk) UpdateRateMultiplier() *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
s.UpdateRateMultiplier()
})
}
// SetStatus sets the "status" field.
func (u *AccountUpsertBulk) SetStatus(v string) *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {

View File

@@ -193,27 +193,6 @@ func (_u *AccountUpdate) AddPriority(v int) *AccountUpdate {
return _u
}
// SetRateMultiplier sets the "rate_multiplier" field.
func (_u *AccountUpdate) SetRateMultiplier(v float64) *AccountUpdate {
_u.mutation.ResetRateMultiplier()
_u.mutation.SetRateMultiplier(v)
return _u
}
// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil.
func (_u *AccountUpdate) SetNillableRateMultiplier(v *float64) *AccountUpdate {
if v != nil {
_u.SetRateMultiplier(*v)
}
return _u
}
// AddRateMultiplier adds value to the "rate_multiplier" field.
func (_u *AccountUpdate) AddRateMultiplier(v float64) *AccountUpdate {
_u.mutation.AddRateMultiplier(v)
return _u
}
// SetStatus sets the "status" field.
func (_u *AccountUpdate) SetStatus(v string) *AccountUpdate {
_u.mutation.SetStatus(v)
@@ -650,12 +629,6 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.AddedPriority(); ok {
_spec.AddField(account.FieldPriority, field.TypeInt, value)
}
if value, ok := _u.mutation.RateMultiplier(); ok {
_spec.SetField(account.FieldRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
_spec.AddField(account.FieldRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(account.FieldStatus, field.TypeString, value)
}
@@ -1032,27 +1005,6 @@ func (_u *AccountUpdateOne) AddPriority(v int) *AccountUpdateOne {
return _u
}
// SetRateMultiplier sets the "rate_multiplier" field.
func (_u *AccountUpdateOne) SetRateMultiplier(v float64) *AccountUpdateOne {
_u.mutation.ResetRateMultiplier()
_u.mutation.SetRateMultiplier(v)
return _u
}
// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil.
func (_u *AccountUpdateOne) SetNillableRateMultiplier(v *float64) *AccountUpdateOne {
if v != nil {
_u.SetRateMultiplier(*v)
}
return _u
}
// AddRateMultiplier adds value to the "rate_multiplier" field.
func (_u *AccountUpdateOne) AddRateMultiplier(v float64) *AccountUpdateOne {
_u.mutation.AddRateMultiplier(v)
return _u
}
// SetStatus sets the "status" field.
func (_u *AccountUpdateOne) SetStatus(v string) *AccountUpdateOne {
_u.mutation.SetStatus(v)
@@ -1519,12 +1471,6 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er
if value, ok := _u.mutation.AddedPriority(); ok {
_spec.AddField(account.FieldPriority, field.TypeInt, value)
}
if value, ok := _u.mutation.RateMultiplier(); ok {
_spec.SetField(account.FieldRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
_spec.AddField(account.FieldRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(account.FieldStatus, field.TypeString, value)
}

View File

@@ -24,7 +24,6 @@ import (
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/setting"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
@@ -58,8 +57,6 @@ type Client struct {
RedeemCode *RedeemCodeClient
// Setting is the client for interacting with the Setting builders.
Setting *SettingClient
// UsageCleanupTask is the client for interacting with the UsageCleanupTask builders.
UsageCleanupTask *UsageCleanupTaskClient
// UsageLog is the client for interacting with the UsageLog builders.
UsageLog *UsageLogClient
// User is the client for interacting with the User builders.
@@ -92,7 +89,6 @@ func (c *Client) init() {
c.Proxy = NewProxyClient(c.config)
c.RedeemCode = NewRedeemCodeClient(c.config)
c.Setting = NewSettingClient(c.config)
c.UsageCleanupTask = NewUsageCleanupTaskClient(c.config)
c.UsageLog = NewUsageLogClient(c.config)
c.User = NewUserClient(c.config)
c.UserAllowedGroup = NewUserAllowedGroupClient(c.config)
@@ -200,7 +196,6 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
Proxy: NewProxyClient(cfg),
RedeemCode: NewRedeemCodeClient(cfg),
Setting: NewSettingClient(cfg),
UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
UsageLog: NewUsageLogClient(cfg),
User: NewUserClient(cfg),
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
@@ -235,7 +230,6 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
Proxy: NewProxyClient(cfg),
RedeemCode: NewRedeemCodeClient(cfg),
Setting: NewSettingClient(cfg),
UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
UsageLog: NewUsageLogClient(cfg),
User: NewUserClient(cfg),
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
@@ -272,9 +266,8 @@ func (c *Client) Close() error {
func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{
c.APIKey, c.Account, c.AccountGroup, c.Group, c.PromoCode, c.PromoCodeUsage,
c.Proxy, c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User,
c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
c.Proxy, c.RedeemCode, c.Setting, c.UsageLog, c.User, c.UserAllowedGroup,
c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription,
} {
n.Use(hooks...)
}
@@ -285,9 +278,8 @@ func (c *Client) Use(hooks ...Hook) {
func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{
c.APIKey, c.Account, c.AccountGroup, c.Group, c.PromoCode, c.PromoCodeUsage,
c.Proxy, c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User,
c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
c.Proxy, c.RedeemCode, c.Setting, c.UsageLog, c.User, c.UserAllowedGroup,
c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription,
} {
n.Intercept(interceptors...)
}
@@ -314,8 +306,6 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.RedeemCode.mutate(ctx, m)
case *SettingMutation:
return c.Setting.mutate(ctx, m)
case *UsageCleanupTaskMutation:
return c.UsageCleanupTask.mutate(ctx, m)
case *UsageLogMutation:
return c.UsageLog.mutate(ctx, m)
case *UserMutation:
@@ -1857,139 +1847,6 @@ func (c *SettingClient) mutate(ctx context.Context, m *SettingMutation) (Value,
}
}
// UsageCleanupTaskClient is a client for the UsageCleanupTask schema.
type UsageCleanupTaskClient struct {
config
}
// NewUsageCleanupTaskClient returns a client for the UsageCleanupTask from the given config.
func NewUsageCleanupTaskClient(c config) *UsageCleanupTaskClient {
return &UsageCleanupTaskClient{config: c}
}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `usagecleanuptask.Hooks(f(g(h())))`.
func (c *UsageCleanupTaskClient) Use(hooks ...Hook) {
c.hooks.UsageCleanupTask = append(c.hooks.UsageCleanupTask, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `usagecleanuptask.Intercept(f(g(h())))`.
func (c *UsageCleanupTaskClient) Intercept(interceptors ...Interceptor) {
c.inters.UsageCleanupTask = append(c.inters.UsageCleanupTask, interceptors...)
}
// Create returns a builder for creating a UsageCleanupTask entity.
func (c *UsageCleanupTaskClient) Create() *UsageCleanupTaskCreate {
mutation := newUsageCleanupTaskMutation(c.config, OpCreate)
return &UsageCleanupTaskCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of UsageCleanupTask entities.
func (c *UsageCleanupTaskClient) CreateBulk(builders ...*UsageCleanupTaskCreate) *UsageCleanupTaskCreateBulk {
return &UsageCleanupTaskCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *UsageCleanupTaskClient) MapCreateBulk(slice any, setFunc func(*UsageCleanupTaskCreate, int)) *UsageCleanupTaskCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &UsageCleanupTaskCreateBulk{err: fmt.Errorf("calling to UsageCleanupTaskClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*UsageCleanupTaskCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &UsageCleanupTaskCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for UsageCleanupTask.
func (c *UsageCleanupTaskClient) Update() *UsageCleanupTaskUpdate {
mutation := newUsageCleanupTaskMutation(c.config, OpUpdate)
return &UsageCleanupTaskUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *UsageCleanupTaskClient) UpdateOne(_m *UsageCleanupTask) *UsageCleanupTaskUpdateOne {
mutation := newUsageCleanupTaskMutation(c.config, OpUpdateOne, withUsageCleanupTask(_m))
return &UsageCleanupTaskUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOneID returns an update builder for the given id.
func (c *UsageCleanupTaskClient) UpdateOneID(id int64) *UsageCleanupTaskUpdateOne {
mutation := newUsageCleanupTaskMutation(c.config, OpUpdateOne, withUsageCleanupTaskID(id))
return &UsageCleanupTaskUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// Delete returns a delete builder for UsageCleanupTask.
func (c *UsageCleanupTaskClient) Delete() *UsageCleanupTaskDelete {
mutation := newUsageCleanupTaskMutation(c.config, OpDelete)
return &UsageCleanupTaskDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// DeleteOne returns a builder for deleting the given entity.
func (c *UsageCleanupTaskClient) DeleteOne(_m *UsageCleanupTask) *UsageCleanupTaskDeleteOne {
return c.DeleteOneID(_m.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *UsageCleanupTaskClient) DeleteOneID(id int64) *UsageCleanupTaskDeleteOne {
builder := c.Delete().Where(usagecleanuptask.ID(id))
builder.mutation.id = &id
builder.mutation.op = OpDeleteOne
return &UsageCleanupTaskDeleteOne{builder}
}
// Query returns a query builder for UsageCleanupTask.
func (c *UsageCleanupTaskClient) Query() *UsageCleanupTaskQuery {
return &UsageCleanupTaskQuery{
config: c.config,
ctx: &QueryContext{Type: TypeUsageCleanupTask},
inters: c.Interceptors(),
}
}
// Get returns a UsageCleanupTask entity by its id.
func (c *UsageCleanupTaskClient) Get(ctx context.Context, id int64) (*UsageCleanupTask, error) {
return c.Query().Where(usagecleanuptask.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *UsageCleanupTaskClient) GetX(ctx context.Context, id int64) *UsageCleanupTask {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
// Hooks returns the client hooks.
func (c *UsageCleanupTaskClient) Hooks() []Hook {
return c.hooks.UsageCleanupTask
}
// Interceptors returns the client interceptors.
func (c *UsageCleanupTaskClient) Interceptors() []Interceptor {
return c.inters.UsageCleanupTask
}
func (c *UsageCleanupTaskClient) mutate(ctx context.Context, m *UsageCleanupTaskMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&UsageCleanupTaskCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&UsageCleanupTaskUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&UsageCleanupTaskUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&UsageCleanupTaskDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown UsageCleanupTask mutation op: %q", m.Op())
}
}
// UsageLogClient is a client for the UsageLog schema.
type UsageLogClient struct {
config
@@ -3117,13 +2974,13 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
type (
hooks struct {
APIKey, Account, AccountGroup, Group, PromoCode, PromoCodeUsage, Proxy,
RedeemCode, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook
RedeemCode, Setting, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
UserAttributeValue, UserSubscription []ent.Hook
}
inters struct {
APIKey, Account, AccountGroup, Group, PromoCode, PromoCodeUsage, Proxy,
RedeemCode, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor
RedeemCode, Setting, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
UserAttributeValue, UserSubscription []ent.Interceptor
}
)

View File

@@ -21,7 +21,6 @@ import (
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/setting"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
@@ -97,7 +96,6 @@ func checkColumn(t, c string) error {
proxy.Table: proxy.ValidColumn,
redeemcode.Table: redeemcode.ValidColumn,
setting.Table: setting.ValidColumn,
usagecleanuptask.Table: usagecleanuptask.ValidColumn,
usagelog.Table: usagelog.ValidColumn,
user.Table: user.ValidColumn,
userallowedgroup.Table: userallowedgroup.ValidColumn,

View File

@@ -3,7 +3,6 @@
package ent
import (
"encoding/json"
"fmt"
"strings"
"time"
@@ -56,10 +55,6 @@ type Group struct {
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
// 非 Claude Code 请求降级使用的分组 ID
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
// 模型路由配置:模型模式 -> 优先账号ID列表
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
// 是否启用模型路由配置
ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the GroupQuery when eager-loading is set.
Edges GroupEdges `json:"edges"`
@@ -166,9 +161,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case group.FieldModelRouting:
values[i] = new([]byte)
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled:
case group.FieldIsExclusive, group.FieldClaudeCodeOnly:
values[i] = new(sql.NullBool)
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
values[i] = new(sql.NullFloat64)
@@ -322,20 +315,6 @@ func (_m *Group) assignValues(columns []string, values []any) error {
_m.FallbackGroupID = new(int64)
*_m.FallbackGroupID = value.Int64
}
case group.FieldModelRouting:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field model_routing", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.ModelRouting); err != nil {
return fmt.Errorf("unmarshal field model_routing: %w", err)
}
}
case group.FieldModelRoutingEnabled:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field model_routing_enabled", values[i])
} else if value.Valid {
_m.ModelRoutingEnabled = value.Bool
}
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -486,12 +465,6 @@ func (_m *Group) String() string {
builder.WriteString("fallback_group_id=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("model_routing=")
builder.WriteString(fmt.Sprintf("%v", _m.ModelRouting))
builder.WriteString(", ")
builder.WriteString("model_routing_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.ModelRoutingEnabled))
builder.WriteByte(')')
return builder.String()
}

View File

@@ -53,10 +53,6 @@ const (
FieldClaudeCodeOnly = "claude_code_only"
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
FieldFallbackGroupID = "fallback_group_id"
// FieldModelRouting holds the string denoting the model_routing field in the database.
FieldModelRouting = "model_routing"
// FieldModelRoutingEnabled holds the string denoting the model_routing_enabled field in the database.
FieldModelRoutingEnabled = "model_routing_enabled"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -151,8 +147,6 @@ var Columns = []string{
FieldImagePrice4k,
FieldClaudeCodeOnly,
FieldFallbackGroupID,
FieldModelRouting,
FieldModelRoutingEnabled,
}
var (
@@ -210,8 +204,6 @@ var (
DefaultDefaultValidityDays int
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
DefaultClaudeCodeOnly bool
// DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field.
DefaultModelRoutingEnabled bool
)
// OrderOption defines the ordering options for the Group queries.
@@ -317,11 +309,6 @@ func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc()
}
// ByModelRoutingEnabled orders the results by the model_routing_enabled field.
func ByModelRoutingEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldModelRoutingEnabled, opts...).ToFunc()
}
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {

View File

@@ -150,11 +150,6 @@ func FallbackGroupID(v int64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
}
// ModelRoutingEnabled applies equality check predicate on the "model_routing_enabled" field. It's identical to ModelRoutingEnabledEQ.
func ModelRoutingEnabled(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
@@ -1070,26 +1065,6 @@ func FallbackGroupIDNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID))
}
// ModelRoutingIsNil applies the IsNil predicate on the "model_routing" field.
func ModelRoutingIsNil() predicate.Group {
return predicate.Group(sql.FieldIsNull(FieldModelRouting))
}
// ModelRoutingNotNil applies the NotNil predicate on the "model_routing" field.
func ModelRoutingNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldModelRouting))
}
// ModelRoutingEnabledEQ applies the EQ predicate on the "model_routing_enabled" field.
func ModelRoutingEnabledEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v))
}
// ModelRoutingEnabledNEQ applies the NEQ predicate on the "model_routing_enabled" field.
func ModelRoutingEnabledNEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldModelRoutingEnabled, v))
}
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.Group {
return predicate.Group(func(s *sql.Selector) {

View File

@@ -286,26 +286,6 @@ func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate {
return _c
}
// SetModelRouting sets the "model_routing" field.
func (_c *GroupCreate) SetModelRouting(v map[string][]int64) *GroupCreate {
_c.mutation.SetModelRouting(v)
return _c
}
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
func (_c *GroupCreate) SetModelRoutingEnabled(v bool) *GroupCreate {
_c.mutation.SetModelRoutingEnabled(v)
return _c
}
// SetNillableModelRoutingEnabled sets the "model_routing_enabled" field if the given value is not nil.
func (_c *GroupCreate) SetNillableModelRoutingEnabled(v *bool) *GroupCreate {
if v != nil {
_c.SetModelRoutingEnabled(*v)
}
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -475,10 +455,6 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultClaudeCodeOnly
_c.mutation.SetClaudeCodeOnly(v)
}
if _, ok := _c.mutation.ModelRoutingEnabled(); !ok {
v := group.DefaultModelRoutingEnabled
_c.mutation.SetModelRoutingEnabled(v)
}
return nil
}
@@ -534,9 +510,6 @@ func (_c *GroupCreate) check() error {
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
}
if _, ok := _c.mutation.ModelRoutingEnabled(); !ok {
return &ValidationError{Name: "model_routing_enabled", err: errors.New(`ent: missing required field "Group.model_routing_enabled"`)}
}
return nil
}
@@ -640,14 +613,6 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
_node.FallbackGroupID = &value
}
if value, ok := _c.mutation.ModelRouting(); ok {
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
_node.ModelRouting = value
}
if value, ok := _c.mutation.ModelRoutingEnabled(); ok {
_spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
_node.ModelRoutingEnabled = value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1128,36 +1093,6 @@ func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert {
return u
}
// SetModelRouting sets the "model_routing" field.
func (u *GroupUpsert) SetModelRouting(v map[string][]int64) *GroupUpsert {
u.Set(group.FieldModelRouting, v)
return u
}
// UpdateModelRouting sets the "model_routing" field to the value that was provided on create.
func (u *GroupUpsert) UpdateModelRouting() *GroupUpsert {
u.SetExcluded(group.FieldModelRouting)
return u
}
// ClearModelRouting clears the value of the "model_routing" field.
func (u *GroupUpsert) ClearModelRouting() *GroupUpsert {
u.SetNull(group.FieldModelRouting)
return u
}
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
func (u *GroupUpsert) SetModelRoutingEnabled(v bool) *GroupUpsert {
u.Set(group.FieldModelRoutingEnabled, v)
return u
}
// UpdateModelRoutingEnabled sets the "model_routing_enabled" field to the value that was provided on create.
func (u *GroupUpsert) UpdateModelRoutingEnabled() *GroupUpsert {
u.SetExcluded(group.FieldModelRoutingEnabled)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -1581,41 +1516,6 @@ func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne {
})
}
// SetModelRouting sets the "model_routing" field.
func (u *GroupUpsertOne) SetModelRouting(v map[string][]int64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetModelRouting(v)
})
}
// UpdateModelRouting sets the "model_routing" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateModelRouting() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateModelRouting()
})
}
// ClearModelRouting clears the value of the "model_routing" field.
func (u *GroupUpsertOne) ClearModelRouting() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.ClearModelRouting()
})
}
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
func (u *GroupUpsertOne) SetModelRoutingEnabled(v bool) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetModelRoutingEnabled(v)
})
}
// UpdateModelRoutingEnabled sets the "model_routing_enabled" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateModelRoutingEnabled() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateModelRoutingEnabled()
})
}
// Exec executes the query.
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -2205,41 +2105,6 @@ func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk {
})
}
// SetModelRouting sets the "model_routing" field.
func (u *GroupUpsertBulk) SetModelRouting(v map[string][]int64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetModelRouting(v)
})
}
// UpdateModelRouting sets the "model_routing" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateModelRouting() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateModelRouting()
})
}
// ClearModelRouting clears the value of the "model_routing" field.
func (u *GroupUpsertBulk) ClearModelRouting() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.ClearModelRouting()
})
}
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
func (u *GroupUpsertBulk) SetModelRoutingEnabled(v bool) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetModelRoutingEnabled(v)
})
}
// UpdateModelRoutingEnabled sets the "model_routing_enabled" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateModelRoutingEnabled() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateModelRoutingEnabled()
})
}
// Exec executes the query.
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {

View File

@@ -395,32 +395,6 @@ func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate {
return _u
}
// SetModelRouting sets the "model_routing" field.
func (_u *GroupUpdate) SetModelRouting(v map[string][]int64) *GroupUpdate {
_u.mutation.SetModelRouting(v)
return _u
}
// ClearModelRouting clears the value of the "model_routing" field.
func (_u *GroupUpdate) ClearModelRouting() *GroupUpdate {
_u.mutation.ClearModelRouting()
return _u
}
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
func (_u *GroupUpdate) SetModelRoutingEnabled(v bool) *GroupUpdate {
_u.mutation.SetModelRoutingEnabled(v)
return _u
}
// SetNillableModelRoutingEnabled sets the "model_routing_enabled" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableModelRoutingEnabled(v *bool) *GroupUpdate {
if v != nil {
_u.SetModelRoutingEnabled(*v)
}
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -829,15 +803,6 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.FallbackGroupIDCleared() {
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
}
if value, ok := _u.mutation.ModelRouting(); ok {
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
}
if _u.mutation.ModelRoutingCleared() {
_spec.ClearField(group.FieldModelRouting, field.TypeJSON)
}
if value, ok := _u.mutation.ModelRoutingEnabled(); ok {
_spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1513,32 +1478,6 @@ func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne {
return _u
}
// SetModelRouting sets the "model_routing" field.
func (_u *GroupUpdateOne) SetModelRouting(v map[string][]int64) *GroupUpdateOne {
_u.mutation.SetModelRouting(v)
return _u
}
// ClearModelRouting clears the value of the "model_routing" field.
func (_u *GroupUpdateOne) ClearModelRouting() *GroupUpdateOne {
_u.mutation.ClearModelRouting()
return _u
}
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
func (_u *GroupUpdateOne) SetModelRoutingEnabled(v bool) *GroupUpdateOne {
_u.mutation.SetModelRoutingEnabled(v)
return _u
}
// SetNillableModelRoutingEnabled sets the "model_routing_enabled" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableModelRoutingEnabled(v *bool) *GroupUpdateOne {
if v != nil {
_u.SetModelRoutingEnabled(*v)
}
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1977,15 +1916,6 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if _u.mutation.FallbackGroupIDCleared() {
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
}
if value, ok := _u.mutation.ModelRouting(); ok {
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
}
if _u.mutation.ModelRoutingCleared() {
_spec.ClearField(group.FieldModelRouting, field.TypeJSON)
}
if value, ok := _u.mutation.ModelRoutingEnabled(); ok {
_spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,

View File

@@ -117,18 +117,6 @@ func (f SettingFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, err
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SettingMutation", m)
}
// The UsageCleanupTaskFunc type is an adapter to allow the use of ordinary
// function as UsageCleanupTask mutator.
type UsageCleanupTaskFunc func(context.Context, *ent.UsageCleanupTaskMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f UsageCleanupTaskFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if mv, ok := m.(*ent.UsageCleanupTaskMutation); ok {
return f(ctx, mv)
}
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UsageCleanupTaskMutation", m)
}
// The UsageLogFunc type is an adapter to allow the use of ordinary
// function as UsageLog mutator.
type UsageLogFunc func(context.Context, *ent.UsageLogMutation) (ent.Value, error)

View File

@@ -18,7 +18,6 @@ import (
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/setting"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
@@ -326,33 +325,6 @@ func (f TraverseSetting) Traverse(ctx context.Context, q ent.Query) error {
return fmt.Errorf("unexpected query type %T. expect *ent.SettingQuery", q)
}
// The UsageCleanupTaskFunc type is an adapter to allow the use of ordinary function as a Querier.
type UsageCleanupTaskFunc func(context.Context, *ent.UsageCleanupTaskQuery) (ent.Value, error)
// Query calls f(ctx, q).
func (f UsageCleanupTaskFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
if q, ok := q.(*ent.UsageCleanupTaskQuery); ok {
return f(ctx, q)
}
return nil, fmt.Errorf("unexpected query type %T. expect *ent.UsageCleanupTaskQuery", q)
}
// The TraverseUsageCleanupTask type is an adapter to allow the use of ordinary function as Traverser.
type TraverseUsageCleanupTask func(context.Context, *ent.UsageCleanupTaskQuery) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f TraverseUsageCleanupTask) Intercept(next ent.Querier) ent.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f TraverseUsageCleanupTask) Traverse(ctx context.Context, q ent.Query) error {
if q, ok := q.(*ent.UsageCleanupTaskQuery); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect *ent.UsageCleanupTaskQuery", q)
}
// The UsageLogFunc type is an adapter to allow the use of ordinary function as a Querier.
type UsageLogFunc func(context.Context, *ent.UsageLogQuery) (ent.Value, error)
@@ -536,8 +508,6 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.RedeemCodeQuery, predicate.RedeemCode, redeemcode.OrderOption]{typ: ent.TypeRedeemCode, tq: q}, nil
case *ent.SettingQuery:
return &query[*ent.SettingQuery, predicate.Setting, setting.OrderOption]{typ: ent.TypeSetting, tq: q}, nil
case *ent.UsageCleanupTaskQuery:
return &query[*ent.UsageCleanupTaskQuery, predicate.UsageCleanupTask, usagecleanuptask.OrderOption]{typ: ent.TypeUsageCleanupTask, tq: q}, nil
case *ent.UsageLogQuery:
return &query[*ent.UsageLogQuery, predicate.UsageLog, usagelog.OrderOption]{typ: ent.TypeUsageLog, tq: q}, nil
case *ent.UserQuery:

View File

@@ -79,7 +79,6 @@ var (
{Name: "extra", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "concurrency", Type: field.TypeInt, Default: 3},
{Name: "priority", Type: field.TypeInt, Default: 50},
{Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
{Name: "error_message", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "last_used_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
@@ -102,7 +101,7 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "accounts_proxies_proxy",
Columns: []*schema.Column{AccountsColumns[25]},
Columns: []*schema.Column{AccountsColumns[24]},
RefColumns: []*schema.Column{ProxiesColumns[0]},
OnDelete: schema.SetNull,
},
@@ -121,12 +120,12 @@ var (
{
Name: "account_status",
Unique: false,
Columns: []*schema.Column{AccountsColumns[13]},
Columns: []*schema.Column{AccountsColumns[12]},
},
{
Name: "account_proxy_id",
Unique: false,
Columns: []*schema.Column{AccountsColumns[25]},
Columns: []*schema.Column{AccountsColumns[24]},
},
{
Name: "account_priority",
@@ -136,27 +135,27 @@ var (
{
Name: "account_last_used_at",
Unique: false,
Columns: []*schema.Column{AccountsColumns[15]},
Columns: []*schema.Column{AccountsColumns[14]},
},
{
Name: "account_schedulable",
Unique: false,
Columns: []*schema.Column{AccountsColumns[18]},
Columns: []*schema.Column{AccountsColumns[17]},
},
{
Name: "account_rate_limited_at",
Unique: false,
Columns: []*schema.Column{AccountsColumns[19]},
Columns: []*schema.Column{AccountsColumns[18]},
},
{
Name: "account_rate_limit_reset_at",
Unique: false,
Columns: []*schema.Column{AccountsColumns[20]},
Columns: []*schema.Column{AccountsColumns[19]},
},
{
Name: "account_overload_until",
Unique: false,
Columns: []*schema.Column{AccountsColumns[21]},
Columns: []*schema.Column{AccountsColumns[20]},
},
{
Name: "account_deleted_at",
@@ -226,8 +225,6 @@ var (
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
{Name: "model_routing", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "model_routing_enabled", Type: field.TypeBool, Default: false},
}
// GroupsTable holds the schema information for the "groups" table.
GroupsTable = &schema.Table{
@@ -434,44 +431,6 @@ var (
Columns: SettingsColumns,
PrimaryKey: []*schema.Column{SettingsColumns[0]},
}
// UsageCleanupTasksColumns holds the columns for the "usage_cleanup_tasks" table.
UsageCleanupTasksColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "status", Type: field.TypeString, Size: 20},
{Name: "filters", Type: field.TypeJSON},
{Name: "created_by", Type: field.TypeInt64},
{Name: "deleted_rows", Type: field.TypeInt64, Default: 0},
{Name: "error_message", Type: field.TypeString, Nullable: true},
{Name: "canceled_by", Type: field.TypeInt64, Nullable: true},
{Name: "canceled_at", Type: field.TypeTime, Nullable: true},
{Name: "started_at", Type: field.TypeTime, Nullable: true},
{Name: "finished_at", Type: field.TypeTime, Nullable: true},
}
// UsageCleanupTasksTable holds the schema information for the "usage_cleanup_tasks" table.
UsageCleanupTasksTable = &schema.Table{
Name: "usage_cleanup_tasks",
Columns: UsageCleanupTasksColumns,
PrimaryKey: []*schema.Column{UsageCleanupTasksColumns[0]},
Indexes: []*schema.Index{
{
Name: "usagecleanuptask_status_created_at",
Unique: false,
Columns: []*schema.Column{UsageCleanupTasksColumns[3], UsageCleanupTasksColumns[1]},
},
{
Name: "usagecleanuptask_created_at",
Unique: false,
Columns: []*schema.Column{UsageCleanupTasksColumns[1]},
},
{
Name: "usagecleanuptask_canceled_at",
Unique: false,
Columns: []*schema.Column{UsageCleanupTasksColumns[9]},
},
},
}
// UsageLogsColumns holds the columns for the "usage_logs" table.
UsageLogsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -490,7 +449,6 @@ var (
{Name: "total_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "actual_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
{Name: "account_rate_multiplier", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
{Name: "billing_type", Type: field.TypeInt8, Default: 0},
{Name: "stream", Type: field.TypeBool, Default: false},
{Name: "duration_ms", Type: field.TypeInt, Nullable: true},
@@ -514,31 +472,31 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "usage_logs_api_keys_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[26]},
Columns: []*schema.Column{UsageLogsColumns[25]},
RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_accounts_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[26]},
RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_groups_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[28]},
Columns: []*schema.Column{UsageLogsColumns[27]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "usage_logs_users_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[29]},
Columns: []*schema.Column{UsageLogsColumns[28]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_user_subscriptions_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[30]},
Columns: []*schema.Column{UsageLogsColumns[29]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull,
},
@@ -547,32 +505,32 @@ var (
{
Name: "usagelog_user_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[29]},
Columns: []*schema.Column{UsageLogsColumns[28]},
},
{
Name: "usagelog_api_key_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[26]},
Columns: []*schema.Column{UsageLogsColumns[25]},
},
{
Name: "usagelog_account_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[26]},
},
{
Name: "usagelog_group_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[28]},
Columns: []*schema.Column{UsageLogsColumns[27]},
},
{
Name: "usagelog_subscription_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[30]},
Columns: []*schema.Column{UsageLogsColumns[29]},
},
{
Name: "usagelog_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[25]},
Columns: []*schema.Column{UsageLogsColumns[24]},
},
{
Name: "usagelog_model",
@@ -587,12 +545,12 @@ var (
{
Name: "usagelog_user_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[25]},
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[24]},
},
{
Name: "usagelog_api_key_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[26], UsageLogsColumns[25]},
Columns: []*schema.Column{UsageLogsColumns[25], UsageLogsColumns[24]},
},
},
}
@@ -843,7 +801,6 @@ var (
ProxiesTable,
RedeemCodesTable,
SettingsTable,
UsageCleanupTasksTable,
UsageLogsTable,
UsersTable,
UserAllowedGroupsTable,
@@ -890,9 +847,6 @@ func init() {
SettingsTable.Annotation = &entsql.Annotation{
Table: "settings",
}
UsageCleanupTasksTable.Annotation = &entsql.Annotation{
Table: "usage_cleanup_tasks",
}
UsageLogsTable.ForeignKeys[0].RefTable = APIKeysTable
UsageLogsTable.ForeignKeys[1].RefTable = AccountsTable
UsageLogsTable.ForeignKeys[2].RefTable = GroupsTable

File diff suppressed because it is too large Load Diff

View File

@@ -33,9 +33,6 @@ type RedeemCode func(*sql.Selector)
// Setting is the predicate function for setting builders.
type Setting func(*sql.Selector)
// UsageCleanupTask is the predicate function for usagecleanuptask builders.
type UsageCleanupTask func(*sql.Selector)
// UsageLog is the predicate function for usagelog builders.
type UsageLog func(*sql.Selector)

View File

@@ -15,7 +15,6 @@ import (
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/schema"
"github.com/Wei-Shaw/sub2api/ent/setting"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
@@ -178,26 +177,22 @@ func init() {
accountDescPriority := accountFields[8].Descriptor()
// account.DefaultPriority holds the default value on creation for the priority field.
account.DefaultPriority = accountDescPriority.Default.(int)
// accountDescRateMultiplier is the schema descriptor for rate_multiplier field.
accountDescRateMultiplier := accountFields[9].Descriptor()
// account.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
account.DefaultRateMultiplier = accountDescRateMultiplier.Default.(float64)
// accountDescStatus is the schema descriptor for status field.
accountDescStatus := accountFields[10].Descriptor()
accountDescStatus := accountFields[9].Descriptor()
// account.DefaultStatus holds the default value on creation for the status field.
account.DefaultStatus = accountDescStatus.Default.(string)
// account.StatusValidator is a validator for the "status" field. It is called by the builders before save.
account.StatusValidator = accountDescStatus.Validators[0].(func(string) error)
// accountDescAutoPauseOnExpired is the schema descriptor for auto_pause_on_expired field.
accountDescAutoPauseOnExpired := accountFields[14].Descriptor()
accountDescAutoPauseOnExpired := accountFields[13].Descriptor()
// account.DefaultAutoPauseOnExpired holds the default value on creation for the auto_pause_on_expired field.
account.DefaultAutoPauseOnExpired = accountDescAutoPauseOnExpired.Default.(bool)
// accountDescSchedulable is the schema descriptor for schedulable field.
accountDescSchedulable := accountFields[15].Descriptor()
accountDescSchedulable := accountFields[14].Descriptor()
// account.DefaultSchedulable holds the default value on creation for the schedulable field.
account.DefaultSchedulable = accountDescSchedulable.Default.(bool)
// accountDescSessionWindowStatus is the schema descriptor for session_window_status field.
accountDescSessionWindowStatus := accountFields[21].Descriptor()
accountDescSessionWindowStatus := accountFields[20].Descriptor()
// account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save.
account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error)
accountgroupFields := schema.AccountGroup{}.Fields()
@@ -281,10 +276,6 @@ func init() {
groupDescClaudeCodeOnly := groupFields[14].Descriptor()
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
groupDescModelRoutingEnabled := groupFields[17].Descriptor()
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
promocodeFields := schema.PromoCode{}.Fields()
_ = promocodeFields
// promocodeDescCode is the schema descriptor for code field.
@@ -496,43 +487,6 @@ func init() {
setting.DefaultUpdatedAt = settingDescUpdatedAt.Default.(func() time.Time)
// setting.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
setting.UpdateDefaultUpdatedAt = settingDescUpdatedAt.UpdateDefault.(func() time.Time)
usagecleanuptaskMixin := schema.UsageCleanupTask{}.Mixin()
usagecleanuptaskMixinFields0 := usagecleanuptaskMixin[0].Fields()
_ = usagecleanuptaskMixinFields0
usagecleanuptaskFields := schema.UsageCleanupTask{}.Fields()
_ = usagecleanuptaskFields
// usagecleanuptaskDescCreatedAt is the schema descriptor for created_at field.
usagecleanuptaskDescCreatedAt := usagecleanuptaskMixinFields0[0].Descriptor()
// usagecleanuptask.DefaultCreatedAt holds the default value on creation for the created_at field.
usagecleanuptask.DefaultCreatedAt = usagecleanuptaskDescCreatedAt.Default.(func() time.Time)
// usagecleanuptaskDescUpdatedAt is the schema descriptor for updated_at field.
usagecleanuptaskDescUpdatedAt := usagecleanuptaskMixinFields0[1].Descriptor()
// usagecleanuptask.DefaultUpdatedAt holds the default value on creation for the updated_at field.
usagecleanuptask.DefaultUpdatedAt = usagecleanuptaskDescUpdatedAt.Default.(func() time.Time)
// usagecleanuptask.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
usagecleanuptask.UpdateDefaultUpdatedAt = usagecleanuptaskDescUpdatedAt.UpdateDefault.(func() time.Time)
// usagecleanuptaskDescStatus is the schema descriptor for status field.
usagecleanuptaskDescStatus := usagecleanuptaskFields[0].Descriptor()
// usagecleanuptask.StatusValidator is a validator for the "status" field. It is called by the builders before save.
usagecleanuptask.StatusValidator = func() func(string) error {
validators := usagecleanuptaskDescStatus.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
}
return func(status string) error {
for _, fn := range fns {
if err := fn(status); err != nil {
return err
}
}
return nil
}
}()
// usagecleanuptaskDescDeletedRows is the schema descriptor for deleted_rows field.
usagecleanuptaskDescDeletedRows := usagecleanuptaskFields[3].Descriptor()
// usagecleanuptask.DefaultDeletedRows holds the default value on creation for the deleted_rows field.
usagecleanuptask.DefaultDeletedRows = usagecleanuptaskDescDeletedRows.Default.(int64)
usagelogFields := schema.UsageLog{}.Fields()
_ = usagelogFields
// usagelogDescRequestID is the schema descriptor for request_id field.
@@ -624,31 +578,31 @@ func init() {
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
// usagelogDescBillingType is the schema descriptor for billing_type field.
usagelogDescBillingType := usagelogFields[21].Descriptor()
usagelogDescBillingType := usagelogFields[20].Descriptor()
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
// usagelogDescStream is the schema descriptor for stream field.
usagelogDescStream := usagelogFields[22].Descriptor()
usagelogDescStream := usagelogFields[21].Descriptor()
// usagelog.DefaultStream holds the default value on creation for the stream field.
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
// usagelogDescUserAgent is the schema descriptor for user_agent field.
usagelogDescUserAgent := usagelogFields[25].Descriptor()
usagelogDescUserAgent := usagelogFields[24].Descriptor()
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
// usagelogDescIPAddress is the schema descriptor for ip_address field.
usagelogDescIPAddress := usagelogFields[26].Descriptor()
usagelogDescIPAddress := usagelogFields[25].Descriptor()
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
// usagelogDescImageCount is the schema descriptor for image_count field.
usagelogDescImageCount := usagelogFields[27].Descriptor()
usagelogDescImageCount := usagelogFields[26].Descriptor()
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
// usagelogDescImageSize is the schema descriptor for image_size field.
usagelogDescImageSize := usagelogFields[28].Descriptor()
usagelogDescImageSize := usagelogFields[27].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
// usagelogDescCreatedAt is the schema descriptor for created_at field.
usagelogDescCreatedAt := usagelogFields[29].Descriptor()
usagelogDescCreatedAt := usagelogFields[28].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin()

View File

@@ -102,12 +102,6 @@ func (Account) Fields() []ent.Field {
field.Int("priority").
Default(50),
// rate_multiplier: 账号计费倍率(>=0允许 0 表示该账号计费为 0
// 仅影响账号维度计费口径,不影响用户/API Key 扣费(分组倍率)
field.Float("rate_multiplier").
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}).
Default(1.0),
// status: 账户状态,如 "active", "error", "disabled"
field.String("status").
MaxLen(20).

View File

@@ -95,17 +95,6 @@ func (Group) Fields() []ent.Field {
Optional().
Nillable().
Comment("非 Claude Code 请求降级使用的分组 ID"),
// 模型路由配置 (added by migration 040)
field.JSON("model_routing", map[string][]int64{}).
Optional().
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
Comment("模型路由配置:模型模式 -> 优先账号ID列表"),
// 模型路由开关 (added by migration 041)
field.Bool("model_routing_enabled").
Default(false).
Comment("是否启用模型路由配置"),
}
}

View File

@@ -5,7 +5,6 @@ package mixins
import (
"context"
"fmt"
"reflect"
"time"
"entgo.io/ent"
@@ -13,6 +12,7 @@ import (
"entgo.io/ent/dialect/sql"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/mixin"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/intercept"
)
@@ -113,6 +113,7 @@ func (d SoftDeleteMixin) Hooks() []ent.Hook {
SetOp(ent.Op)
SetDeletedAt(time.Time)
WhereP(...func(*sql.Selector))
Client() *dbent.Client
})
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
@@ -123,7 +124,7 @@ func (d SoftDeleteMixin) Hooks() []ent.Hook {
mx.SetOp(ent.OpUpdate)
// 设置删除时间为当前时间
mx.SetDeletedAt(time.Now())
return mutateWithClient(ctx, m, next)
return mx.Client().Mutate(ctx, m)
})
},
}
@@ -136,41 +137,3 @@ func (d SoftDeleteMixin) applyPredicate(w interface{ WhereP(...func(*sql.Selecto
sql.FieldIsNull(d.Fields()[0].Descriptor().Name),
)
}
func mutateWithClient(ctx context.Context, m ent.Mutation, fallback ent.Mutator) (ent.Value, error) {
clientMethod := reflect.ValueOf(m).MethodByName("Client")
if !clientMethod.IsValid() || clientMethod.Type().NumIn() != 0 || clientMethod.Type().NumOut() != 1 {
return nil, fmt.Errorf("soft delete: mutation client method not found for %T", m)
}
client := clientMethod.Call(nil)[0]
mutateMethod := client.MethodByName("Mutate")
if !mutateMethod.IsValid() {
return nil, fmt.Errorf("soft delete: mutation client missing Mutate for %T", m)
}
if mutateMethod.Type().NumIn() != 2 || mutateMethod.Type().NumOut() != 2 {
return nil, fmt.Errorf("soft delete: mutation client signature mismatch for %T", m)
}
results := mutateMethod.Call([]reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(m)})
value := results[0].Interface()
var err error
if !results[1].IsNil() {
errValue := results[1].Interface()
typedErr, ok := errValue.(error)
if !ok {
return nil, fmt.Errorf("soft delete: unexpected error type %T for %T", errValue, m)
}
err = typedErr
}
if err != nil {
return nil, err
}
if value == nil {
return nil, fmt.Errorf("soft delete: mutation client returned nil for %T", m)
}
v, ok := value.(ent.Value)
if !ok {
return nil, fmt.Errorf("soft delete: unexpected value type %T for %T", value, m)
}
return v, nil
}

View File

@@ -1,75 +0,0 @@
package schema
import (
"encoding/json"
"fmt"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"entgo.io/ent"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// UsageCleanupTask 定义使用记录清理任务的 schema。
type UsageCleanupTask struct {
ent.Schema
}
func (UsageCleanupTask) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "usage_cleanup_tasks"},
}
}
func (UsageCleanupTask) Mixin() []ent.Mixin {
return []ent.Mixin{
mixins.TimeMixin{},
}
}
func (UsageCleanupTask) Fields() []ent.Field {
return []ent.Field{
field.String("status").
MaxLen(20).
Validate(validateUsageCleanupStatus),
field.JSON("filters", json.RawMessage{}),
field.Int64("created_by"),
field.Int64("deleted_rows").
Default(0),
field.String("error_message").
Optional().
Nillable(),
field.Int64("canceled_by").
Optional().
Nillable(),
field.Time("canceled_at").
Optional().
Nillable(),
field.Time("started_at").
Optional().
Nillable(),
field.Time("finished_at").
Optional().
Nillable(),
}
}
func (UsageCleanupTask) Indexes() []ent.Index {
return []ent.Index{
index.Fields("status", "created_at"),
index.Fields("created_at"),
index.Fields("canceled_at"),
}
}
func validateUsageCleanupStatus(status string) error {
switch status {
case "pending", "running", "succeeded", "failed", "canceled":
return nil
default:
return fmt.Errorf("invalid usage cleanup status: %s", status)
}
}

View File

@@ -85,12 +85,6 @@ func (UsageLog) Fields() []ent.Field {
Default(1).
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}),
// account_rate_multiplier: 账号计费倍率快照NULL 表示按 1.0 处理)
field.Float("account_rate_multiplier").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}),
// 其他字段
field.Int8("billing_type").
Default(0),

View File

@@ -32,8 +32,6 @@ type Tx struct {
RedeemCode *RedeemCodeClient
// Setting is the client for interacting with the Setting builders.
Setting *SettingClient
// UsageCleanupTask is the client for interacting with the UsageCleanupTask builders.
UsageCleanupTask *UsageCleanupTaskClient
// UsageLog is the client for interacting with the UsageLog builders.
UsageLog *UsageLogClient
// User is the client for interacting with the User builders.
@@ -186,7 +184,6 @@ func (tx *Tx) init() {
tx.Proxy = NewProxyClient(tx.config)
tx.RedeemCode = NewRedeemCodeClient(tx.config)
tx.Setting = NewSettingClient(tx.config)
tx.UsageCleanupTask = NewUsageCleanupTaskClient(tx.config)
tx.UsageLog = NewUsageLogClient(tx.config)
tx.User = NewUserClient(tx.config)
tx.UserAllowedGroup = NewUserAllowedGroupClient(tx.config)

View File

@@ -1,236 +0,0 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"encoding/json"
"fmt"
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
)
// UsageCleanupTask is the model entity for the UsageCleanupTask schema.
type UsageCleanupTask struct {
config `json:"-"`
// ID of the ent.
ID int64 `json:"id,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// UpdatedAt holds the value of the "updated_at" field.
UpdatedAt time.Time `json:"updated_at,omitempty"`
// Status holds the value of the "status" field.
Status string `json:"status,omitempty"`
// Filters holds the value of the "filters" field.
Filters json.RawMessage `json:"filters,omitempty"`
// CreatedBy holds the value of the "created_by" field.
CreatedBy int64 `json:"created_by,omitempty"`
// DeletedRows holds the value of the "deleted_rows" field.
DeletedRows int64 `json:"deleted_rows,omitempty"`
// ErrorMessage holds the value of the "error_message" field.
ErrorMessage *string `json:"error_message,omitempty"`
// CanceledBy holds the value of the "canceled_by" field.
CanceledBy *int64 `json:"canceled_by,omitempty"`
// CanceledAt holds the value of the "canceled_at" field.
CanceledAt *time.Time `json:"canceled_at,omitempty"`
// StartedAt holds the value of the "started_at" field.
StartedAt *time.Time `json:"started_at,omitempty"`
// FinishedAt holds the value of the "finished_at" field.
FinishedAt *time.Time `json:"finished_at,omitempty"`
selectValues sql.SelectValues
}
// scanValues returns the types for scanning values from sql.Rows.
func (*UsageCleanupTask) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case usagecleanuptask.FieldFilters:
values[i] = new([]byte)
case usagecleanuptask.FieldID, usagecleanuptask.FieldCreatedBy, usagecleanuptask.FieldDeletedRows, usagecleanuptask.FieldCanceledBy:
values[i] = new(sql.NullInt64)
case usagecleanuptask.FieldStatus, usagecleanuptask.FieldErrorMessage:
values[i] = new(sql.NullString)
case usagecleanuptask.FieldCreatedAt, usagecleanuptask.FieldUpdatedAt, usagecleanuptask.FieldCanceledAt, usagecleanuptask.FieldStartedAt, usagecleanuptask.FieldFinishedAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
}
}
return values, nil
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the UsageCleanupTask fields.
func (_m *UsageCleanupTask) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
for i := range columns {
switch columns[i] {
case usagecleanuptask.FieldID:
value, ok := values[i].(*sql.NullInt64)
if !ok {
return fmt.Errorf("unexpected type %T for field id", value)
}
_m.ID = int64(value.Int64)
case usagecleanuptask.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
} else if value.Valid {
_m.CreatedAt = value.Time
}
case usagecleanuptask.FieldUpdatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
} else if value.Valid {
_m.UpdatedAt = value.Time
}
case usagecleanuptask.FieldStatus:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field status", values[i])
} else if value.Valid {
_m.Status = value.String
}
case usagecleanuptask.FieldFilters:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field filters", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.Filters); err != nil {
return fmt.Errorf("unmarshal field filters: %w", err)
}
}
case usagecleanuptask.FieldCreatedBy:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field created_by", values[i])
} else if value.Valid {
_m.CreatedBy = value.Int64
}
case usagecleanuptask.FieldDeletedRows:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field deleted_rows", values[i])
} else if value.Valid {
_m.DeletedRows = value.Int64
}
case usagecleanuptask.FieldErrorMessage:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field error_message", values[i])
} else if value.Valid {
_m.ErrorMessage = new(string)
*_m.ErrorMessage = value.String
}
case usagecleanuptask.FieldCanceledBy:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field canceled_by", values[i])
} else if value.Valid {
_m.CanceledBy = new(int64)
*_m.CanceledBy = value.Int64
}
case usagecleanuptask.FieldCanceledAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field canceled_at", values[i])
} else if value.Valid {
_m.CanceledAt = new(time.Time)
*_m.CanceledAt = value.Time
}
case usagecleanuptask.FieldStartedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field started_at", values[i])
} else if value.Valid {
_m.StartedAt = new(time.Time)
*_m.StartedAt = value.Time
}
case usagecleanuptask.FieldFinishedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field finished_at", values[i])
} else if value.Valid {
_m.FinishedAt = new(time.Time)
*_m.FinishedAt = value.Time
}
default:
_m.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// Value returns the ent.Value that was dynamically selected and assigned to the UsageCleanupTask.
// This includes values selected through modifiers, order, etc.
func (_m *UsageCleanupTask) Value(name string) (ent.Value, error) {
return _m.selectValues.Get(name)
}
// Update returns a builder for updating this UsageCleanupTask.
// Note that you need to call UsageCleanupTask.Unwrap() before calling this method if this UsageCleanupTask
// was returned from a transaction, and the transaction was committed or rolled back.
func (_m *UsageCleanupTask) Update() *UsageCleanupTaskUpdateOne {
return NewUsageCleanupTaskClient(_m.config).UpdateOne(_m)
}
// Unwrap unwraps the UsageCleanupTask entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func (_m *UsageCleanupTask) Unwrap() *UsageCleanupTask {
_tx, ok := _m.config.driver.(*txDriver)
if !ok {
panic("ent: UsageCleanupTask is not a transactional entity")
}
_m.config.driver = _tx.drv
return _m
}
// String implements the fmt.Stringer.
func (_m *UsageCleanupTask) String() string {
var builder strings.Builder
builder.WriteString("UsageCleanupTask(")
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("updated_at=")
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("status=")
builder.WriteString(_m.Status)
builder.WriteString(", ")
builder.WriteString("filters=")
builder.WriteString(fmt.Sprintf("%v", _m.Filters))
builder.WriteString(", ")
builder.WriteString("created_by=")
builder.WriteString(fmt.Sprintf("%v", _m.CreatedBy))
builder.WriteString(", ")
builder.WriteString("deleted_rows=")
builder.WriteString(fmt.Sprintf("%v", _m.DeletedRows))
builder.WriteString(", ")
if v := _m.ErrorMessage; v != nil {
builder.WriteString("error_message=")
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.CanceledBy; v != nil {
builder.WriteString("canceled_by=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
if v := _m.CanceledAt; v != nil {
builder.WriteString("canceled_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
if v := _m.StartedAt; v != nil {
builder.WriteString("started_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
if v := _m.FinishedAt; v != nil {
builder.WriteString("finished_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteByte(')')
return builder.String()
}
// UsageCleanupTasks is a parsable slice of UsageCleanupTask.
type UsageCleanupTasks []*UsageCleanupTask

View File

@@ -1,137 +0,0 @@
// Code generated by ent, DO NOT EDIT.
package usagecleanuptask
import (
"time"
"entgo.io/ent/dialect/sql"
)
const (
// Label holds the string label denoting the usagecleanuptask type in the database.
Label = "usage_cleanup_task"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
FieldUpdatedAt = "updated_at"
// FieldStatus holds the string denoting the status field in the database.
FieldStatus = "status"
// FieldFilters holds the string denoting the filters field in the database.
FieldFilters = "filters"
// FieldCreatedBy holds the string denoting the created_by field in the database.
FieldCreatedBy = "created_by"
// FieldDeletedRows holds the string denoting the deleted_rows field in the database.
FieldDeletedRows = "deleted_rows"
// FieldErrorMessage holds the string denoting the error_message field in the database.
FieldErrorMessage = "error_message"
// FieldCanceledBy holds the string denoting the canceled_by field in the database.
FieldCanceledBy = "canceled_by"
// FieldCanceledAt holds the string denoting the canceled_at field in the database.
FieldCanceledAt = "canceled_at"
// FieldStartedAt holds the string denoting the started_at field in the database.
FieldStartedAt = "started_at"
// FieldFinishedAt holds the string denoting the finished_at field in the database.
FieldFinishedAt = "finished_at"
// Table holds the table name of the usagecleanuptask in the database.
Table = "usage_cleanup_tasks"
)
// Columns holds all SQL columns for usagecleanuptask fields.
var Columns = []string{
FieldID,
FieldCreatedAt,
FieldUpdatedAt,
FieldStatus,
FieldFilters,
FieldCreatedBy,
FieldDeletedRows,
FieldErrorMessage,
FieldCanceledBy,
FieldCanceledAt,
FieldStartedAt,
FieldFinishedAt,
}
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
if column == Columns[i] {
return true
}
}
return false
}
var (
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
DefaultUpdatedAt func() time.Time
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
UpdateDefaultUpdatedAt func() time.Time
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
StatusValidator func(string) error
// DefaultDeletedRows holds the default value on creation for the "deleted_rows" field.
DefaultDeletedRows int64
)
// OrderOption defines the ordering options for the UsageCleanupTask queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByUpdatedAt orders the results by the updated_at field.
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// ByStatus orders the results by the status field.
func ByStatus(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStatus, opts...).ToFunc()
}
// ByCreatedBy orders the results by the created_by field.
func ByCreatedBy(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedBy, opts...).ToFunc()
}
// ByDeletedRows orders the results by the deleted_rows field.
func ByDeletedRows(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDeletedRows, opts...).ToFunc()
}
// ByErrorMessage orders the results by the error_message field.
func ByErrorMessage(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldErrorMessage, opts...).ToFunc()
}
// ByCanceledBy orders the results by the canceled_by field.
func ByCanceledBy(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCanceledBy, opts...).ToFunc()
}
// ByCanceledAt orders the results by the canceled_at field.
func ByCanceledAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCanceledAt, opts...).ToFunc()
}
// ByStartedAt orders the results by the started_at field.
func ByStartedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStartedAt, opts...).ToFunc()
}
// ByFinishedAt orders the results by the finished_at field.
func ByFinishedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldFinishedAt, opts...).ToFunc()
}

View File

@@ -1,620 +0,0 @@
// Code generated by ent, DO NOT EDIT.
package usagecleanuptask
import (
"time"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// ID filters vertices based on their ID field.
func ID(id int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldID, id))
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldID, id))
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldID, id))
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIn(FieldID, ids...))
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldID, ids...))
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGT(FieldID, id))
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGTE(FieldID, id))
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLT(FieldID, id))
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLTE(FieldID, id))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldCreatedAt, v))
}
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
func UpdatedAt(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldUpdatedAt, v))
}
// Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
func Status(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldStatus, v))
}
// CreatedBy applies equality check predicate on the "created_by" field. It's identical to CreatedByEQ.
func CreatedBy(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldCreatedBy, v))
}
// DeletedRows applies equality check predicate on the "deleted_rows" field. It's identical to DeletedRowsEQ.
func DeletedRows(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldDeletedRows, v))
}
// ErrorMessage applies equality check predicate on the "error_message" field. It's identical to ErrorMessageEQ.
func ErrorMessage(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldErrorMessage, v))
}
// CanceledBy applies equality check predicate on the "canceled_by" field. It's identical to CanceledByEQ.
func CanceledBy(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldCanceledBy, v))
}
// CanceledAt applies equality check predicate on the "canceled_at" field. It's identical to CanceledAtEQ.
func CanceledAt(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldCanceledAt, v))
}
// StartedAt applies equality check predicate on the "started_at" field. It's identical to StartedAtEQ.
func StartedAt(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldStartedAt, v))
}
// FinishedAt applies equality check predicate on the "finished_at" field. It's identical to FinishedAtEQ.
func FinishedAt(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldFinishedAt, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldCreatedAt, v))
}
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
func CreatedAtNEQ(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldCreatedAt, v))
}
// CreatedAtIn applies the In predicate on the "created_at" field.
func CreatedAtIn(vs ...time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIn(FieldCreatedAt, vs...))
}
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
func CreatedAtNotIn(vs ...time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldCreatedAt, vs...))
}
// CreatedAtGT applies the GT predicate on the "created_at" field.
func CreatedAtGT(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGT(FieldCreatedAt, v))
}
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
func CreatedAtGTE(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGTE(FieldCreatedAt, v))
}
// CreatedAtLT applies the LT predicate on the "created_at" field.
func CreatedAtLT(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLT(FieldCreatedAt, v))
}
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
func CreatedAtLTE(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLTE(FieldCreatedAt, v))
}
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
func UpdatedAtEQ(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldUpdatedAt, v))
}
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
func UpdatedAtNEQ(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldUpdatedAt, v))
}
// UpdatedAtIn applies the In predicate on the "updated_at" field.
func UpdatedAtIn(vs ...time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIn(FieldUpdatedAt, vs...))
}
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
func UpdatedAtNotIn(vs ...time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldUpdatedAt, vs...))
}
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
func UpdatedAtGT(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGT(FieldUpdatedAt, v))
}
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
func UpdatedAtGTE(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGTE(FieldUpdatedAt, v))
}
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
func UpdatedAtLT(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLT(FieldUpdatedAt, v))
}
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
func UpdatedAtLTE(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLTE(FieldUpdatedAt, v))
}
// StatusEQ applies the EQ predicate on the "status" field.
func StatusEQ(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldStatus, v))
}
// StatusNEQ applies the NEQ predicate on the "status" field.
func StatusNEQ(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldStatus, v))
}
// StatusIn applies the In predicate on the "status" field.
func StatusIn(vs ...string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIn(FieldStatus, vs...))
}
// StatusNotIn applies the NotIn predicate on the "status" field.
func StatusNotIn(vs ...string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldStatus, vs...))
}
// StatusGT applies the GT predicate on the "status" field.
func StatusGT(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGT(FieldStatus, v))
}
// StatusGTE applies the GTE predicate on the "status" field.
func StatusGTE(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGTE(FieldStatus, v))
}
// StatusLT applies the LT predicate on the "status" field.
func StatusLT(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLT(FieldStatus, v))
}
// StatusLTE applies the LTE predicate on the "status" field.
func StatusLTE(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLTE(FieldStatus, v))
}
// StatusContains applies the Contains predicate on the "status" field.
func StatusContains(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldContains(FieldStatus, v))
}
// StatusHasPrefix applies the HasPrefix predicate on the "status" field.
func StatusHasPrefix(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldHasPrefix(FieldStatus, v))
}
// StatusHasSuffix applies the HasSuffix predicate on the "status" field.
func StatusHasSuffix(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldHasSuffix(FieldStatus, v))
}
// StatusEqualFold applies the EqualFold predicate on the "status" field.
func StatusEqualFold(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEqualFold(FieldStatus, v))
}
// StatusContainsFold applies the ContainsFold predicate on the "status" field.
func StatusContainsFold(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldContainsFold(FieldStatus, v))
}
// CreatedByEQ applies the EQ predicate on the "created_by" field.
func CreatedByEQ(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldCreatedBy, v))
}
// CreatedByNEQ applies the NEQ predicate on the "created_by" field.
func CreatedByNEQ(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldCreatedBy, v))
}
// CreatedByIn applies the In predicate on the "created_by" field.
func CreatedByIn(vs ...int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIn(FieldCreatedBy, vs...))
}
// CreatedByNotIn applies the NotIn predicate on the "created_by" field.
func CreatedByNotIn(vs ...int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldCreatedBy, vs...))
}
// CreatedByGT applies the GT predicate on the "created_by" field.
func CreatedByGT(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGT(FieldCreatedBy, v))
}
// CreatedByGTE applies the GTE predicate on the "created_by" field.
func CreatedByGTE(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGTE(FieldCreatedBy, v))
}
// CreatedByLT applies the LT predicate on the "created_by" field.
func CreatedByLT(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLT(FieldCreatedBy, v))
}
// CreatedByLTE applies the LTE predicate on the "created_by" field.
func CreatedByLTE(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLTE(FieldCreatedBy, v))
}
// DeletedRowsEQ applies the EQ predicate on the "deleted_rows" field.
func DeletedRowsEQ(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldDeletedRows, v))
}
// DeletedRowsNEQ applies the NEQ predicate on the "deleted_rows" field.
func DeletedRowsNEQ(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldDeletedRows, v))
}
// DeletedRowsIn applies the In predicate on the "deleted_rows" field.
func DeletedRowsIn(vs ...int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIn(FieldDeletedRows, vs...))
}
// DeletedRowsNotIn applies the NotIn predicate on the "deleted_rows" field.
func DeletedRowsNotIn(vs ...int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldDeletedRows, vs...))
}
// DeletedRowsGT applies the GT predicate on the "deleted_rows" field.
func DeletedRowsGT(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGT(FieldDeletedRows, v))
}
// DeletedRowsGTE applies the GTE predicate on the "deleted_rows" field.
func DeletedRowsGTE(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGTE(FieldDeletedRows, v))
}
// DeletedRowsLT applies the LT predicate on the "deleted_rows" field.
func DeletedRowsLT(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLT(FieldDeletedRows, v))
}
// DeletedRowsLTE applies the LTE predicate on the "deleted_rows" field.
func DeletedRowsLTE(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLTE(FieldDeletedRows, v))
}
// ErrorMessageEQ applies the EQ predicate on the "error_message" field.
func ErrorMessageEQ(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldErrorMessage, v))
}
// ErrorMessageNEQ applies the NEQ predicate on the "error_message" field.
func ErrorMessageNEQ(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldErrorMessage, v))
}
// ErrorMessageIn applies the In predicate on the "error_message" field.
func ErrorMessageIn(vs ...string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIn(FieldErrorMessage, vs...))
}
// ErrorMessageNotIn applies the NotIn predicate on the "error_message" field.
func ErrorMessageNotIn(vs ...string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldErrorMessage, vs...))
}
// ErrorMessageGT applies the GT predicate on the "error_message" field.
func ErrorMessageGT(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGT(FieldErrorMessage, v))
}
// ErrorMessageGTE applies the GTE predicate on the "error_message" field.
func ErrorMessageGTE(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGTE(FieldErrorMessage, v))
}
// ErrorMessageLT applies the LT predicate on the "error_message" field.
func ErrorMessageLT(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLT(FieldErrorMessage, v))
}
// ErrorMessageLTE applies the LTE predicate on the "error_message" field.
func ErrorMessageLTE(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLTE(FieldErrorMessage, v))
}
// ErrorMessageContains applies the Contains predicate on the "error_message" field.
func ErrorMessageContains(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldContains(FieldErrorMessage, v))
}
// ErrorMessageHasPrefix applies the HasPrefix predicate on the "error_message" field.
func ErrorMessageHasPrefix(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldHasPrefix(FieldErrorMessage, v))
}
// ErrorMessageHasSuffix applies the HasSuffix predicate on the "error_message" field.
func ErrorMessageHasSuffix(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldHasSuffix(FieldErrorMessage, v))
}
// ErrorMessageIsNil applies the IsNil predicate on the "error_message" field.
func ErrorMessageIsNil() predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIsNull(FieldErrorMessage))
}
// ErrorMessageNotNil applies the NotNil predicate on the "error_message" field.
func ErrorMessageNotNil() predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotNull(FieldErrorMessage))
}
// ErrorMessageEqualFold applies the EqualFold predicate on the "error_message" field.
func ErrorMessageEqualFold(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEqualFold(FieldErrorMessage, v))
}
// ErrorMessageContainsFold applies the ContainsFold predicate on the "error_message" field.
func ErrorMessageContainsFold(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldContainsFold(FieldErrorMessage, v))
}
// CanceledByEQ applies the EQ predicate on the "canceled_by" field.
func CanceledByEQ(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldCanceledBy, v))
}
// CanceledByNEQ applies the NEQ predicate on the "canceled_by" field.
func CanceledByNEQ(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldCanceledBy, v))
}
// CanceledByIn applies the In predicate on the "canceled_by" field.
func CanceledByIn(vs ...int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIn(FieldCanceledBy, vs...))
}
// CanceledByNotIn applies the NotIn predicate on the "canceled_by" field.
func CanceledByNotIn(vs ...int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldCanceledBy, vs...))
}
// CanceledByGT applies the GT predicate on the "canceled_by" field.
func CanceledByGT(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGT(FieldCanceledBy, v))
}
// CanceledByGTE applies the GTE predicate on the "canceled_by" field.
func CanceledByGTE(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGTE(FieldCanceledBy, v))
}
// CanceledByLT applies the LT predicate on the "canceled_by" field.
func CanceledByLT(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLT(FieldCanceledBy, v))
}
// CanceledByLTE applies the LTE predicate on the "canceled_by" field.
func CanceledByLTE(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLTE(FieldCanceledBy, v))
}
// CanceledByIsNil applies the IsNil predicate on the "canceled_by" field.
func CanceledByIsNil() predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIsNull(FieldCanceledBy))
}
// CanceledByNotNil applies the NotNil predicate on the "canceled_by" field.
func CanceledByNotNil() predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotNull(FieldCanceledBy))
}
// CanceledAtEQ applies the EQ predicate on the "canceled_at" field.
func CanceledAtEQ(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldCanceledAt, v))
}
// CanceledAtNEQ applies the NEQ predicate on the "canceled_at" field.
func CanceledAtNEQ(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldCanceledAt, v))
}
// CanceledAtIn applies the In predicate on the "canceled_at" field.
func CanceledAtIn(vs ...time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIn(FieldCanceledAt, vs...))
}
// CanceledAtNotIn applies the NotIn predicate on the "canceled_at" field.
func CanceledAtNotIn(vs ...time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldCanceledAt, vs...))
}
// CanceledAtGT applies the GT predicate on the "canceled_at" field.
func CanceledAtGT(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGT(FieldCanceledAt, v))
}
// CanceledAtGTE applies the GTE predicate on the "canceled_at" field.
func CanceledAtGTE(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGTE(FieldCanceledAt, v))
}
// CanceledAtLT applies the LT predicate on the "canceled_at" field.
func CanceledAtLT(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLT(FieldCanceledAt, v))
}
// CanceledAtLTE applies the LTE predicate on the "canceled_at" field.
func CanceledAtLTE(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLTE(FieldCanceledAt, v))
}
// CanceledAtIsNil applies the IsNil predicate on the "canceled_at" field.
func CanceledAtIsNil() predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIsNull(FieldCanceledAt))
}
// CanceledAtNotNil applies the NotNil predicate on the "canceled_at" field.
func CanceledAtNotNil() predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotNull(FieldCanceledAt))
}
// StartedAtEQ applies the EQ predicate on the "started_at" field.
func StartedAtEQ(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldStartedAt, v))
}
// StartedAtNEQ applies the NEQ predicate on the "started_at" field.
func StartedAtNEQ(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldStartedAt, v))
}
// StartedAtIn applies the In predicate on the "started_at" field.
func StartedAtIn(vs ...time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIn(FieldStartedAt, vs...))
}
// StartedAtNotIn applies the NotIn predicate on the "started_at" field.
func StartedAtNotIn(vs ...time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldStartedAt, vs...))
}
// StartedAtGT applies the GT predicate on the "started_at" field.
func StartedAtGT(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGT(FieldStartedAt, v))
}
// StartedAtGTE applies the GTE predicate on the "started_at" field.
func StartedAtGTE(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGTE(FieldStartedAt, v))
}
// StartedAtLT applies the LT predicate on the "started_at" field.
func StartedAtLT(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLT(FieldStartedAt, v))
}
// StartedAtLTE applies the LTE predicate on the "started_at" field.
func StartedAtLTE(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLTE(FieldStartedAt, v))
}
// StartedAtIsNil applies the IsNil predicate on the "started_at" field.
func StartedAtIsNil() predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIsNull(FieldStartedAt))
}
// StartedAtNotNil applies the NotNil predicate on the "started_at" field.
func StartedAtNotNil() predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotNull(FieldStartedAt))
}
// FinishedAtEQ applies the EQ predicate on the "finished_at" field.
func FinishedAtEQ(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldFinishedAt, v))
}
// FinishedAtNEQ applies the NEQ predicate on the "finished_at" field.
func FinishedAtNEQ(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldFinishedAt, v))
}
// FinishedAtIn applies the In predicate on the "finished_at" field.
func FinishedAtIn(vs ...time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIn(FieldFinishedAt, vs...))
}
// FinishedAtNotIn applies the NotIn predicate on the "finished_at" field.
func FinishedAtNotIn(vs ...time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldFinishedAt, vs...))
}
// FinishedAtGT applies the GT predicate on the "finished_at" field.
func FinishedAtGT(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGT(FieldFinishedAt, v))
}
// FinishedAtGTE applies the GTE predicate on the "finished_at" field.
func FinishedAtGTE(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGTE(FieldFinishedAt, v))
}
// FinishedAtLT applies the LT predicate on the "finished_at" field.
func FinishedAtLT(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLT(FieldFinishedAt, v))
}
// FinishedAtLTE applies the LTE predicate on the "finished_at" field.
func FinishedAtLTE(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLTE(FieldFinishedAt, v))
}
// FinishedAtIsNil applies the IsNil predicate on the "finished_at" field.
func FinishedAtIsNil() predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIsNull(FieldFinishedAt))
}
// FinishedAtNotNil applies the NotNil predicate on the "finished_at" field.
func FinishedAtNotNil() predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotNull(FieldFinishedAt))
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.UsageCleanupTask) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.AndPredicates(predicates...))
}
// Or groups predicates with the OR operator between them.
func Or(predicates ...predicate.UsageCleanupTask) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.OrPredicates(predicates...))
}
// Not applies the not operator on the given predicate.
func Not(p predicate.UsageCleanupTask) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.NotPredicates(p))
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,88 +0,0 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
)
// UsageCleanupTaskDelete is the builder for deleting a UsageCleanupTask entity.
type UsageCleanupTaskDelete struct {
config
hooks []Hook
mutation *UsageCleanupTaskMutation
}
// Where appends a list predicates to the UsageCleanupTaskDelete builder.
func (_d *UsageCleanupTaskDelete) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskDelete {
_d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (_d *UsageCleanupTaskDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *UsageCleanupTaskDelete) ExecX(ctx context.Context) int {
n, err := _d.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (_d *UsageCleanupTaskDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(usagecleanuptask.Table, sqlgraph.NewFieldSpec(usagecleanuptask.FieldID, field.TypeInt64))
if ps := _d.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
_d.mutation.done = true
return affected, err
}
// UsageCleanupTaskDeleteOne is the builder for deleting a single UsageCleanupTask entity.
type UsageCleanupTaskDeleteOne struct {
_d *UsageCleanupTaskDelete
}
// Where appends a list predicates to the UsageCleanupTaskDelete builder.
func (_d *UsageCleanupTaskDeleteOne) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskDeleteOne {
_d._d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query.
func (_d *UsageCleanupTaskDeleteOne) Exec(ctx context.Context) error {
n, err := _d._d.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{usagecleanuptask.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *UsageCleanupTaskDeleteOne) ExecX(ctx context.Context) {
if err := _d.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -1,564 +0,0 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
)
// UsageCleanupTaskQuery is the builder for querying UsageCleanupTask entities.
type UsageCleanupTaskQuery struct {
config
ctx *QueryContext
order []usagecleanuptask.OrderOption
inters []Interceptor
predicates []predicate.UsageCleanupTask
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the UsageCleanupTaskQuery builder.
func (_q *UsageCleanupTaskQuery) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *UsageCleanupTaskQuery) Limit(limit int) *UsageCleanupTaskQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *UsageCleanupTaskQuery) Offset(offset int) *UsageCleanupTaskQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *UsageCleanupTaskQuery) Unique(unique bool) *UsageCleanupTaskQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *UsageCleanupTaskQuery) Order(o ...usagecleanuptask.OrderOption) *UsageCleanupTaskQuery {
_q.order = append(_q.order, o...)
return _q
}
// First returns the first UsageCleanupTask entity from the query.
// Returns a *NotFoundError when no UsageCleanupTask was found.
func (_q *UsageCleanupTaskQuery) First(ctx context.Context) (*UsageCleanupTask, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{usagecleanuptask.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *UsageCleanupTaskQuery) FirstX(ctx context.Context) *UsageCleanupTask {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first UsageCleanupTask ID from the query.
// Returns a *NotFoundError when no UsageCleanupTask ID was found.
func (_q *UsageCleanupTaskQuery) FirstID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{usagecleanuptask.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *UsageCleanupTaskQuery) FirstIDX(ctx context.Context) int64 {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single UsageCleanupTask entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one UsageCleanupTask entity is found.
// Returns a *NotFoundError when no UsageCleanupTask entities are found.
func (_q *UsageCleanupTaskQuery) Only(ctx context.Context) (*UsageCleanupTask, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{usagecleanuptask.Label}
default:
return nil, &NotSingularError{usagecleanuptask.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *UsageCleanupTaskQuery) OnlyX(ctx context.Context) *UsageCleanupTask {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only UsageCleanupTask ID in the query.
// Returns a *NotSingularError when more than one UsageCleanupTask ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *UsageCleanupTaskQuery) OnlyID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{usagecleanuptask.Label}
default:
err = &NotSingularError{usagecleanuptask.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *UsageCleanupTaskQuery) OnlyIDX(ctx context.Context) int64 {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of UsageCleanupTasks.
func (_q *UsageCleanupTaskQuery) All(ctx context.Context) ([]*UsageCleanupTask, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*UsageCleanupTask, *UsageCleanupTaskQuery]()
return withInterceptors[[]*UsageCleanupTask](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *UsageCleanupTaskQuery) AllX(ctx context.Context) []*UsageCleanupTask {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of UsageCleanupTask IDs.
func (_q *UsageCleanupTaskQuery) IDs(ctx context.Context) (ids []int64, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(usagecleanuptask.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *UsageCleanupTaskQuery) IDsX(ctx context.Context) []int64 {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *UsageCleanupTaskQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*UsageCleanupTaskQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *UsageCleanupTaskQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *UsageCleanupTaskQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *UsageCleanupTaskQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the UsageCleanupTaskQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *UsageCleanupTaskQuery) Clone() *UsageCleanupTaskQuery {
if _q == nil {
return nil
}
return &UsageCleanupTaskQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]usagecleanuptask.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.UsageCleanupTask{}, _q.predicates...),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.UsageCleanupTask.Query().
// GroupBy(usagecleanuptask.FieldCreatedAt).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *UsageCleanupTaskQuery) GroupBy(field string, fields ...string) *UsageCleanupTaskGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &UsageCleanupTaskGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = usagecleanuptask.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// }
//
// client.UsageCleanupTask.Query().
// Select(usagecleanuptask.FieldCreatedAt).
// Scan(ctx, &v)
func (_q *UsageCleanupTaskQuery) Select(fields ...string) *UsageCleanupTaskSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &UsageCleanupTaskSelect{UsageCleanupTaskQuery: _q}
sbuild.label = usagecleanuptask.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a UsageCleanupTaskSelect configured with the given aggregations.
func (_q *UsageCleanupTaskQuery) Aggregate(fns ...AggregateFunc) *UsageCleanupTaskSelect {
return _q.Select().Aggregate(fns...)
}
func (_q *UsageCleanupTaskQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !usagecleanuptask.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *UsageCleanupTaskQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UsageCleanupTask, error) {
var (
nodes = []*UsageCleanupTask{}
_spec = _q.querySpec()
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*UsageCleanupTask).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &UsageCleanupTask{config: _q.config}
nodes = append(nodes, node)
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
return nodes, nil
}
func (_q *UsageCleanupTaskQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *UsageCleanupTaskQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(usagecleanuptask.Table, usagecleanuptask.Columns, sqlgraph.NewFieldSpec(usagecleanuptask.FieldID, field.TypeInt64))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, usagecleanuptask.FieldID)
for i := range fields {
if fields[i] != usagecleanuptask.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *UsageCleanupTaskQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(usagecleanuptask.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = usagecleanuptask.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UsageCleanupTaskQuery) ForUpdate(opts ...sql.LockOption) *UsageCleanupTaskQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UsageCleanupTaskQuery) ForShare(opts ...sql.LockOption) *UsageCleanupTaskQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UsageCleanupTaskGroupBy is the group-by builder for UsageCleanupTask entities.
type UsageCleanupTaskGroupBy struct {
selector
build *UsageCleanupTaskQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *UsageCleanupTaskGroupBy) Aggregate(fns ...AggregateFunc) *UsageCleanupTaskGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *UsageCleanupTaskGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*UsageCleanupTaskQuery, *UsageCleanupTaskGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *UsageCleanupTaskGroupBy) sqlScan(ctx context.Context, root *UsageCleanupTaskQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// UsageCleanupTaskSelect is the builder for selecting fields of UsageCleanupTask entities.
type UsageCleanupTaskSelect struct {
*UsageCleanupTaskQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *UsageCleanupTaskSelect) Aggregate(fns ...AggregateFunc) *UsageCleanupTaskSelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *UsageCleanupTaskSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*UsageCleanupTaskQuery, *UsageCleanupTaskSelect](ctx, _s.UsageCleanupTaskQuery, _s, _s.inters, v)
}
func (_s *UsageCleanupTaskSelect) sqlScan(ctx context.Context, root *UsageCleanupTaskQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}

View File

@@ -1,702 +0,0 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/dialect/sql/sqljson"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
)
// UsageCleanupTaskUpdate is the builder for updating UsageCleanupTask entities.
type UsageCleanupTaskUpdate struct {
config
hooks []Hook
mutation *UsageCleanupTaskMutation
}
// Where appends a list predicates to the UsageCleanupTaskUpdate builder.
func (_u *UsageCleanupTaskUpdate) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *UsageCleanupTaskUpdate) SetUpdatedAt(v time.Time) *UsageCleanupTaskUpdate {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetStatus sets the "status" field.
func (_u *UsageCleanupTaskUpdate) SetStatus(v string) *UsageCleanupTaskUpdate {
_u.mutation.SetStatus(v)
return _u
}
// SetNillableStatus sets the "status" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdate) SetNillableStatus(v *string) *UsageCleanupTaskUpdate {
if v != nil {
_u.SetStatus(*v)
}
return _u
}
// SetFilters sets the "filters" field.
func (_u *UsageCleanupTaskUpdate) SetFilters(v json.RawMessage) *UsageCleanupTaskUpdate {
_u.mutation.SetFilters(v)
return _u
}
// AppendFilters appends value to the "filters" field.
func (_u *UsageCleanupTaskUpdate) AppendFilters(v json.RawMessage) *UsageCleanupTaskUpdate {
_u.mutation.AppendFilters(v)
return _u
}
// SetCreatedBy sets the "created_by" field.
func (_u *UsageCleanupTaskUpdate) SetCreatedBy(v int64) *UsageCleanupTaskUpdate {
_u.mutation.ResetCreatedBy()
_u.mutation.SetCreatedBy(v)
return _u
}
// SetNillableCreatedBy sets the "created_by" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdate) SetNillableCreatedBy(v *int64) *UsageCleanupTaskUpdate {
if v != nil {
_u.SetCreatedBy(*v)
}
return _u
}
// AddCreatedBy adds value to the "created_by" field.
func (_u *UsageCleanupTaskUpdate) AddCreatedBy(v int64) *UsageCleanupTaskUpdate {
_u.mutation.AddCreatedBy(v)
return _u
}
// SetDeletedRows sets the "deleted_rows" field.
func (_u *UsageCleanupTaskUpdate) SetDeletedRows(v int64) *UsageCleanupTaskUpdate {
_u.mutation.ResetDeletedRows()
_u.mutation.SetDeletedRows(v)
return _u
}
// SetNillableDeletedRows sets the "deleted_rows" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdate) SetNillableDeletedRows(v *int64) *UsageCleanupTaskUpdate {
if v != nil {
_u.SetDeletedRows(*v)
}
return _u
}
// AddDeletedRows adds value to the "deleted_rows" field.
func (_u *UsageCleanupTaskUpdate) AddDeletedRows(v int64) *UsageCleanupTaskUpdate {
_u.mutation.AddDeletedRows(v)
return _u
}
// SetErrorMessage sets the "error_message" field.
func (_u *UsageCleanupTaskUpdate) SetErrorMessage(v string) *UsageCleanupTaskUpdate {
_u.mutation.SetErrorMessage(v)
return _u
}
// SetNillableErrorMessage sets the "error_message" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdate) SetNillableErrorMessage(v *string) *UsageCleanupTaskUpdate {
if v != nil {
_u.SetErrorMessage(*v)
}
return _u
}
// ClearErrorMessage clears the value of the "error_message" field.
func (_u *UsageCleanupTaskUpdate) ClearErrorMessage() *UsageCleanupTaskUpdate {
_u.mutation.ClearErrorMessage()
return _u
}
// SetCanceledBy sets the "canceled_by" field.
func (_u *UsageCleanupTaskUpdate) SetCanceledBy(v int64) *UsageCleanupTaskUpdate {
_u.mutation.ResetCanceledBy()
_u.mutation.SetCanceledBy(v)
return _u
}
// SetNillableCanceledBy sets the "canceled_by" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdate) SetNillableCanceledBy(v *int64) *UsageCleanupTaskUpdate {
if v != nil {
_u.SetCanceledBy(*v)
}
return _u
}
// AddCanceledBy adds value to the "canceled_by" field.
func (_u *UsageCleanupTaskUpdate) AddCanceledBy(v int64) *UsageCleanupTaskUpdate {
_u.mutation.AddCanceledBy(v)
return _u
}
// ClearCanceledBy clears the value of the "canceled_by" field.
func (_u *UsageCleanupTaskUpdate) ClearCanceledBy() *UsageCleanupTaskUpdate {
_u.mutation.ClearCanceledBy()
return _u
}
// SetCanceledAt sets the "canceled_at" field.
func (_u *UsageCleanupTaskUpdate) SetCanceledAt(v time.Time) *UsageCleanupTaskUpdate {
_u.mutation.SetCanceledAt(v)
return _u
}
// SetNillableCanceledAt sets the "canceled_at" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdate) SetNillableCanceledAt(v *time.Time) *UsageCleanupTaskUpdate {
if v != nil {
_u.SetCanceledAt(*v)
}
return _u
}
// ClearCanceledAt clears the value of the "canceled_at" field.
func (_u *UsageCleanupTaskUpdate) ClearCanceledAt() *UsageCleanupTaskUpdate {
_u.mutation.ClearCanceledAt()
return _u
}
// SetStartedAt sets the "started_at" field.
func (_u *UsageCleanupTaskUpdate) SetStartedAt(v time.Time) *UsageCleanupTaskUpdate {
_u.mutation.SetStartedAt(v)
return _u
}
// SetNillableStartedAt sets the "started_at" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdate) SetNillableStartedAt(v *time.Time) *UsageCleanupTaskUpdate {
if v != nil {
_u.SetStartedAt(*v)
}
return _u
}
// ClearStartedAt clears the value of the "started_at" field.
func (_u *UsageCleanupTaskUpdate) ClearStartedAt() *UsageCleanupTaskUpdate {
_u.mutation.ClearStartedAt()
return _u
}
// SetFinishedAt sets the "finished_at" field.
func (_u *UsageCleanupTaskUpdate) SetFinishedAt(v time.Time) *UsageCleanupTaskUpdate {
_u.mutation.SetFinishedAt(v)
return _u
}
// SetNillableFinishedAt sets the "finished_at" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdate) SetNillableFinishedAt(v *time.Time) *UsageCleanupTaskUpdate {
if v != nil {
_u.SetFinishedAt(*v)
}
return _u
}
// ClearFinishedAt clears the value of the "finished_at" field.
func (_u *UsageCleanupTaskUpdate) ClearFinishedAt() *UsageCleanupTaskUpdate {
_u.mutation.ClearFinishedAt()
return _u
}
// Mutation returns the UsageCleanupTaskMutation object of the builder.
func (_u *UsageCleanupTaskUpdate) Mutation() *UsageCleanupTaskMutation {
return _u.mutation
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UsageCleanupTaskUpdate) Save(ctx context.Context) (int, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *UsageCleanupTaskUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *UsageCleanupTaskUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *UsageCleanupTaskUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *UsageCleanupTaskUpdate) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := usagecleanuptask.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *UsageCleanupTaskUpdate) check() error {
if v, ok := _u.mutation.Status(); ok {
if err := usagecleanuptask.StatusValidator(v); err != nil {
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "UsageCleanupTask.status": %w`, err)}
}
}
return nil
}
func (_u *UsageCleanupTaskUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(usagecleanuptask.Table, usagecleanuptask.Columns, sqlgraph.NewFieldSpec(usagecleanuptask.FieldID, field.TypeInt64))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(usagecleanuptask.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(usagecleanuptask.FieldStatus, field.TypeString, value)
}
if value, ok := _u.mutation.Filters(); ok {
_spec.SetField(usagecleanuptask.FieldFilters, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedFilters(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, usagecleanuptask.FieldFilters, value)
})
}
if value, ok := _u.mutation.CreatedBy(); ok {
_spec.SetField(usagecleanuptask.FieldCreatedBy, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedCreatedBy(); ok {
_spec.AddField(usagecleanuptask.FieldCreatedBy, field.TypeInt64, value)
}
if value, ok := _u.mutation.DeletedRows(); ok {
_spec.SetField(usagecleanuptask.FieldDeletedRows, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedDeletedRows(); ok {
_spec.AddField(usagecleanuptask.FieldDeletedRows, field.TypeInt64, value)
}
if value, ok := _u.mutation.ErrorMessage(); ok {
_spec.SetField(usagecleanuptask.FieldErrorMessage, field.TypeString, value)
}
if _u.mutation.ErrorMessageCleared() {
_spec.ClearField(usagecleanuptask.FieldErrorMessage, field.TypeString)
}
if value, ok := _u.mutation.CanceledBy(); ok {
_spec.SetField(usagecleanuptask.FieldCanceledBy, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedCanceledBy(); ok {
_spec.AddField(usagecleanuptask.FieldCanceledBy, field.TypeInt64, value)
}
if _u.mutation.CanceledByCleared() {
_spec.ClearField(usagecleanuptask.FieldCanceledBy, field.TypeInt64)
}
if value, ok := _u.mutation.CanceledAt(); ok {
_spec.SetField(usagecleanuptask.FieldCanceledAt, field.TypeTime, value)
}
if _u.mutation.CanceledAtCleared() {
_spec.ClearField(usagecleanuptask.FieldCanceledAt, field.TypeTime)
}
if value, ok := _u.mutation.StartedAt(); ok {
_spec.SetField(usagecleanuptask.FieldStartedAt, field.TypeTime, value)
}
if _u.mutation.StartedAtCleared() {
_spec.ClearField(usagecleanuptask.FieldStartedAt, field.TypeTime)
}
if value, ok := _u.mutation.FinishedAt(); ok {
_spec.SetField(usagecleanuptask.FieldFinishedAt, field.TypeTime, value)
}
if _u.mutation.FinishedAtCleared() {
_spec.ClearField(usagecleanuptask.FieldFinishedAt, field.TypeTime)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{usagecleanuptask.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// UsageCleanupTaskUpdateOne is the builder for updating a single UsageCleanupTask entity.
type UsageCleanupTaskUpdateOne struct {
config
fields []string
hooks []Hook
mutation *UsageCleanupTaskMutation
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *UsageCleanupTaskUpdateOne) SetUpdatedAt(v time.Time) *UsageCleanupTaskUpdateOne {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetStatus sets the "status" field.
func (_u *UsageCleanupTaskUpdateOne) SetStatus(v string) *UsageCleanupTaskUpdateOne {
_u.mutation.SetStatus(v)
return _u
}
// SetNillableStatus sets the "status" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdateOne) SetNillableStatus(v *string) *UsageCleanupTaskUpdateOne {
if v != nil {
_u.SetStatus(*v)
}
return _u
}
// SetFilters sets the "filters" field.
func (_u *UsageCleanupTaskUpdateOne) SetFilters(v json.RawMessage) *UsageCleanupTaskUpdateOne {
_u.mutation.SetFilters(v)
return _u
}
// AppendFilters appends value to the "filters" field.
func (_u *UsageCleanupTaskUpdateOne) AppendFilters(v json.RawMessage) *UsageCleanupTaskUpdateOne {
_u.mutation.AppendFilters(v)
return _u
}
// SetCreatedBy sets the "created_by" field.
func (_u *UsageCleanupTaskUpdateOne) SetCreatedBy(v int64) *UsageCleanupTaskUpdateOne {
_u.mutation.ResetCreatedBy()
_u.mutation.SetCreatedBy(v)
return _u
}
// SetNillableCreatedBy sets the "created_by" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdateOne) SetNillableCreatedBy(v *int64) *UsageCleanupTaskUpdateOne {
if v != nil {
_u.SetCreatedBy(*v)
}
return _u
}
// AddCreatedBy adds value to the "created_by" field.
func (_u *UsageCleanupTaskUpdateOne) AddCreatedBy(v int64) *UsageCleanupTaskUpdateOne {
_u.mutation.AddCreatedBy(v)
return _u
}
// SetDeletedRows sets the "deleted_rows" field.
func (_u *UsageCleanupTaskUpdateOne) SetDeletedRows(v int64) *UsageCleanupTaskUpdateOne {
_u.mutation.ResetDeletedRows()
_u.mutation.SetDeletedRows(v)
return _u
}
// SetNillableDeletedRows sets the "deleted_rows" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdateOne) SetNillableDeletedRows(v *int64) *UsageCleanupTaskUpdateOne {
if v != nil {
_u.SetDeletedRows(*v)
}
return _u
}
// AddDeletedRows adds value to the "deleted_rows" field.
func (_u *UsageCleanupTaskUpdateOne) AddDeletedRows(v int64) *UsageCleanupTaskUpdateOne {
_u.mutation.AddDeletedRows(v)
return _u
}
// SetErrorMessage sets the "error_message" field.
func (_u *UsageCleanupTaskUpdateOne) SetErrorMessage(v string) *UsageCleanupTaskUpdateOne {
_u.mutation.SetErrorMessage(v)
return _u
}
// SetNillableErrorMessage sets the "error_message" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdateOne) SetNillableErrorMessage(v *string) *UsageCleanupTaskUpdateOne {
if v != nil {
_u.SetErrorMessage(*v)
}
return _u
}
// ClearErrorMessage clears the value of the "error_message" field.
func (_u *UsageCleanupTaskUpdateOne) ClearErrorMessage() *UsageCleanupTaskUpdateOne {
_u.mutation.ClearErrorMessage()
return _u
}
// SetCanceledBy sets the "canceled_by" field.
func (_u *UsageCleanupTaskUpdateOne) SetCanceledBy(v int64) *UsageCleanupTaskUpdateOne {
_u.mutation.ResetCanceledBy()
_u.mutation.SetCanceledBy(v)
return _u
}
// SetNillableCanceledBy sets the "canceled_by" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdateOne) SetNillableCanceledBy(v *int64) *UsageCleanupTaskUpdateOne {
if v != nil {
_u.SetCanceledBy(*v)
}
return _u
}
// AddCanceledBy adds value to the "canceled_by" field.
func (_u *UsageCleanupTaskUpdateOne) AddCanceledBy(v int64) *UsageCleanupTaskUpdateOne {
_u.mutation.AddCanceledBy(v)
return _u
}
// ClearCanceledBy clears the value of the "canceled_by" field.
func (_u *UsageCleanupTaskUpdateOne) ClearCanceledBy() *UsageCleanupTaskUpdateOne {
_u.mutation.ClearCanceledBy()
return _u
}
// SetCanceledAt sets the "canceled_at" field.
func (_u *UsageCleanupTaskUpdateOne) SetCanceledAt(v time.Time) *UsageCleanupTaskUpdateOne {
_u.mutation.SetCanceledAt(v)
return _u
}
// SetNillableCanceledAt sets the "canceled_at" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdateOne) SetNillableCanceledAt(v *time.Time) *UsageCleanupTaskUpdateOne {
if v != nil {
_u.SetCanceledAt(*v)
}
return _u
}
// ClearCanceledAt clears the value of the "canceled_at" field.
func (_u *UsageCleanupTaskUpdateOne) ClearCanceledAt() *UsageCleanupTaskUpdateOne {
_u.mutation.ClearCanceledAt()
return _u
}
// SetStartedAt sets the "started_at" field.
func (_u *UsageCleanupTaskUpdateOne) SetStartedAt(v time.Time) *UsageCleanupTaskUpdateOne {
_u.mutation.SetStartedAt(v)
return _u
}
// SetNillableStartedAt sets the "started_at" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdateOne) SetNillableStartedAt(v *time.Time) *UsageCleanupTaskUpdateOne {
if v != nil {
_u.SetStartedAt(*v)
}
return _u
}
// ClearStartedAt clears the value of the "started_at" field.
func (_u *UsageCleanupTaskUpdateOne) ClearStartedAt() *UsageCleanupTaskUpdateOne {
_u.mutation.ClearStartedAt()
return _u
}
// SetFinishedAt sets the "finished_at" field.
func (_u *UsageCleanupTaskUpdateOne) SetFinishedAt(v time.Time) *UsageCleanupTaskUpdateOne {
_u.mutation.SetFinishedAt(v)
return _u
}
// SetNillableFinishedAt sets the "finished_at" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdateOne) SetNillableFinishedAt(v *time.Time) *UsageCleanupTaskUpdateOne {
if v != nil {
_u.SetFinishedAt(*v)
}
return _u
}
// ClearFinishedAt clears the value of the "finished_at" field.
func (_u *UsageCleanupTaskUpdateOne) ClearFinishedAt() *UsageCleanupTaskUpdateOne {
_u.mutation.ClearFinishedAt()
return _u
}
// Mutation returns the UsageCleanupTaskMutation object of the builder.
func (_u *UsageCleanupTaskUpdateOne) Mutation() *UsageCleanupTaskMutation {
return _u.mutation
}
// Where appends a list predicates to the UsageCleanupTaskUpdate builder.
func (_u *UsageCleanupTaskUpdateOne) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *UsageCleanupTaskUpdateOne) Select(field string, fields ...string) *UsageCleanupTaskUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated UsageCleanupTask entity.
func (_u *UsageCleanupTaskUpdateOne) Save(ctx context.Context) (*UsageCleanupTask, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *UsageCleanupTaskUpdateOne) SaveX(ctx context.Context) *UsageCleanupTask {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *UsageCleanupTaskUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *UsageCleanupTaskUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *UsageCleanupTaskUpdateOne) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := usagecleanuptask.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *UsageCleanupTaskUpdateOne) check() error {
if v, ok := _u.mutation.Status(); ok {
if err := usagecleanuptask.StatusValidator(v); err != nil {
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "UsageCleanupTask.status": %w`, err)}
}
}
return nil
}
func (_u *UsageCleanupTaskUpdateOne) sqlSave(ctx context.Context) (_node *UsageCleanupTask, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(usagecleanuptask.Table, usagecleanuptask.Columns, sqlgraph.NewFieldSpec(usagecleanuptask.FieldID, field.TypeInt64))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "UsageCleanupTask.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, usagecleanuptask.FieldID)
for _, f := range fields {
if !usagecleanuptask.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != usagecleanuptask.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(usagecleanuptask.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(usagecleanuptask.FieldStatus, field.TypeString, value)
}
if value, ok := _u.mutation.Filters(); ok {
_spec.SetField(usagecleanuptask.FieldFilters, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedFilters(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, usagecleanuptask.FieldFilters, value)
})
}
if value, ok := _u.mutation.CreatedBy(); ok {
_spec.SetField(usagecleanuptask.FieldCreatedBy, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedCreatedBy(); ok {
_spec.AddField(usagecleanuptask.FieldCreatedBy, field.TypeInt64, value)
}
if value, ok := _u.mutation.DeletedRows(); ok {
_spec.SetField(usagecleanuptask.FieldDeletedRows, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedDeletedRows(); ok {
_spec.AddField(usagecleanuptask.FieldDeletedRows, field.TypeInt64, value)
}
if value, ok := _u.mutation.ErrorMessage(); ok {
_spec.SetField(usagecleanuptask.FieldErrorMessage, field.TypeString, value)
}
if _u.mutation.ErrorMessageCleared() {
_spec.ClearField(usagecleanuptask.FieldErrorMessage, field.TypeString)
}
if value, ok := _u.mutation.CanceledBy(); ok {
_spec.SetField(usagecleanuptask.FieldCanceledBy, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedCanceledBy(); ok {
_spec.AddField(usagecleanuptask.FieldCanceledBy, field.TypeInt64, value)
}
if _u.mutation.CanceledByCleared() {
_spec.ClearField(usagecleanuptask.FieldCanceledBy, field.TypeInt64)
}
if value, ok := _u.mutation.CanceledAt(); ok {
_spec.SetField(usagecleanuptask.FieldCanceledAt, field.TypeTime, value)
}
if _u.mutation.CanceledAtCleared() {
_spec.ClearField(usagecleanuptask.FieldCanceledAt, field.TypeTime)
}
if value, ok := _u.mutation.StartedAt(); ok {
_spec.SetField(usagecleanuptask.FieldStartedAt, field.TypeTime, value)
}
if _u.mutation.StartedAtCleared() {
_spec.ClearField(usagecleanuptask.FieldStartedAt, field.TypeTime)
}
if value, ok := _u.mutation.FinishedAt(); ok {
_spec.SetField(usagecleanuptask.FieldFinishedAt, field.TypeTime, value)
}
if _u.mutation.FinishedAtCleared() {
_spec.ClearField(usagecleanuptask.FieldFinishedAt, field.TypeTime)
}
_node = &UsageCleanupTask{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{usagecleanuptask.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}

View File

@@ -62,8 +62,6 @@ type UsageLog struct {
ActualCost float64 `json:"actual_cost,omitempty"`
// RateMultiplier holds the value of the "rate_multiplier" field.
RateMultiplier float64 `json:"rate_multiplier,omitempty"`
// AccountRateMultiplier holds the value of the "account_rate_multiplier" field.
AccountRateMultiplier *float64 `json:"account_rate_multiplier,omitempty"`
// BillingType holds the value of the "billing_type" field.
BillingType int8 `json:"billing_type,omitempty"`
// Stream holds the value of the "stream" field.
@@ -167,7 +165,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
switch columns[i] {
case usagelog.FieldStream:
values[i] = new(sql.NullBool)
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier:
values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64)
@@ -318,13 +316,6 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.RateMultiplier = value.Float64
}
case usagelog.FieldAccountRateMultiplier:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field account_rate_multiplier", values[i])
} else if value.Valid {
_m.AccountRateMultiplier = new(float64)
*_m.AccountRateMultiplier = value.Float64
}
case usagelog.FieldBillingType:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field billing_type", values[i])
@@ -509,11 +500,6 @@ func (_m *UsageLog) String() string {
builder.WriteString("rate_multiplier=")
builder.WriteString(fmt.Sprintf("%v", _m.RateMultiplier))
builder.WriteString(", ")
if v := _m.AccountRateMultiplier; v != nil {
builder.WriteString("account_rate_multiplier=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("billing_type=")
builder.WriteString(fmt.Sprintf("%v", _m.BillingType))
builder.WriteString(", ")

View File

@@ -54,8 +54,6 @@ const (
FieldActualCost = "actual_cost"
// FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
FieldRateMultiplier = "rate_multiplier"
// FieldAccountRateMultiplier holds the string denoting the account_rate_multiplier field in the database.
FieldAccountRateMultiplier = "account_rate_multiplier"
// FieldBillingType holds the string denoting the billing_type field in the database.
FieldBillingType = "billing_type"
// FieldStream holds the string denoting the stream field in the database.
@@ -146,7 +144,6 @@ var Columns = []string{
FieldTotalCost,
FieldActualCost,
FieldRateMultiplier,
FieldAccountRateMultiplier,
FieldBillingType,
FieldStream,
FieldDurationMs,
@@ -323,11 +320,6 @@ func ByRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRateMultiplier, opts...).ToFunc()
}
// ByAccountRateMultiplier orders the results by the account_rate_multiplier field.
func ByAccountRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAccountRateMultiplier, opts...).ToFunc()
}
// ByBillingType orders the results by the billing_type field.
func ByBillingType(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBillingType, opts...).ToFunc()

View File

@@ -155,11 +155,6 @@ func RateMultiplier(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldRateMultiplier, v))
}
// AccountRateMultiplier applies equality check predicate on the "account_rate_multiplier" field. It's identical to AccountRateMultiplierEQ.
func AccountRateMultiplier(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldAccountRateMultiplier, v))
}
// BillingType applies equality check predicate on the "billing_type" field. It's identical to BillingTypeEQ.
func BillingType(v int8) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldBillingType, v))
@@ -975,56 +970,6 @@ func RateMultiplierLTE(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldRateMultiplier, v))
}
// AccountRateMultiplierEQ applies the EQ predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierEQ(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldAccountRateMultiplier, v))
}
// AccountRateMultiplierNEQ applies the NEQ predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierNEQ(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldAccountRateMultiplier, v))
}
// AccountRateMultiplierIn applies the In predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierIn(vs ...float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldAccountRateMultiplier, vs...))
}
// AccountRateMultiplierNotIn applies the NotIn predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierNotIn(vs ...float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldAccountRateMultiplier, vs...))
}
// AccountRateMultiplierGT applies the GT predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierGT(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldAccountRateMultiplier, v))
}
// AccountRateMultiplierGTE applies the GTE predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierGTE(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldAccountRateMultiplier, v))
}
// AccountRateMultiplierLT applies the LT predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierLT(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldAccountRateMultiplier, v))
}
// AccountRateMultiplierLTE applies the LTE predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierLTE(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldAccountRateMultiplier, v))
}
// AccountRateMultiplierIsNil applies the IsNil predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldAccountRateMultiplier))
}
// AccountRateMultiplierNotNil applies the NotNil predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldAccountRateMultiplier))
}
// BillingTypeEQ applies the EQ predicate on the "billing_type" field.
func BillingTypeEQ(v int8) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldBillingType, v))

View File

@@ -267,20 +267,6 @@ func (_c *UsageLogCreate) SetNillableRateMultiplier(v *float64) *UsageLogCreate
return _c
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (_c *UsageLogCreate) SetAccountRateMultiplier(v float64) *UsageLogCreate {
_c.mutation.SetAccountRateMultiplier(v)
return _c
}
// SetNillableAccountRateMultiplier sets the "account_rate_multiplier" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableAccountRateMultiplier(v *float64) *UsageLogCreate {
if v != nil {
_c.SetAccountRateMultiplier(*v)
}
return _c
}
// SetBillingType sets the "billing_type" field.
func (_c *UsageLogCreate) SetBillingType(v int8) *UsageLogCreate {
_c.mutation.SetBillingType(v)
@@ -726,10 +712,6 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
_node.RateMultiplier = value
}
if value, ok := _c.mutation.AccountRateMultiplier(); ok {
_spec.SetField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
_node.AccountRateMultiplier = &value
}
if value, ok := _c.mutation.BillingType(); ok {
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
_node.BillingType = value
@@ -1233,30 +1215,6 @@ func (u *UsageLogUpsert) AddRateMultiplier(v float64) *UsageLogUpsert {
return u
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (u *UsageLogUpsert) SetAccountRateMultiplier(v float64) *UsageLogUpsert {
u.Set(usagelog.FieldAccountRateMultiplier, v)
return u
}
// UpdateAccountRateMultiplier sets the "account_rate_multiplier" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateAccountRateMultiplier() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldAccountRateMultiplier)
return u
}
// AddAccountRateMultiplier adds v to the "account_rate_multiplier" field.
func (u *UsageLogUpsert) AddAccountRateMultiplier(v float64) *UsageLogUpsert {
u.Add(usagelog.FieldAccountRateMultiplier, v)
return u
}
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
func (u *UsageLogUpsert) ClearAccountRateMultiplier() *UsageLogUpsert {
u.SetNull(usagelog.FieldAccountRateMultiplier)
return u
}
// SetBillingType sets the "billing_type" field.
func (u *UsageLogUpsert) SetBillingType(v int8) *UsageLogUpsert {
u.Set(usagelog.FieldBillingType, v)
@@ -1837,34 +1795,6 @@ func (u *UsageLogUpsertOne) UpdateRateMultiplier() *UsageLogUpsertOne {
})
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (u *UsageLogUpsertOne) SetAccountRateMultiplier(v float64) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetAccountRateMultiplier(v)
})
}
// AddAccountRateMultiplier adds v to the "account_rate_multiplier" field.
func (u *UsageLogUpsertOne) AddAccountRateMultiplier(v float64) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.AddAccountRateMultiplier(v)
})
}
// UpdateAccountRateMultiplier sets the "account_rate_multiplier" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateAccountRateMultiplier() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateAccountRateMultiplier()
})
}
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
func (u *UsageLogUpsertOne) ClearAccountRateMultiplier() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearAccountRateMultiplier()
})
}
// SetBillingType sets the "billing_type" field.
func (u *UsageLogUpsertOne) SetBillingType(v int8) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
@@ -2636,34 +2566,6 @@ func (u *UsageLogUpsertBulk) UpdateRateMultiplier() *UsageLogUpsertBulk {
})
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (u *UsageLogUpsertBulk) SetAccountRateMultiplier(v float64) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetAccountRateMultiplier(v)
})
}
// AddAccountRateMultiplier adds v to the "account_rate_multiplier" field.
func (u *UsageLogUpsertBulk) AddAccountRateMultiplier(v float64) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.AddAccountRateMultiplier(v)
})
}
// UpdateAccountRateMultiplier sets the "account_rate_multiplier" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateAccountRateMultiplier() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateAccountRateMultiplier()
})
}
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
func (u *UsageLogUpsertBulk) ClearAccountRateMultiplier() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearAccountRateMultiplier()
})
}
// SetBillingType sets the "billing_type" field.
func (u *UsageLogUpsertBulk) SetBillingType(v int8) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {

View File

@@ -415,33 +415,6 @@ func (_u *UsageLogUpdate) AddRateMultiplier(v float64) *UsageLogUpdate {
return _u
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (_u *UsageLogUpdate) SetAccountRateMultiplier(v float64) *UsageLogUpdate {
_u.mutation.ResetAccountRateMultiplier()
_u.mutation.SetAccountRateMultiplier(v)
return _u
}
// SetNillableAccountRateMultiplier sets the "account_rate_multiplier" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableAccountRateMultiplier(v *float64) *UsageLogUpdate {
if v != nil {
_u.SetAccountRateMultiplier(*v)
}
return _u
}
// AddAccountRateMultiplier adds value to the "account_rate_multiplier" field.
func (_u *UsageLogUpdate) AddAccountRateMultiplier(v float64) *UsageLogUpdate {
_u.mutation.AddAccountRateMultiplier(v)
return _u
}
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
func (_u *UsageLogUpdate) ClearAccountRateMultiplier() *UsageLogUpdate {
_u.mutation.ClearAccountRateMultiplier()
return _u
}
// SetBillingType sets the "billing_type" field.
func (_u *UsageLogUpdate) SetBillingType(v int8) *UsageLogUpdate {
_u.mutation.ResetBillingType()
@@ -834,15 +807,6 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
_spec.AddField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AccountRateMultiplier(); ok {
_spec.SetField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedAccountRateMultiplier(); ok {
_spec.AddField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
}
if _u.mutation.AccountRateMultiplierCleared() {
_spec.ClearField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64)
}
if value, ok := _u.mutation.BillingType(); ok {
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
}
@@ -1442,33 +1406,6 @@ func (_u *UsageLogUpdateOne) AddRateMultiplier(v float64) *UsageLogUpdateOne {
return _u
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (_u *UsageLogUpdateOne) SetAccountRateMultiplier(v float64) *UsageLogUpdateOne {
_u.mutation.ResetAccountRateMultiplier()
_u.mutation.SetAccountRateMultiplier(v)
return _u
}
// SetNillableAccountRateMultiplier sets the "account_rate_multiplier" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableAccountRateMultiplier(v *float64) *UsageLogUpdateOne {
if v != nil {
_u.SetAccountRateMultiplier(*v)
}
return _u
}
// AddAccountRateMultiplier adds value to the "account_rate_multiplier" field.
func (_u *UsageLogUpdateOne) AddAccountRateMultiplier(v float64) *UsageLogUpdateOne {
_u.mutation.AddAccountRateMultiplier(v)
return _u
}
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
func (_u *UsageLogUpdateOne) ClearAccountRateMultiplier() *UsageLogUpdateOne {
_u.mutation.ClearAccountRateMultiplier()
return _u
}
// SetBillingType sets the "billing_type" field.
func (_u *UsageLogUpdateOne) SetBillingType(v int8) *UsageLogUpdateOne {
_u.mutation.ResetBillingType()
@@ -1891,15 +1828,6 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
_spec.AddField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AccountRateMultiplier(); ok {
_spec.SetField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedAccountRateMultiplier(); ok {
_spec.AddField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
}
if _u.mutation.AccountRateMultiplierCleared() {
_spec.ClearField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64)
}
if value, ok := _u.mutation.BillingType(); ok {
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
}

View File

@@ -8,11 +8,9 @@ require (
github.com/golang-jwt/jwt/v5 v5.2.2
github.com/google/uuid v1.6.0
github.com/google/wire v0.7.0
github.com/gorilla/websocket v1.5.3
github.com/imroc/req/v3 v3.57.0
github.com/lib/pq v1.10.9
github.com/redis/go-redis/v9 v9.17.2
github.com/shirou/gopsutil/v4 v4.25.6
github.com/spf13/viper v1.18.2
github.com/stretchr/testify v1.11.1
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0
@@ -31,7 +29,6 @@ require (
ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9 // indirect
dario.cat/mergo v1.0.2 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/agext/levenshtein v1.2.3 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect
@@ -47,13 +44,11 @@ require (
github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dgraph-io/ristretto v0.2.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v28.5.1+incompatible // indirect
github.com/docker/go-connections v0.6.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/ebitengine/purego v0.8.4 // indirect
github.com/fatih/color v1.18.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
@@ -98,7 +93,6 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/olekukonko/tablewriter v0.0.5 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
@@ -109,11 +103,10 @@ require (
github.com/quic-go/qpack v0.6.0 // indirect
github.com/quic-go/quic-go v0.57.1 // indirect
github.com/refraction-networking/utls v1.8.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect
@@ -142,7 +135,7 @@ require (
go.uber.org/automaxprocs v1.6.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
golang.org/x/mod v0.30.0 // indirect
golang.org/x/sys v0.39.0 // indirect
golang.org/x/text v0.32.0 // indirect
@@ -151,8 +144,4 @@ require (
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
modernc.org/libc v1.67.6 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect
modernc.org/sqlite v1.44.1 // indirect
)

View File

@@ -51,8 +51,6 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgraph-io/ristretto v0.2.0 h1:XAfl+7cmoUDWW/2Lx8TGZQjjxIQ2Ley9DSf52dru4WE=
github.com/dgraph-io/ristretto v0.2.0/go.mod h1:8uBHCU/PBV4Ag0CJrP47b9Ofby5dqWNh4FicAdoqFNU=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
@@ -63,8 +61,6 @@ github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pM
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
@@ -117,8 +113,6 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3J18=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
@@ -141,7 +135,6 @@ github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
@@ -200,8 +193,6 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
@@ -227,12 +218,8 @@ github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4Vi
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
@@ -343,8 +330,6 @@ golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
@@ -372,7 +357,6 @@ golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
golang.org/x/tools/go/expect v0.1.0-deprecated h1:jY2C5HGYR5lqex3gEniOQL0r7Dq5+VGVgY1nudX5lXY=
golang.org/x/tools/go/expect v0.1.0-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY=
golang.org/x/tools/go/expect v0.1.1-deprecated h1:jpBZDwmgPhXsKZC6WhL20P4b/wmnpsEAGHaNy0n/rJM=
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated h1:1h2MnaIAIXISqTFKdENegdpAgUXz6NrPEsbIeWaBRvM=
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated/go.mod h1:RVAQXBGNv1ib0J382/DPCRS/BPnsGebyM1Gj5VSDpG8=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -395,12 +379,4 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI=
modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
modernc.org/sqlite v1.44.1 h1:qybx/rNpfQipX/t47OxbHmkkJuv2JWifCMH8SVUiDas=
modernc.org/sqlite v1.44.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

View File

@@ -19,9 +19,7 @@ const (
RunModeSimple = "simple"
)
// DefaultCSPPolicy is the default Content-Security-Policy with nonce support
// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' https://challenges.cloudflare.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
// 连接池隔离策略常量
// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗
@@ -38,30 +36,33 @@ const (
)
type Config struct {
Server ServerConfig `mapstructure:"server"`
CORS CORSConfig `mapstructure:"cors"`
Security SecurityConfig `mapstructure:"security"`
Billing BillingConfig `mapstructure:"billing"`
Turnstile TurnstileConfig `mapstructure:"turnstile"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
Ops OpsConfig `mapstructure:"ops"`
JWT JWTConfig `mapstructure:"jwt"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"`
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Gemini GeminiConfig `mapstructure:"gemini"`
Update UpdateConfig `mapstructure:"update"`
Server ServerConfig `mapstructure:"server"`
CORS CORSConfig `mapstructure:"cors"`
Security SecurityConfig `mapstructure:"security"`
Billing BillingConfig `mapstructure:"billing"`
Turnstile TurnstileConfig `mapstructure:"turnstile"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
JWT JWTConfig `mapstructure:"jwt"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Gemini GeminiConfig `mapstructure:"gemini"`
Update UpdateConfig `mapstructure:"update"`
}
// UpdateConfig 在线更新相关配置
type UpdateConfig struct {
// ProxyURL 用于访问 GitHub 的代理地址
// 支持 http/https/socks5/socks5h 协议
// 例如: "http://127.0.0.1:7890", "socks5://127.0.0.1:1080"
ProxyURL string `mapstructure:"proxy_url"`
}
type GeminiConfig struct {
@@ -86,33 +87,6 @@ type GeminiTierQuotaConfig struct {
CooldownMinutes *int `mapstructure:"cooldown_minutes" json:"cooldown_minutes"`
}
type UpdateConfig struct {
// ProxyURL 用于访问 GitHub 的代理地址
// 支持 http/https/socks5/socks5h 协议
// 例如: "http://127.0.0.1:7890", "socks5://127.0.0.1:1080"
ProxyURL string `mapstructure:"proxy_url"`
}
type LinuxDoConnectConfig struct {
Enabled bool `mapstructure:"enabled"`
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
AuthorizeURL string `mapstructure:"authorize_url"`
TokenURL string `mapstructure:"token_url"`
UserInfoURL string `mapstructure:"userinfo_url"`
Scopes string `mapstructure:"scopes"`
RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/linuxdo/callback
TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
UsePKCE bool `mapstructure:"use_pkce"`
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
// 为空时,服务端会尝试一组常见字段名。
UserInfoEmailPath string `mapstructure:"userinfo_email_path"`
UserInfoIDPath string `mapstructure:"userinfo_id_path"`
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
// TokenRefreshConfig OAuth token自动刷新配置
type TokenRefreshConfig struct {
// 是否启用自动刷新
@@ -235,10 +209,6 @@ type GatewayConfig struct {
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
// SessionIdleTimeoutMinutes: 会话空闲超时时间(分钟),默认 5 分钟
// 用于 Anthropic OAuth/SetupToken 账号的会话数量限制功能
// 空闲超过此时间的会话将被自动释放
SessionIdleTimeoutMinutes int `mapstructure:"session_idle_timeout_minutes"`
// StreamDataIntervalTimeout: 流数据间隔超时0表示禁用
StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"`
@@ -258,43 +228,8 @@ type GatewayConfig struct {
// 是否允许对部分 400 错误触发 failover默认关闭以避免改变语义
FailoverOn400 bool `mapstructure:"failover_on_400"`
// 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限)
MaxAccountSwitches int `mapstructure:"max_account_switches"`
// Gemini 账户切换最大次数Gemini 平台单独配置,因 API 限制更严格)
MaxAccountSwitchesGemini int `mapstructure:"max_account_switches_gemini"`
// Antigravity 429 fallback 限流时间(分钟),解析重置时间失败时使用
AntigravityFallbackCooldownMinutes int `mapstructure:"antigravity_fallback_cooldown_minutes"`
// Scheduling: 账号调度相关配置
Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"`
// TLSFingerprint: TLS指纹伪装配置
TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"`
}
// TLSFingerprintConfig TLS指纹伪装配置
// 用于模拟 Claude CLI (Node.js) 的 TLS 握手特征,避免被识别为非官方客户端
type TLSFingerprintConfig struct {
// Enabled: 是否全局启用TLS指纹功能
Enabled bool `mapstructure:"enabled"`
// Profiles: 预定义的TLS指纹配置模板
// key 为模板名称,如 "claude_cli_v2", "chrome_120" 等
Profiles map[string]TLSProfileConfig `mapstructure:"profiles"`
}
// TLSProfileConfig 单个TLS指纹模板的配置
type TLSProfileConfig struct {
// Name: 模板显示名称
Name string `mapstructure:"name"`
// EnableGREASE: 是否启用GREASE扩展Chrome使用Node.js不使用
EnableGREASE bool `mapstructure:"enable_grease"`
// CipherSuites: TLS加密套件列表空则使用内置默认值
CipherSuites []uint16 `mapstructure:"cipher_suites"`
// Curves: 椭圆曲线列表(空则使用内置默认值)
Curves []uint16 `mapstructure:"curves"`
// PointFormats: 点格式列表(空则使用内置默认值)
PointFormats []uint8 `mapstructure:"point_formats"`
}
// GatewaySchedulingConfig accounts scheduling configuration.
@@ -307,37 +242,11 @@ type GatewaySchedulingConfig struct {
FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"`
FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"`
// 兜底层账户选择策略: "last_used"(按最后使用时间排序,默认) 或 "random"(随机)
FallbackSelectionMode string `mapstructure:"fallback_selection_mode"`
// 负载计算
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
// 过期槽位清理周期0 表示禁用)
SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"`
// 受控回源配置
DbFallbackEnabled bool `mapstructure:"db_fallback_enabled"`
// 受控回源超时0 表示不额外收紧超时
DbFallbackTimeoutSeconds int `mapstructure:"db_fallback_timeout_seconds"`
// 受控回源限流(实例级 QPS0 表示不限制
DbFallbackMaxQPS int `mapstructure:"db_fallback_max_qps"`
// Outbox 轮询与滞后阈值配置
// Outbox 轮询周期(秒)
OutboxPollIntervalSeconds int `mapstructure:"outbox_poll_interval_seconds"`
// Outbox 滞后告警阈值(秒)
OutboxLagWarnSeconds int `mapstructure:"outbox_lag_warn_seconds"`
// Outbox 触发强制重建阈值(秒)
OutboxLagRebuildSeconds int `mapstructure:"outbox_lag_rebuild_seconds"`
// Outbox 连续滞后触发次数
OutboxLagRebuildFailures int `mapstructure:"outbox_lag_rebuild_failures"`
// Outbox 积压触发重建阈值(行数)
OutboxBacklogRebuildRows int `mapstructure:"outbox_backlog_rebuild_rows"`
// 全量重建周期配置
// 全量重建周期0 表示禁用
FullRebuildIntervalSeconds int `mapstructure:"full_rebuild_interval_seconds"`
}
func (s *ServerConfig) Address() string {
@@ -420,47 +329,6 @@ func (r *RedisConfig) Address() string {
return fmt.Sprintf("%s:%d", r.Host, r.Port)
}
type OpsConfig struct {
// Enabled controls whether ops features should run.
//
// NOTE: vNext still has a DB-backed feature flag (ops_monitoring_enabled) for runtime on/off.
// This config flag is the "hard switch" for deployments that want to disable ops completely.
Enabled bool `mapstructure:"enabled"`
// UsePreaggregatedTables prefers ops_metrics_hourly/daily for long-window dashboard queries.
UsePreaggregatedTables bool `mapstructure:"use_preaggregated_tables"`
// Cleanup controls periodic deletion of old ops data to prevent unbounded growth.
Cleanup OpsCleanupConfig `mapstructure:"cleanup"`
// MetricsCollectorCache controls Redis caching for expensive per-window collector queries.
MetricsCollectorCache OpsMetricsCollectorCacheConfig `mapstructure:"metrics_collector_cache"`
// Pre-aggregation configuration.
Aggregation OpsAggregationConfig `mapstructure:"aggregation"`
}
type OpsCleanupConfig struct {
Enabled bool `mapstructure:"enabled"`
Schedule string `mapstructure:"schedule"`
// Retention days (0 disables that cleanup target).
//
// vNext requirement: default 30 days across ops datasets.
ErrorLogRetentionDays int `mapstructure:"error_log_retention_days"`
MinuteMetricsRetentionDays int `mapstructure:"minute_metrics_retention_days"`
HourlyMetricsRetentionDays int `mapstructure:"hourly_metrics_retention_days"`
}
type OpsAggregationConfig struct {
Enabled bool `mapstructure:"enabled"`
}
type OpsMetricsCollectorCacheConfig struct {
Enabled bool `mapstructure:"enabled"`
TTL time.Duration `mapstructure:"ttl"`
}
type JWTConfig struct {
Secret string `mapstructure:"secret"`
ExpireHour int `mapstructure:"expire_hour"`
@@ -470,6 +338,30 @@ type TurnstileConfig struct {
Required bool `mapstructure:"required"`
}
// LinuxDoConnectConfig 用于 LinuxDo Connect OAuth 登录(终端用户 SSO
//
// 注意:这与上游账号的 OAuth例如 OpenAI/Gemini 账号接入)不是一回事。
// 这里是用于登录 Sub2API 本身的用户体系。
type LinuxDoConnectConfig struct {
Enabled bool `mapstructure:"enabled"`
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
AuthorizeURL string `mapstructure:"authorize_url"`
TokenURL string `mapstructure:"token_url"`
UserInfoURL string `mapstructure:"userinfo_url"`
Scopes string `mapstructure:"scopes"`
RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/linuxdo/callback
TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
UsePKCE bool `mapstructure:"use_pkce"`
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
// 为空时,服务端会尝试一组常见字段名。
UserInfoEmailPath string `mapstructure:"userinfo_email_path"`
UserInfoIDPath string `mapstructure:"userinfo_id_path"`
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
type DefaultConfig struct {
AdminEmail string `mapstructure:"admin_email"`
AdminPassword string `mapstructure:"admin_password"`
@@ -483,69 +375,6 @@ type RateLimitConfig struct {
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
}
// APIKeyAuthCacheConfig API Key 认证缓存配置
type APIKeyAuthCacheConfig struct {
L1Size int `mapstructure:"l1_size"`
L1TTLSeconds int `mapstructure:"l1_ttl_seconds"`
L2TTLSeconds int `mapstructure:"l2_ttl_seconds"`
NegativeTTLSeconds int `mapstructure:"negative_ttl_seconds"`
JitterPercent int `mapstructure:"jitter_percent"`
Singleflight bool `mapstructure:"singleflight"`
}
// DashboardCacheConfig 仪表盘统计缓存配置
type DashboardCacheConfig struct {
// Enabled: 是否启用仪表盘缓存
Enabled bool `mapstructure:"enabled"`
// KeyPrefix: Redis key 前缀,用于多环境隔离
KeyPrefix string `mapstructure:"key_prefix"`
// StatsFreshTTLSeconds: 缓存命中认为“新鲜”的时间窗口(秒)
StatsFreshTTLSeconds int `mapstructure:"stats_fresh_ttl_seconds"`
// StatsTTLSeconds: Redis 缓存总 TTL
StatsTTLSeconds int `mapstructure:"stats_ttl_seconds"`
// StatsRefreshTimeoutSeconds: 异步刷新超时(秒)
StatsRefreshTimeoutSeconds int `mapstructure:"stats_refresh_timeout_seconds"`
}
// DashboardAggregationConfig 仪表盘预聚合配置
type DashboardAggregationConfig struct {
// Enabled: 是否启用预聚合作业
Enabled bool `mapstructure:"enabled"`
// IntervalSeconds: 聚合刷新间隔(秒)
IntervalSeconds int `mapstructure:"interval_seconds"`
// LookbackSeconds: 回看窗口(秒)
LookbackSeconds int `mapstructure:"lookback_seconds"`
// BackfillEnabled: 是否允许全量回填
BackfillEnabled bool `mapstructure:"backfill_enabled"`
// BackfillMaxDays: 回填最大跨度(天)
BackfillMaxDays int `mapstructure:"backfill_max_days"`
// Retention: 各表保留窗口(天)
Retention DashboardAggregationRetentionConfig `mapstructure:"retention"`
// RecomputeDays: 启动时重算最近 N 天
RecomputeDays int `mapstructure:"recompute_days"`
}
// DashboardAggregationRetentionConfig 预聚合保留窗口
type DashboardAggregationRetentionConfig struct {
UsageLogsDays int `mapstructure:"usage_logs_days"`
HourlyDays int `mapstructure:"hourly_days"`
DailyDays int `mapstructure:"daily_days"`
}
// UsageCleanupConfig 使用记录清理任务配置
type UsageCleanupConfig struct {
// Enabled: 是否启用清理任务执行器
Enabled bool `mapstructure:"enabled"`
// MaxRangeDays: 单次任务允许的最大时间跨度(天)
MaxRangeDays int `mapstructure:"max_range_days"`
// BatchSize: 单批删除数量
BatchSize int `mapstructure:"batch_size"`
// WorkerIntervalSeconds: 后台任务轮询间隔(秒)
WorkerIntervalSeconds int `mapstructure:"worker_interval_seconds"`
// TaskTimeoutSeconds: 单次任务最大执行时长(秒)
TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"`
}
func NormalizeRunMode(value string) string {
normalized := strings.ToLower(strings.TrimSpace(value))
switch normalized {
@@ -611,7 +440,6 @@ func Load() (*Config, error) {
cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath)
cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath)
cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath)
cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
@@ -650,6 +478,81 @@ func Load() (*Config, error) {
return &cfg, nil
}
// ValidateAbsoluteHTTPURL 校验一个绝对 http(s) URL禁止 fragment
func ValidateAbsoluteHTTPURL(raw string) error {
raw = strings.TrimSpace(raw)
if raw == "" {
return fmt.Errorf("empty url")
}
u, err := url.Parse(raw)
if err != nil {
return err
}
if !u.IsAbs() {
return fmt.Errorf("must be absolute")
}
if !isHTTPScheme(u.Scheme) {
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
}
if strings.TrimSpace(u.Host) == "" {
return fmt.Errorf("missing host")
}
if u.Fragment != "" {
return fmt.Errorf("must not include fragment")
}
return nil
}
// ValidateFrontendRedirectURL 校验前端回调地址:
// - 允许同源相对路径(以 / 开头)
// - 或绝对 http(s) URL禁止 fragment
func ValidateFrontendRedirectURL(raw string) error {
raw = strings.TrimSpace(raw)
if raw == "" {
return fmt.Errorf("empty url")
}
if strings.ContainsAny(raw, "\r\n") {
return fmt.Errorf("contains invalid characters")
}
if strings.HasPrefix(raw, "/") {
if strings.HasPrefix(raw, "//") {
return fmt.Errorf("must not start with //")
}
return nil
}
u, err := url.Parse(raw)
if err != nil {
return err
}
if !u.IsAbs() {
return fmt.Errorf("must be absolute http(s) url or relative path")
}
if !isHTTPScheme(u.Scheme) {
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
}
if strings.TrimSpace(u.Host) == "" {
return fmt.Errorf("missing host")
}
if u.Fragment != "" {
return fmt.Errorf("must not include fragment")
}
return nil
}
func isHTTPScheme(scheme string) bool {
return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https")
}
func warnIfInsecureURL(field, raw string) {
u, err := url.Parse(strings.TrimSpace(raw))
if err != nil {
return
}
if strings.EqualFold(u.Scheme, "http") {
log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field)
}
}
func setDefaults() {
viper.SetDefault("run_mode", RunModeStandard)
@@ -699,7 +602,7 @@ func setDefaults() {
// Turnstile
viper.SetDefault("turnstile.required", false)
// LinuxDo Connect OAuth 登录
// LinuxDo Connect OAuth 登录(终端用户 SSO
viper.SetDefault("linuxdo_connect.enabled", false)
viper.SetDefault("linuxdo_connect.client_id", "")
viper.SetDefault("linuxdo_connect.client_secret", "")
@@ -738,20 +641,6 @@ func setDefaults() {
viper.SetDefault("redis.pool_size", 128)
viper.SetDefault("redis.min_idle_conns", 10)
// Ops (vNext)
viper.SetDefault("ops.enabled", true)
viper.SetDefault("ops.use_preaggregated_tables", false)
viper.SetDefault("ops.cleanup.enabled", true)
viper.SetDefault("ops.cleanup.schedule", "0 2 * * *")
// Retention days: vNext defaults to 30 days across ops datasets.
viper.SetDefault("ops.cleanup.error_log_retention_days", 30)
viper.SetDefault("ops.cleanup.minute_metrics_retention_days", 30)
viper.SetDefault("ops.cleanup.hourly_metrics_retention_days", 30)
viper.SetDefault("ops.aggregation.enabled", true)
viper.SetDefault("ops.metrics_collector_cache.enabled", true)
// TTL should be slightly larger than collection interval (1m) to maximize cross-replica cache hits.
viper.SetDefault("ops.metrics_collector_cache.ttl", 65*time.Second)
// JWT
viper.SetDefault("jwt.secret", "")
viper.SetDefault("jwt.expire_hour", 24)
@@ -780,48 +669,12 @@ func setDefaults() {
// Timezone (default to Asia/Shanghai for Chinese users)
viper.SetDefault("timezone", "Asia/Shanghai")
// API Key auth cache
viper.SetDefault("api_key_auth_cache.l1_size", 65535)
viper.SetDefault("api_key_auth_cache.l1_ttl_seconds", 15)
viper.SetDefault("api_key_auth_cache.l2_ttl_seconds", 300)
viper.SetDefault("api_key_auth_cache.negative_ttl_seconds", 30)
viper.SetDefault("api_key_auth_cache.jitter_percent", 10)
viper.SetDefault("api_key_auth_cache.singleflight", true)
// Dashboard cache
viper.SetDefault("dashboard_cache.enabled", true)
viper.SetDefault("dashboard_cache.key_prefix", "sub2api:")
viper.SetDefault("dashboard_cache.stats_fresh_ttl_seconds", 15)
viper.SetDefault("dashboard_cache.stats_ttl_seconds", 30)
viper.SetDefault("dashboard_cache.stats_refresh_timeout_seconds", 30)
// Dashboard aggregation
viper.SetDefault("dashboard_aggregation.enabled", true)
viper.SetDefault("dashboard_aggregation.interval_seconds", 60)
viper.SetDefault("dashboard_aggregation.lookback_seconds", 120)
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
// Usage cleanup task
viper.SetDefault("usage_cleanup.enabled", true)
viper.SetDefault("usage_cleanup.max_range_days", 31)
viper.SetDefault("usage_cleanup.batch_size", 5000)
viper.SetDefault("usage_cleanup.worker_interval_seconds", 10)
viper.SetDefault("usage_cleanup.task_timeout_seconds", 1800)
// Gateway
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头LLM高负载时可能排队较久
viper.SetDefault("gateway.log_upstream_error_body", true)
viper.SetDefault("gateway.log_upstream_error_body", false)
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
viper.SetDefault("gateway.inject_beta_for_apikey", false)
viper.SetDefault("gateway.failover_on_400", false)
viper.SetDefault("gateway.max_account_switches", 10)
viper.SetDefault("gateway.max_account_switches_gemini", 3)
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
@@ -836,23 +689,11 @@ func setDefaults() {
viper.SetDefault("gateway.stream_keepalive_interval", 10)
viper.SetDefault("gateway.max_line_size", 40*1024*1024)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second)
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used")
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
viper.SetDefault("gateway.scheduling.db_fallback_enabled", true)
viper.SetDefault("gateway.scheduling.db_fallback_timeout_seconds", 0)
viper.SetDefault("gateway.scheduling.db_fallback_max_qps", 0)
viper.SetDefault("gateway.scheduling.outbox_poll_interval_seconds", 1)
viper.SetDefault("gateway.scheduling.outbox_lag_warn_seconds", 5)
viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_seconds", 10)
viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3)
viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000)
viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300)
// TLS指纹伪装配置默认关闭需要账号级别单独启用
viper.SetDefault("gateway.tls_fingerprint.enabled", true)
viper.SetDefault("concurrency.ping_interval", 10)
// TokenRefresh
@@ -869,6 +710,10 @@ func setDefaults() {
viper.SetDefault("gemini.oauth.client_secret", "")
viper.SetDefault("gemini.oauth.scopes", "")
viper.SetDefault("gemini.quota.policy", "")
// Update - 在线更新配置
// 代理地址为空表示直连 GitHub适用于海外服务器
viper.SetDefault("update.proxy_url", "")
}
func (c *Config) Validate() error {
@@ -909,8 +754,7 @@ func (c *Config) Validate() error {
if method == "none" && !c.LinuxDo.UsePKCE {
return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none")
}
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") &&
strings.TrimSpace(c.LinuxDo.ClientSecret) == "" {
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.LinuxDo.ClientSecret) == "" {
return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
}
if strings.TrimSpace(c.LinuxDo.FrontendRedirectURL) == "" {
@@ -983,105 +827,6 @@ func (c *Config) Validate() error {
if c.Redis.MinIdleConns > c.Redis.PoolSize {
return fmt.Errorf("redis.min_idle_conns cannot exceed redis.pool_size")
}
if c.Dashboard.Enabled {
if c.Dashboard.StatsFreshTTLSeconds <= 0 {
return fmt.Errorf("dashboard_cache.stats_fresh_ttl_seconds must be positive")
}
if c.Dashboard.StatsTTLSeconds <= 0 {
return fmt.Errorf("dashboard_cache.stats_ttl_seconds must be positive")
}
if c.Dashboard.StatsRefreshTimeoutSeconds <= 0 {
return fmt.Errorf("dashboard_cache.stats_refresh_timeout_seconds must be positive")
}
if c.Dashboard.StatsFreshTTLSeconds > c.Dashboard.StatsTTLSeconds {
return fmt.Errorf("dashboard_cache.stats_fresh_ttl_seconds must be <= dashboard_cache.stats_ttl_seconds")
}
} else {
if c.Dashboard.StatsFreshTTLSeconds < 0 {
return fmt.Errorf("dashboard_cache.stats_fresh_ttl_seconds must be non-negative")
}
if c.Dashboard.StatsTTLSeconds < 0 {
return fmt.Errorf("dashboard_cache.stats_ttl_seconds must be non-negative")
}
if c.Dashboard.StatsRefreshTimeoutSeconds < 0 {
return fmt.Errorf("dashboard_cache.stats_refresh_timeout_seconds must be non-negative")
}
}
if c.DashboardAgg.Enabled {
if c.DashboardAgg.IntervalSeconds <= 0 {
return fmt.Errorf("dashboard_aggregation.interval_seconds must be positive")
}
if c.DashboardAgg.LookbackSeconds < 0 {
return fmt.Errorf("dashboard_aggregation.lookback_seconds must be non-negative")
}
if c.DashboardAgg.BackfillMaxDays < 0 {
return fmt.Errorf("dashboard_aggregation.backfill_max_days must be non-negative")
}
if c.DashboardAgg.BackfillEnabled && c.DashboardAgg.BackfillMaxDays == 0 {
return fmt.Errorf("dashboard_aggregation.backfill_max_days must be positive")
}
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
}
if c.DashboardAgg.Retention.HourlyDays <= 0 {
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
}
if c.DashboardAgg.Retention.DailyDays <= 0 {
return fmt.Errorf("dashboard_aggregation.retention.daily_days must be positive")
}
if c.DashboardAgg.RecomputeDays < 0 {
return fmt.Errorf("dashboard_aggregation.recompute_days must be non-negative")
}
} else {
if c.DashboardAgg.IntervalSeconds < 0 {
return fmt.Errorf("dashboard_aggregation.interval_seconds must be non-negative")
}
if c.DashboardAgg.LookbackSeconds < 0 {
return fmt.Errorf("dashboard_aggregation.lookback_seconds must be non-negative")
}
if c.DashboardAgg.BackfillMaxDays < 0 {
return fmt.Errorf("dashboard_aggregation.backfill_max_days must be non-negative")
}
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
}
if c.DashboardAgg.Retention.HourlyDays < 0 {
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
}
if c.DashboardAgg.Retention.DailyDays < 0 {
return fmt.Errorf("dashboard_aggregation.retention.daily_days must be non-negative")
}
if c.DashboardAgg.RecomputeDays < 0 {
return fmt.Errorf("dashboard_aggregation.recompute_days must be non-negative")
}
}
if c.UsageCleanup.Enabled {
if c.UsageCleanup.MaxRangeDays <= 0 {
return fmt.Errorf("usage_cleanup.max_range_days must be positive")
}
if c.UsageCleanup.BatchSize <= 0 {
return fmt.Errorf("usage_cleanup.batch_size must be positive")
}
if c.UsageCleanup.WorkerIntervalSeconds <= 0 {
return fmt.Errorf("usage_cleanup.worker_interval_seconds must be positive")
}
if c.UsageCleanup.TaskTimeoutSeconds <= 0 {
return fmt.Errorf("usage_cleanup.task_timeout_seconds must be positive")
}
} else {
if c.UsageCleanup.MaxRangeDays < 0 {
return fmt.Errorf("usage_cleanup.max_range_days must be non-negative")
}
if c.UsageCleanup.BatchSize < 0 {
return fmt.Errorf("usage_cleanup.batch_size must be non-negative")
}
if c.UsageCleanup.WorkerIntervalSeconds < 0 {
return fmt.Errorf("usage_cleanup.worker_interval_seconds must be non-negative")
}
if c.UsageCleanup.TaskTimeoutSeconds < 0 {
return fmt.Errorf("usage_cleanup.task_timeout_seconds must be non-negative")
}
}
if c.Gateway.MaxBodySize <= 0 {
return fmt.Errorf("gateway.max_body_size must be positive")
}
@@ -1152,50 +897,6 @@ func (c *Config) Validate() error {
if c.Gateway.Scheduling.SlotCleanupInterval < 0 {
return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative")
}
if c.Gateway.Scheduling.DbFallbackTimeoutSeconds < 0 {
return fmt.Errorf("gateway.scheduling.db_fallback_timeout_seconds must be non-negative")
}
if c.Gateway.Scheduling.DbFallbackMaxQPS < 0 {
return fmt.Errorf("gateway.scheduling.db_fallback_max_qps must be non-negative")
}
if c.Gateway.Scheduling.OutboxPollIntervalSeconds <= 0 {
return fmt.Errorf("gateway.scheduling.outbox_poll_interval_seconds must be positive")
}
if c.Gateway.Scheduling.OutboxLagWarnSeconds < 0 {
return fmt.Errorf("gateway.scheduling.outbox_lag_warn_seconds must be non-negative")
}
if c.Gateway.Scheduling.OutboxLagRebuildSeconds < 0 {
return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_seconds must be non-negative")
}
if c.Gateway.Scheduling.OutboxLagRebuildFailures <= 0 {
return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_failures must be positive")
}
if c.Gateway.Scheduling.OutboxBacklogRebuildRows < 0 {
return fmt.Errorf("gateway.scheduling.outbox_backlog_rebuild_rows must be non-negative")
}
if c.Gateway.Scheduling.FullRebuildIntervalSeconds < 0 {
return fmt.Errorf("gateway.scheduling.full_rebuild_interval_seconds must be non-negative")
}
if c.Gateway.Scheduling.OutboxLagWarnSeconds > 0 &&
c.Gateway.Scheduling.OutboxLagRebuildSeconds > 0 &&
c.Gateway.Scheduling.OutboxLagRebuildSeconds < c.Gateway.Scheduling.OutboxLagWarnSeconds {
return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_seconds must be >= outbox_lag_warn_seconds")
}
if c.Ops.MetricsCollectorCache.TTL < 0 {
return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative")
}
if c.Ops.Cleanup.ErrorLogRetentionDays < 0 {
return fmt.Errorf("ops.cleanup.error_log_retention_days must be non-negative")
}
if c.Ops.Cleanup.MinuteMetricsRetentionDays < 0 {
return fmt.Errorf("ops.cleanup.minute_metrics_retention_days must be non-negative")
}
if c.Ops.Cleanup.HourlyMetricsRetentionDays < 0 {
return fmt.Errorf("ops.cleanup.hourly_metrics_retention_days must be non-negative")
}
if c.Ops.Cleanup.Enabled && strings.TrimSpace(c.Ops.Cleanup.Schedule) == "" {
return fmt.Errorf("ops.cleanup.schedule is required when ops.cleanup.enabled=true")
}
if c.Concurrency.PingInterval < 5 || c.Concurrency.PingInterval > 30 {
return fmt.Errorf("concurrency.ping_interval must be between 5-30 seconds")
}
@@ -1272,77 +973,3 @@ func GetServerAddress() string {
port := v.GetInt("server.port")
return fmt.Sprintf("%s:%d", host, port)
}
// ValidateAbsoluteHTTPURL 验证是否为有效的绝对 HTTP(S) URL
func ValidateAbsoluteHTTPURL(raw string) error {
raw = strings.TrimSpace(raw)
if raw == "" {
return fmt.Errorf("empty url")
}
u, err := url.Parse(raw)
if err != nil {
return err
}
if !u.IsAbs() {
return fmt.Errorf("must be absolute")
}
if !isHTTPScheme(u.Scheme) {
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
}
if strings.TrimSpace(u.Host) == "" {
return fmt.Errorf("missing host")
}
if u.Fragment != "" {
return fmt.Errorf("must not include fragment")
}
return nil
}
// ValidateFrontendRedirectURL 验证前端重定向 URL可以是绝对 URL 或相对路径)
func ValidateFrontendRedirectURL(raw string) error {
raw = strings.TrimSpace(raw)
if raw == "" {
return fmt.Errorf("empty url")
}
if strings.ContainsAny(raw, "\r\n") {
return fmt.Errorf("contains invalid characters")
}
if strings.HasPrefix(raw, "/") {
if strings.HasPrefix(raw, "//") {
return fmt.Errorf("must not start with //")
}
return nil
}
u, err := url.Parse(raw)
if err != nil {
return err
}
if !u.IsAbs() {
return fmt.Errorf("must be absolute http(s) url or relative path")
}
if !isHTTPScheme(u.Scheme) {
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
}
if strings.TrimSpace(u.Host) == "" {
return fmt.Errorf("missing host")
}
if u.Fragment != "" {
return fmt.Errorf("must not include fragment")
}
return nil
}
// isHTTPScheme 检查是否为 HTTP 或 HTTPS 协议
func isHTTPScheme(scheme string) bool {
return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https")
}
func warnIfInsecureURL(field, raw string) {
u, err := url.Parse(strings.TrimSpace(raw))
if err != nil {
return
}
if strings.EqualFold(u.Scheme, "http") {
log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field)
}
}

View File

@@ -39,8 +39,8 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) {
if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 3 {
t.Fatalf("StickySessionMaxWaiting = %d, want 3", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
}
if cfg.Gateway.Scheduling.StickySessionWaitTimeout != 120*time.Second {
t.Fatalf("StickySessionWaitTimeout = %v, want 120s", cfg.Gateway.Scheduling.StickySessionWaitTimeout)
if cfg.Gateway.Scheduling.StickySessionWaitTimeout != 45*time.Second {
t.Fatalf("StickySessionWaitTimeout = %v, want 45s", cfg.Gateway.Scheduling.StickySessionWaitTimeout)
}
if cfg.Gateway.Scheduling.FallbackWaitTimeout != 30*time.Second {
t.Fatalf("FallbackWaitTimeout = %v, want 30s", cfg.Gateway.Scheduling.FallbackWaitTimeout)
@@ -141,712 +141,3 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
t.Fatalf("Validate() expected use_pkce error, got: %v", err)
}
}
func TestLoadDefaultDashboardCacheConfig(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if !cfg.Dashboard.Enabled {
t.Fatalf("Dashboard.Enabled = false, want true")
}
if cfg.Dashboard.KeyPrefix != "sub2api:" {
t.Fatalf("Dashboard.KeyPrefix = %q, want %q", cfg.Dashboard.KeyPrefix, "sub2api:")
}
if cfg.Dashboard.StatsFreshTTLSeconds != 15 {
t.Fatalf("Dashboard.StatsFreshTTLSeconds = %d, want 15", cfg.Dashboard.StatsFreshTTLSeconds)
}
if cfg.Dashboard.StatsTTLSeconds != 30 {
t.Fatalf("Dashboard.StatsTTLSeconds = %d, want 30", cfg.Dashboard.StatsTTLSeconds)
}
if cfg.Dashboard.StatsRefreshTimeoutSeconds != 30 {
t.Fatalf("Dashboard.StatsRefreshTimeoutSeconds = %d, want 30", cfg.Dashboard.StatsRefreshTimeoutSeconds)
}
}
func TestValidateDashboardCacheConfigEnabled(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Dashboard.Enabled = true
cfg.Dashboard.StatsFreshTTLSeconds = 10
cfg.Dashboard.StatsTTLSeconds = 5
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error for stats_fresh_ttl_seconds > stats_ttl_seconds, got nil")
}
if !strings.Contains(err.Error(), "dashboard_cache.stats_fresh_ttl_seconds") {
t.Fatalf("Validate() expected stats_fresh_ttl_seconds error, got: %v", err)
}
}
func TestValidateDashboardCacheConfigDisabled(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Dashboard.Enabled = false
cfg.Dashboard.StatsTTLSeconds = -1
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error for negative stats_ttl_seconds, got nil")
}
if !strings.Contains(err.Error(), "dashboard_cache.stats_ttl_seconds") {
t.Fatalf("Validate() expected stats_ttl_seconds error, got: %v", err)
}
}
func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if !cfg.DashboardAgg.Enabled {
t.Fatalf("DashboardAgg.Enabled = false, want true")
}
if cfg.DashboardAgg.IntervalSeconds != 60 {
t.Fatalf("DashboardAgg.IntervalSeconds = %d, want 60", cfg.DashboardAgg.IntervalSeconds)
}
if cfg.DashboardAgg.LookbackSeconds != 120 {
t.Fatalf("DashboardAgg.LookbackSeconds = %d, want 120", cfg.DashboardAgg.LookbackSeconds)
}
if cfg.DashboardAgg.BackfillEnabled {
t.Fatalf("DashboardAgg.BackfillEnabled = true, want false")
}
if cfg.DashboardAgg.BackfillMaxDays != 31 {
t.Fatalf("DashboardAgg.BackfillMaxDays = %d, want 31", cfg.DashboardAgg.BackfillMaxDays)
}
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
}
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
}
if cfg.DashboardAgg.Retention.DailyDays != 730 {
t.Fatalf("DashboardAgg.Retention.DailyDays = %d, want 730", cfg.DashboardAgg.Retention.DailyDays)
}
if cfg.DashboardAgg.RecomputeDays != 2 {
t.Fatalf("DashboardAgg.RecomputeDays = %d, want 2", cfg.DashboardAgg.RecomputeDays)
}
}
func TestValidateDashboardAggregationConfigDisabled(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.DashboardAgg.Enabled = false
cfg.DashboardAgg.IntervalSeconds = -1
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error for negative dashboard_aggregation.interval_seconds, got nil")
}
if !strings.Contains(err.Error(), "dashboard_aggregation.interval_seconds") {
t.Fatalf("Validate() expected interval_seconds error, got: %v", err)
}
}
func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.DashboardAgg.BackfillEnabled = true
cfg.DashboardAgg.BackfillMaxDays = 0
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error for dashboard_aggregation.backfill_max_days, got nil")
}
if !strings.Contains(err.Error(), "dashboard_aggregation.backfill_max_days") {
t.Fatalf("Validate() expected backfill_max_days error, got: %v", err)
}
}
func TestLoadDefaultUsageCleanupConfig(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if !cfg.UsageCleanup.Enabled {
t.Fatalf("UsageCleanup.Enabled = false, want true")
}
if cfg.UsageCleanup.MaxRangeDays != 31 {
t.Fatalf("UsageCleanup.MaxRangeDays = %d, want 31", cfg.UsageCleanup.MaxRangeDays)
}
if cfg.UsageCleanup.BatchSize != 5000 {
t.Fatalf("UsageCleanup.BatchSize = %d, want 5000", cfg.UsageCleanup.BatchSize)
}
if cfg.UsageCleanup.WorkerIntervalSeconds != 10 {
t.Fatalf("UsageCleanup.WorkerIntervalSeconds = %d, want 10", cfg.UsageCleanup.WorkerIntervalSeconds)
}
if cfg.UsageCleanup.TaskTimeoutSeconds != 1800 {
t.Fatalf("UsageCleanup.TaskTimeoutSeconds = %d, want 1800", cfg.UsageCleanup.TaskTimeoutSeconds)
}
}
func TestValidateUsageCleanupConfigEnabled(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.UsageCleanup.Enabled = true
cfg.UsageCleanup.MaxRangeDays = 0
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error for usage_cleanup.max_range_days, got nil")
}
if !strings.Contains(err.Error(), "usage_cleanup.max_range_days") {
t.Fatalf("Validate() expected max_range_days error, got: %v", err)
}
}
func TestValidateUsageCleanupConfigDisabled(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.UsageCleanup.Enabled = false
cfg.UsageCleanup.BatchSize = -1
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error for usage_cleanup.batch_size, got nil")
}
if !strings.Contains(err.Error(), "usage_cleanup.batch_size") {
t.Fatalf("Validate() expected batch_size error, got: %v", err)
}
}
func TestConfigAddressHelpers(t *testing.T) {
server := ServerConfig{Host: "127.0.0.1", Port: 9000}
if server.Address() != "127.0.0.1:9000" {
t.Fatalf("ServerConfig.Address() = %q", server.Address())
}
dbCfg := DatabaseConfig{
Host: "localhost",
Port: 5432,
User: "postgres",
Password: "",
DBName: "sub2api",
SSLMode: "disable",
}
if !strings.Contains(dbCfg.DSN(), "password=") {
} else {
t.Fatalf("DatabaseConfig.DSN() should not include password when empty")
}
dbCfg.Password = "secret"
if !strings.Contains(dbCfg.DSN(), "password=secret") {
t.Fatalf("DatabaseConfig.DSN() missing password")
}
dbCfg.Password = ""
if strings.Contains(dbCfg.DSNWithTimezone("UTC"), "password=") {
t.Fatalf("DatabaseConfig.DSNWithTimezone() should omit password when empty")
}
if !strings.Contains(dbCfg.DSNWithTimezone(""), "TimeZone=Asia/Shanghai") {
t.Fatalf("DatabaseConfig.DSNWithTimezone() should use default timezone")
}
if !strings.Contains(dbCfg.DSNWithTimezone("UTC"), "TimeZone=UTC") {
t.Fatalf("DatabaseConfig.DSNWithTimezone() should use provided timezone")
}
redis := RedisConfig{Host: "redis", Port: 6379}
if redis.Address() != "redis:6379" {
t.Fatalf("RedisConfig.Address() = %q", redis.Address())
}
}
func TestNormalizeStringSlice(t *testing.T) {
values := normalizeStringSlice([]string{" a ", "", "b", " ", "c"})
if len(values) != 3 || values[0] != "a" || values[1] != "b" || values[2] != "c" {
t.Fatalf("normalizeStringSlice() unexpected result: %#v", values)
}
if normalizeStringSlice(nil) != nil {
t.Fatalf("normalizeStringSlice(nil) expected nil slice")
}
}
func TestGetServerAddressFromEnv(t *testing.T) {
t.Setenv("SERVER_HOST", "127.0.0.1")
t.Setenv("SERVER_PORT", "9090")
address := GetServerAddress()
if address != "127.0.0.1:9090" {
t.Fatalf("GetServerAddress() = %q", address)
}
}
func TestValidateAbsoluteHTTPURL(t *testing.T) {
if err := ValidateAbsoluteHTTPURL("https://example.com/path"); err != nil {
t.Fatalf("ValidateAbsoluteHTTPURL valid url error: %v", err)
}
if err := ValidateAbsoluteHTTPURL(""); err == nil {
t.Fatalf("ValidateAbsoluteHTTPURL should reject empty url")
}
if err := ValidateAbsoluteHTTPURL("/relative"); err == nil {
t.Fatalf("ValidateAbsoluteHTTPURL should reject relative url")
}
if err := ValidateAbsoluteHTTPURL("ftp://example.com"); err == nil {
t.Fatalf("ValidateAbsoluteHTTPURL should reject ftp scheme")
}
if err := ValidateAbsoluteHTTPURL("https://example.com/#frag"); err == nil {
t.Fatalf("ValidateAbsoluteHTTPURL should reject fragment")
}
}
func TestValidateFrontendRedirectURL(t *testing.T) {
if err := ValidateFrontendRedirectURL("/auth/callback"); err != nil {
t.Fatalf("ValidateFrontendRedirectURL relative error: %v", err)
}
if err := ValidateFrontendRedirectURL("https://example.com/auth"); err != nil {
t.Fatalf("ValidateFrontendRedirectURL absolute error: %v", err)
}
if err := ValidateFrontendRedirectURL("example.com/path"); err == nil {
t.Fatalf("ValidateFrontendRedirectURL should reject non-absolute url")
}
if err := ValidateFrontendRedirectURL("//evil.com"); err == nil {
t.Fatalf("ValidateFrontendRedirectURL should reject // prefix")
}
if err := ValidateFrontendRedirectURL("javascript:alert(1)"); err == nil {
t.Fatalf("ValidateFrontendRedirectURL should reject javascript scheme")
}
}
func TestWarnIfInsecureURL(t *testing.T) {
warnIfInsecureURL("test", "http://example.com")
warnIfInsecureURL("test", "bad://url")
}
func TestGenerateJWTSecretDefaultLength(t *testing.T) {
secret, err := generateJWTSecret(0)
if err != nil {
t.Fatalf("generateJWTSecret error: %v", err)
}
if len(secret) == 0 {
t.Fatalf("generateJWTSecret returned empty string")
}
}
func TestValidateOpsCleanupScheduleRequired(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Ops.Cleanup.Enabled = true
cfg.Ops.Cleanup.Schedule = ""
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error for ops.cleanup.schedule")
}
if !strings.Contains(err.Error(), "ops.cleanup.schedule") {
t.Fatalf("Validate() expected ops.cleanup.schedule error, got: %v", err)
}
}
func TestValidateConcurrencyPingInterval(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Concurrency.PingInterval = 3
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error for concurrency.ping_interval")
}
if !strings.Contains(err.Error(), "concurrency.ping_interval") {
t.Fatalf("Validate() expected concurrency.ping_interval error, got: %v", err)
}
}
func TestProvideConfig(t *testing.T) {
viper.Reset()
if _, err := ProvideConfig(); err != nil {
t.Fatalf("ProvideConfig() error: %v", err)
}
}
func TestValidateConfigWithLinuxDoEnabled(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Security.CSP.Enabled = true
cfg.Security.CSP.Policy = "default-src 'self'"
cfg.LinuxDo.Enabled = true
cfg.LinuxDo.ClientID = "client"
cfg.LinuxDo.ClientSecret = "secret"
cfg.LinuxDo.AuthorizeURL = "https://example.com/oauth2/authorize"
cfg.LinuxDo.TokenURL = "https://example.com/oauth2/token"
cfg.LinuxDo.UserInfoURL = "https://example.com/oauth2/userinfo"
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback"
cfg.LinuxDo.TokenAuthMethod = "client_secret_post"
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate() unexpected error: %v", err)
}
}
func TestValidateJWTSecretStrength(t *testing.T) {
if !isWeakJWTSecret("change-me-in-production") {
t.Fatalf("isWeakJWTSecret should detect weak secret")
}
if isWeakJWTSecret("StrongSecretValue") {
t.Fatalf("isWeakJWTSecret should accept strong secret")
}
}
func TestGenerateJWTSecretWithLength(t *testing.T) {
secret, err := generateJWTSecret(16)
if err != nil {
t.Fatalf("generateJWTSecret error: %v", err)
}
if len(secret) == 0 {
t.Fatalf("generateJWTSecret returned empty string")
}
}
func TestValidateAbsoluteHTTPURLMissingHost(t *testing.T) {
if err := ValidateAbsoluteHTTPURL("https://"); err == nil {
t.Fatalf("ValidateAbsoluteHTTPURL should reject missing host")
}
}
func TestValidateFrontendRedirectURLInvalidChars(t *testing.T) {
if err := ValidateFrontendRedirectURL("/auth/\ncallback"); err == nil {
t.Fatalf("ValidateFrontendRedirectURL should reject invalid chars")
}
if err := ValidateFrontendRedirectURL("http://"); err == nil {
t.Fatalf("ValidateFrontendRedirectURL should reject missing host")
}
if err := ValidateFrontendRedirectURL("mailto:user@example.com"); err == nil {
t.Fatalf("ValidateFrontendRedirectURL should reject mailto")
}
}
func TestWarnIfInsecureURLHTTPS(t *testing.T) {
warnIfInsecureURL("secure", "https://example.com")
}
func TestValidateConfigErrors(t *testing.T) {
buildValid := func(t *testing.T) *Config {
t.Helper()
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
return cfg
}
cases := []struct {
name string
mutate func(*Config)
wantErr string
}{
{
name: "jwt expire hour positive",
mutate: func(c *Config) { c.JWT.ExpireHour = 0 },
wantErr: "jwt.expire_hour must be positive",
},
{
name: "jwt expire hour max",
mutate: func(c *Config) { c.JWT.ExpireHour = 200 },
wantErr: "jwt.expire_hour must be <= 168",
},
{
name: "csp policy required",
mutate: func(c *Config) { c.Security.CSP.Enabled = true; c.Security.CSP.Policy = "" },
wantErr: "security.csp.policy",
},
{
name: "linuxdo client id required",
mutate: func(c *Config) {
c.LinuxDo.Enabled = true
c.LinuxDo.ClientID = ""
},
wantErr: "linuxdo_connect.client_id",
},
{
name: "linuxdo token auth method",
mutate: func(c *Config) {
c.LinuxDo.Enabled = true
c.LinuxDo.ClientID = "client"
c.LinuxDo.ClientSecret = "secret"
c.LinuxDo.AuthorizeURL = "https://example.com/authorize"
c.LinuxDo.TokenURL = "https://example.com/token"
c.LinuxDo.UserInfoURL = "https://example.com/userinfo"
c.LinuxDo.RedirectURL = "https://example.com/callback"
c.LinuxDo.FrontendRedirectURL = "/auth/callback"
c.LinuxDo.TokenAuthMethod = "invalid"
},
wantErr: "linuxdo_connect.token_auth_method",
},
{
name: "billing circuit breaker threshold",
mutate: func(c *Config) { c.Billing.CircuitBreaker.FailureThreshold = 0 },
wantErr: "billing.circuit_breaker.failure_threshold",
},
{
name: "billing circuit breaker reset",
mutate: func(c *Config) { c.Billing.CircuitBreaker.ResetTimeoutSeconds = 0 },
wantErr: "billing.circuit_breaker.reset_timeout_seconds",
},
{
name: "billing circuit breaker half open",
mutate: func(c *Config) { c.Billing.CircuitBreaker.HalfOpenRequests = 0 },
wantErr: "billing.circuit_breaker.half_open_requests",
},
{
name: "database max open conns",
mutate: func(c *Config) { c.Database.MaxOpenConns = 0 },
wantErr: "database.max_open_conns",
},
{
name: "database max lifetime",
mutate: func(c *Config) { c.Database.ConnMaxLifetimeMinutes = -1 },
wantErr: "database.conn_max_lifetime_minutes",
},
{
name: "database idle exceeds open",
mutate: func(c *Config) { c.Database.MaxIdleConns = c.Database.MaxOpenConns + 1 },
wantErr: "database.max_idle_conns cannot exceed",
},
{
name: "redis dial timeout",
mutate: func(c *Config) { c.Redis.DialTimeoutSeconds = 0 },
wantErr: "redis.dial_timeout_seconds",
},
{
name: "redis read timeout",
mutate: func(c *Config) { c.Redis.ReadTimeoutSeconds = 0 },
wantErr: "redis.read_timeout_seconds",
},
{
name: "redis write timeout",
mutate: func(c *Config) { c.Redis.WriteTimeoutSeconds = 0 },
wantErr: "redis.write_timeout_seconds",
},
{
name: "redis pool size",
mutate: func(c *Config) { c.Redis.PoolSize = 0 },
wantErr: "redis.pool_size",
},
{
name: "redis idle exceeds pool",
mutate: func(c *Config) { c.Redis.MinIdleConns = c.Redis.PoolSize + 1 },
wantErr: "redis.min_idle_conns cannot exceed",
},
{
name: "dashboard cache disabled negative",
mutate: func(c *Config) { c.Dashboard.Enabled = false; c.Dashboard.StatsTTLSeconds = -1 },
wantErr: "dashboard_cache.stats_ttl_seconds",
},
{
name: "dashboard cache fresh ttl positive",
mutate: func(c *Config) { c.Dashboard.Enabled = true; c.Dashboard.StatsFreshTTLSeconds = 0 },
wantErr: "dashboard_cache.stats_fresh_ttl_seconds",
},
{
name: "dashboard aggregation enabled interval",
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.IntervalSeconds = 0 },
wantErr: "dashboard_aggregation.interval_seconds",
},
{
name: "dashboard aggregation backfill positive",
mutate: func(c *Config) {
c.DashboardAgg.Enabled = true
c.DashboardAgg.BackfillEnabled = true
c.DashboardAgg.BackfillMaxDays = 0
},
wantErr: "dashboard_aggregation.backfill_max_days",
},
{
name: "dashboard aggregation retention",
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
wantErr: "dashboard_aggregation.retention.usage_logs_days",
},
{
name: "dashboard aggregation disabled interval",
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
wantErr: "dashboard_aggregation.interval_seconds",
},
{
name: "usage cleanup max range",
mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.MaxRangeDays = 0 },
wantErr: "usage_cleanup.max_range_days",
},
{
name: "usage cleanup worker interval",
mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.WorkerIntervalSeconds = 0 },
wantErr: "usage_cleanup.worker_interval_seconds",
},
{
name: "usage cleanup batch size",
mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.BatchSize = 0 },
wantErr: "usage_cleanup.batch_size",
},
{
name: "usage cleanup disabled negative",
mutate: func(c *Config) { c.UsageCleanup.Enabled = false; c.UsageCleanup.BatchSize = -1 },
wantErr: "usage_cleanup.batch_size",
},
{
name: "gateway max body size",
mutate: func(c *Config) { c.Gateway.MaxBodySize = 0 },
wantErr: "gateway.max_body_size",
},
{
name: "gateway max idle conns",
mutate: func(c *Config) { c.Gateway.MaxIdleConns = 0 },
wantErr: "gateway.max_idle_conns",
},
{
name: "gateway max idle conns per host",
mutate: func(c *Config) { c.Gateway.MaxIdleConnsPerHost = 0 },
wantErr: "gateway.max_idle_conns_per_host",
},
{
name: "gateway idle timeout",
mutate: func(c *Config) { c.Gateway.IdleConnTimeoutSeconds = 0 },
wantErr: "gateway.idle_conn_timeout_seconds",
},
{
name: "gateway max upstream clients",
mutate: func(c *Config) { c.Gateway.MaxUpstreamClients = 0 },
wantErr: "gateway.max_upstream_clients",
},
{
name: "gateway client idle ttl",
mutate: func(c *Config) { c.Gateway.ClientIdleTTLSeconds = 0 },
wantErr: "gateway.client_idle_ttl_seconds",
},
{
name: "gateway concurrency slot ttl",
mutate: func(c *Config) { c.Gateway.ConcurrencySlotTTLMinutes = 0 },
wantErr: "gateway.concurrency_slot_ttl_minutes",
},
{
name: "gateway max conns per host",
mutate: func(c *Config) { c.Gateway.MaxConnsPerHost = -1 },
wantErr: "gateway.max_conns_per_host",
},
{
name: "gateway connection isolation",
mutate: func(c *Config) { c.Gateway.ConnectionPoolIsolation = "invalid" },
wantErr: "gateway.connection_pool_isolation",
},
{
name: "gateway stream keepalive range",
mutate: func(c *Config) { c.Gateway.StreamKeepaliveInterval = 4 },
wantErr: "gateway.stream_keepalive_interval",
},
{
name: "gateway stream data interval range",
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = 5 },
wantErr: "gateway.stream_data_interval_timeout",
},
{
name: "gateway stream data interval negative",
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = -1 },
wantErr: "gateway.stream_data_interval_timeout must be non-negative",
},
{
name: "gateway max line size",
mutate: func(c *Config) { c.Gateway.MaxLineSize = 1024 },
wantErr: "gateway.max_line_size must be at least",
},
{
name: "gateway max line size negative",
mutate: func(c *Config) { c.Gateway.MaxLineSize = -1 },
wantErr: "gateway.max_line_size must be non-negative",
},
{
name: "gateway scheduling sticky waiting",
mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 },
wantErr: "gateway.scheduling.sticky_session_max_waiting",
},
{
name: "gateway scheduling outbox poll",
mutate: func(c *Config) { c.Gateway.Scheduling.OutboxPollIntervalSeconds = 0 },
wantErr: "gateway.scheduling.outbox_poll_interval_seconds",
},
{
name: "gateway scheduling outbox failures",
mutate: func(c *Config) { c.Gateway.Scheduling.OutboxLagRebuildFailures = 0 },
wantErr: "gateway.scheduling.outbox_lag_rebuild_failures",
},
{
name: "gateway outbox lag rebuild",
mutate: func(c *Config) {
c.Gateway.Scheduling.OutboxLagWarnSeconds = 10
c.Gateway.Scheduling.OutboxLagRebuildSeconds = 5
},
wantErr: "gateway.scheduling.outbox_lag_rebuild_seconds",
},
{
name: "ops metrics collector ttl",
mutate: func(c *Config) { c.Ops.MetricsCollectorCache.TTL = -1 },
wantErr: "ops.metrics_collector_cache.ttl",
},
{
name: "ops cleanup retention",
mutate: func(c *Config) { c.Ops.Cleanup.ErrorLogRetentionDays = -1 },
wantErr: "ops.cleanup.error_log_retention_days",
},
{
name: "ops cleanup minute retention",
mutate: func(c *Config) { c.Ops.Cleanup.MinuteMetricsRetentionDays = -1 },
wantErr: "ops.cleanup.minute_metrics_retention_days",
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
cfg := buildValid(t)
tt.mutate(cfg)
err := cfg.Validate()
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("Validate() error = %v, want %q", err, tt.wantErr)
}
})
}
}

View File

@@ -44,7 +44,6 @@ type AccountHandler struct {
accountTestService *service.AccountTestService
concurrencyService *service.ConcurrencyService
crsSyncService *service.CRSSyncService
sessionLimitCache service.SessionLimitCache
}
// NewAccountHandler creates a new admin account handler
@@ -59,7 +58,6 @@ func NewAccountHandler(
accountTestService *service.AccountTestService,
concurrencyService *service.ConcurrencyService,
crsSyncService *service.CRSSyncService,
sessionLimitCache service.SessionLimitCache,
) *AccountHandler {
return &AccountHandler{
adminService: adminService,
@@ -72,7 +70,6 @@ func NewAccountHandler(
accountTestService: accountTestService,
concurrencyService: concurrencyService,
crsSyncService: crsSyncService,
sessionLimitCache: sessionLimitCache,
}
}
@@ -87,7 +84,6 @@ type CreateAccountRequest struct {
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
RateMultiplier *float64 `json:"rate_multiplier"`
GroupIDs []int64 `json:"group_ids"`
ExpiresAt *int64 `json:"expires_at"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
@@ -105,7 +101,6 @@ type UpdateAccountRequest struct {
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
RateMultiplier *float64 `json:"rate_multiplier"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
GroupIDs *[]int64 `json:"group_ids"`
ExpiresAt *int64 `json:"expires_at"`
@@ -120,7 +115,6 @@ type BulkUpdateAccountsRequest struct {
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
RateMultiplier *float64 `json:"rate_multiplier"`
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
Schedulable *bool `json:"schedulable"`
GroupIDs *[]int64 `json:"group_ids"`
@@ -133,9 +127,6 @@ type BulkUpdateAccountsRequest struct {
type AccountWithConcurrency struct {
*dto.Account
CurrentConcurrency int `json:"current_concurrency"`
// 以下字段仅对 Anthropic OAuth/SetupToken 账号有效,且仅在启用相应功能时返回
CurrentWindowCost *float64 `json:"current_window_cost,omitempty"` // 当前窗口费用
ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数
}
// List handles listing all accounts with pagination
@@ -170,85 +161,13 @@ func (h *AccountHandler) List(c *gin.Context) {
concurrencyCounts = make(map[int64]int)
}
// 识别需要查询窗口费用和会话数的账号Anthropic OAuth/SetupToken 且启用了相应功能)
windowCostAccountIDs := make([]int64, 0)
sessionLimitAccountIDs := make([]int64, 0)
for i := range accounts {
acc := &accounts[i]
if acc.IsAnthropicOAuthOrSetupToken() {
if acc.GetWindowCostLimit() > 0 {
windowCostAccountIDs = append(windowCostAccountIDs, acc.ID)
}
if acc.GetMaxSessions() > 0 {
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
}
}
}
// 并行获取窗口费用和活跃会话数
var windowCosts map[int64]float64
var activeSessions map[int64]int
// 获取活跃会话数(批量查询)
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs)
if activeSessions == nil {
activeSessions = make(map[int64]int)
}
}
// 获取窗口费用(并行查询)
if len(windowCostAccountIDs) > 0 {
windowCosts = make(map[int64]float64)
var mu sync.Mutex
g, gctx := errgroup.WithContext(c.Request.Context())
g.SetLimit(10) // 限制并发数
for i := range accounts {
acc := &accounts[i]
if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 {
continue
}
accCopy := acc // 闭包捕获
g.Go(func() error {
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
startTime := accCopy.GetCurrentWindowStartTime()
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
if err == nil && stats != nil {
mu.Lock()
windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用
mu.Unlock()
}
return nil // 不返回错误,允许部分失败
})
}
_ = g.Wait()
}
// Build response with concurrency info
result := make([]AccountWithConcurrency, len(accounts))
for i := range accounts {
acc := &accounts[i]
item := AccountWithConcurrency{
Account: dto.AccountFromService(acc),
CurrentConcurrency: concurrencyCounts[acc.ID],
result[i] = AccountWithConcurrency{
Account: dto.AccountFromService(&accounts[i]),
CurrentConcurrency: concurrencyCounts[accounts[i].ID],
}
// 添加窗口费用(仅当启用时)
if windowCosts != nil {
if cost, ok := windowCosts[acc.ID]; ok {
item.CurrentWindowCost = &cost
}
}
// 添加活跃会话数(仅当启用时)
if activeSessions != nil {
if count, ok := activeSessions[acc.ID]; ok {
item.ActiveSessions = &count
}
}
result[i] = item
}
response.Paginated(c, result, total, page, pageSize)
@@ -280,10 +199,6 @@ func (h *AccountHandler) Create(c *gin.Context) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.RateMultiplier != nil && *req.RateMultiplier < 0 {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -298,7 +213,6 @@ func (h *AccountHandler) Create(c *gin.Context) {
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
RateMultiplier: req.RateMultiplier,
GroupIDs: req.GroupIDs,
ExpiresAt: req.ExpiresAt,
AutoPauseOnExpired: req.AutoPauseOnExpired,
@@ -344,10 +258,6 @@ func (h *AccountHandler) Update(c *gin.Context) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.RateMultiplier != nil && *req.RateMultiplier < 0 {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -361,7 +271,6 @@ func (h *AccountHandler) Update(c *gin.Context) {
ProxyID: req.ProxyID,
Concurrency: req.Concurrency, // 指针类型nil 表示未提供
Priority: req.Priority, // 指针类型nil 表示未提供
RateMultiplier: req.RateMultiplier,
Status: req.Status,
GroupIDs: req.GroupIDs,
ExpiresAt: req.ExpiresAt,
@@ -541,36 +450,6 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
newCredentials[k] = v
}
}
// 如果 project_id 获取失败,先更新凭证,再标记账户为 error
if tokenInfo.ProjectIDMissing {
// 先更新凭证
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: newCredentials,
})
if updateErr != nil {
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
return
}
// 标记账户为 error
if setErr := h.adminService.SetAccountError(c.Request.Context(), accountID, "missing_project_id: 账户缺少project id可能无法使用Antigravity"); setErr != nil {
response.InternalError(c, "Failed to set account error: "+setErr.Error())
return
}
response.Success(c, gin.H{
"message": "Token refreshed but project_id is missing, account marked as error",
"warning": "missing_project_id",
})
return
}
// 成功获取到 project_id如果之前是 missing_project_id 错误则清除
if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") {
if _, clearErr := h.adminService.ClearAccountError(c.Request.Context(), accountID); clearErr != nil {
response.InternalError(c, "Failed to clear account error: "+clearErr.Error())
return
}
}
} else {
// Use Anthropic/Claude OAuth service to refresh token
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
@@ -773,10 +652,6 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.RateMultiplier != nil && *req.RateMultiplier < 0 {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -785,7 +660,6 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
req.ProxyID != nil ||
req.Concurrency != nil ||
req.Priority != nil ||
req.RateMultiplier != nil ||
req.Status != "" ||
req.Schedulable != nil ||
req.GroupIDs != nil ||
@@ -803,7 +677,6 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
RateMultiplier: req.RateMultiplier,
Status: req.Status,
Schedulable: req.Schedulable,
GroupIDs: req.GroupIDs,

View File

@@ -1,262 +0,0 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func setupAdminRouter() (*gin.Engine, *stubAdminService) {
gin.SetMode(gin.TestMode)
router := gin.New()
adminSvc := newStubAdminService()
userHandler := NewUserHandler(adminSvc)
groupHandler := NewGroupHandler(adminSvc)
proxyHandler := NewProxyHandler(adminSvc)
redeemHandler := NewRedeemHandler(adminSvc)
router.GET("/api/v1/admin/users", userHandler.List)
router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
router.POST("/api/v1/admin/users", userHandler.Create)
router.PUT("/api/v1/admin/users/:id", userHandler.Update)
router.DELETE("/api/v1/admin/users/:id", userHandler.Delete)
router.POST("/api/v1/admin/users/:id/balance", userHandler.UpdateBalance)
router.GET("/api/v1/admin/users/:id/api-keys", userHandler.GetUserAPIKeys)
router.GET("/api/v1/admin/users/:id/usage", userHandler.GetUserUsage)
router.GET("/api/v1/admin/groups", groupHandler.List)
router.GET("/api/v1/admin/groups/all", groupHandler.GetAll)
router.GET("/api/v1/admin/groups/:id", groupHandler.GetByID)
router.POST("/api/v1/admin/groups", groupHandler.Create)
router.PUT("/api/v1/admin/groups/:id", groupHandler.Update)
router.DELETE("/api/v1/admin/groups/:id", groupHandler.Delete)
router.GET("/api/v1/admin/groups/:id/stats", groupHandler.GetStats)
router.GET("/api/v1/admin/groups/:id/api-keys", groupHandler.GetGroupAPIKeys)
router.GET("/api/v1/admin/proxies", proxyHandler.List)
router.GET("/api/v1/admin/proxies/all", proxyHandler.GetAll)
router.GET("/api/v1/admin/proxies/:id", proxyHandler.GetByID)
router.POST("/api/v1/admin/proxies", proxyHandler.Create)
router.PUT("/api/v1/admin/proxies/:id", proxyHandler.Update)
router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete)
router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete)
router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test)
router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats)
router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts)
router.GET("/api/v1/admin/redeem-codes", redeemHandler.List)
router.GET("/api/v1/admin/redeem-codes/:id", redeemHandler.GetByID)
router.POST("/api/v1/admin/redeem-codes", redeemHandler.Generate)
router.DELETE("/api/v1/admin/redeem-codes/:id", redeemHandler.Delete)
router.POST("/api/v1/admin/redeem-codes/batch-delete", redeemHandler.BatchDelete)
router.POST("/api/v1/admin/redeem-codes/:id/expire", redeemHandler.Expire)
router.GET("/api/v1/admin/redeem-codes/:id/stats", redeemHandler.GetStats)
return router, adminSvc
}
func TestUserHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/users?page=1&page_size=20", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2}
body, _ := json.Marshal(createBody)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
updateBody := map[string]any{"email": "updated@example.com"}
body, _ = json.Marshal(updateBody)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/users/1", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/users/1", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/balance", bytes.NewBufferString(`{"balance":1,"operation":"add"}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1/api-keys", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1/usage?period=today", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestGroupHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/all", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ := json.Marshal(map[string]any{"name": "new", "platform": "anthropic", "subscription_type": "standard"})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/groups", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ = json.Marshal(map[string]any{"name": "update"})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/groups/2", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/groups/2", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2/stats", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2/api-keys", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestProxyHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/all", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ := json.Marshal(map[string]any{"name": "proxy", "protocol": "http", "host": "localhost", "port": 8080})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ = json.Marshal(map[string]any{"name": "proxy2"})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/proxies/4", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/proxies/4", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/batch-delete", bytes.NewBufferString(`{"ids":[1,2]}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/test", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/accounts", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestRedeemHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/5", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ := json.Marshal(map[string]any{"count": 1, "type": "balance", "value": 10})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/redeem-codes/5", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/batch-delete", bytes.NewBufferString(`{"ids":[1,2]}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/5/expire", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/5/stats", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}

View File

@@ -1,134 +0,0 @@
package admin
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/netip"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestParseTimeRange(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
req := httptest.NewRequest(http.MethodGet, "/?start_date=2024-01-01&end_date=2024-01-02&timezone=UTC", nil)
c.Request = req
start, end := parseTimeRange(c)
require.Equal(t, time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), start)
require.Equal(t, time.Date(2024, 1, 3, 0, 0, 0, 0, time.UTC), end)
req = httptest.NewRequest(http.MethodGet, "/?start_date=bad&timezone=UTC", nil)
c.Request = req
start, end = parseTimeRange(c)
require.False(t, start.IsZero())
require.False(t, end.IsZero())
}
func TestParseOpsViewParam(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/?view=excluded", nil)
require.Equal(t, opsListViewExcluded, parseOpsViewParam(c))
c2, _ := gin.CreateTestContext(w)
c2.Request = httptest.NewRequest(http.MethodGet, "/?view=all", nil)
require.Equal(t, opsListViewAll, parseOpsViewParam(c2))
c3, _ := gin.CreateTestContext(w)
c3.Request = httptest.NewRequest(http.MethodGet, "/?view=unknown", nil)
require.Equal(t, opsListViewErrors, parseOpsViewParam(c3))
require.Equal(t, "", parseOpsViewParam(nil))
}
func TestParseOpsDuration(t *testing.T) {
dur, ok := parseOpsDuration("1h")
require.True(t, ok)
require.Equal(t, time.Hour, dur)
_, ok = parseOpsDuration("invalid")
require.False(t, ok)
}
func TestParseOpsTimeRange(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
now := time.Now().UTC()
startStr := now.Add(-time.Hour).Format(time.RFC3339)
endStr := now.Format(time.RFC3339)
c.Request = httptest.NewRequest(http.MethodGet, "/?start_time="+startStr+"&end_time="+endStr, nil)
start, end, err := parseOpsTimeRange(c, "1h")
require.NoError(t, err)
require.True(t, start.Before(end))
c2, _ := gin.CreateTestContext(w)
c2.Request = httptest.NewRequest(http.MethodGet, "/?start_time=bad", nil)
_, _, err = parseOpsTimeRange(c2, "1h")
require.Error(t, err)
}
func TestParseOpsRealtimeWindow(t *testing.T) {
dur, label, ok := parseOpsRealtimeWindow("5m")
require.True(t, ok)
require.Equal(t, 5*time.Minute, dur)
require.Equal(t, "5min", label)
_, _, ok = parseOpsRealtimeWindow("invalid")
require.False(t, ok)
}
func TestPickThroughputBucketSeconds(t *testing.T) {
require.Equal(t, 60, pickThroughputBucketSeconds(30*time.Minute))
require.Equal(t, 300, pickThroughputBucketSeconds(6*time.Hour))
require.Equal(t, 3600, pickThroughputBucketSeconds(48*time.Hour))
}
func TestParseOpsQueryMode(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/?mode=raw", nil)
require.Equal(t, service.ParseOpsQueryMode("raw"), parseOpsQueryMode(c))
require.Equal(t, service.OpsQueryMode(""), parseOpsQueryMode(nil))
}
func TestOpsAlertRuleValidation(t *testing.T) {
raw := map[string]json.RawMessage{
"name": json.RawMessage(`"High error rate"`),
"metric_type": json.RawMessage(`"error_rate"`),
"operator": json.RawMessage(`">"`),
"threshold": json.RawMessage(`90`),
}
validated, err := validateOpsAlertRulePayload(raw)
require.NoError(t, err)
require.Equal(t, "High error rate", validated.Name)
_, err = validateOpsAlertRulePayload(map[string]json.RawMessage{})
require.Error(t, err)
require.True(t, isPercentOrRateMetric("error_rate"))
require.False(t, isPercentOrRateMetric("concurrency_queue_depth"))
}
func TestOpsWSHelpers(t *testing.T) {
prefixes, invalid := parseTrustedProxyList("10.0.0.0/8,invalid")
require.Len(t, prefixes, 1)
require.Len(t, invalid, 1)
host := hostWithoutPort("example.com:443")
require.Equal(t, "example.com", host)
addr := netip.MustParseAddr("10.0.0.1")
require.True(t, isAddrInTrustedProxies(addr, prefixes))
require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes))
}

View File

@@ -1,294 +0,0 @@
package admin
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type stubAdminService struct {
users []service.User
apiKeys []service.APIKey
groups []service.Group
accounts []service.Account
proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode
}
func newStubAdminService() *stubAdminService {
now := time.Now().UTC()
user := service.User{
ID: 1,
Email: "user@example.com",
Role: service.RoleUser,
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
apiKey := service.APIKey{
ID: 10,
UserID: user.ID,
Key: "sk-test",
Name: "test",
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
group := service.Group{
ID: 2,
Name: "group",
Platform: service.PlatformAnthropic,
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
account := service.Account{
ID: 3,
Name: "account",
Platform: service.PlatformAnthropic,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
proxy := service.Proxy{
ID: 4,
Name: "proxy",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
redeem := service.RedeemCode{
ID: 5,
Code: "R-TEST",
Type: service.RedeemTypeBalance,
Value: 10,
Status: service.StatusUnused,
CreatedAt: now,
}
return &stubAdminService{
users: []service.User{user},
apiKeys: []service.APIKey{apiKey},
groups: []service.Group{group},
accounts: []service.Account{account},
proxies: []service.Proxy{proxy},
proxyCounts: []service.ProxyWithAccountCount{{Proxy: proxy, AccountCount: 1}},
redeems: []service.RedeemCode{redeem},
}
}
func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters) ([]service.User, int64, error) {
return s.users, int64(len(s.users)), nil
}
func (s *stubAdminService) GetUser(ctx context.Context, id int64) (*service.User, error) {
for i := range s.users {
if s.users[i].ID == id {
return &s.users[i], nil
}
}
user := service.User{ID: id, Email: "user@example.com", Status: service.StatusActive}
return &user, nil
}
func (s *stubAdminService) CreateUser(ctx context.Context, input *service.CreateUserInput) (*service.User, error) {
user := service.User{ID: 100, Email: input.Email, Status: service.StatusActive}
return &user, nil
}
func (s *stubAdminService) UpdateUser(ctx context.Context, id int64, input *service.UpdateUserInput) (*service.User, error) {
user := service.User{ID: id, Email: "updated@example.com", Status: service.StatusActive}
return &user, nil
}
func (s *stubAdminService) DeleteUser(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*service.User, error) {
user := service.User{ID: userID, Balance: balance, Status: service.StatusActive}
return &user, nil
}
func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]service.APIKey, int64, error) {
return s.apiKeys, int64(len(s.apiKeys)), nil
}
func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
return map[string]any{"user_id": userID}, nil
}
func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]service.Group, int64, error) {
return s.groups, int64(len(s.groups)), nil
}
func (s *stubAdminService) GetAllGroups(ctx context.Context) ([]service.Group, error) {
return s.groups, nil
}
func (s *stubAdminService) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
return s.groups, nil
}
func (s *stubAdminService) GetGroup(ctx context.Context, id int64) (*service.Group, error) {
group := service.Group{ID: id, Name: "group", Status: service.StatusActive}
return &group, nil
}
func (s *stubAdminService) CreateGroup(ctx context.Context, input *service.CreateGroupInput) (*service.Group, error) {
group := service.Group{ID: 200, Name: input.Name, Status: service.StatusActive}
return &group, nil
}
func (s *stubAdminService) UpdateGroup(ctx context.Context, id int64, input *service.UpdateGroupInput) (*service.Group, error) {
group := service.Group{ID: id, Name: input.Name, Status: service.StatusActive}
return &group, nil
}
func (s *stubAdminService) DeleteGroup(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]service.APIKey, int64, error) {
return s.apiKeys, int64(len(s.apiKeys)), nil
}
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]service.Account, int64, error) {
return s.accounts, int64(len(s.accounts)), nil
}
func (s *stubAdminService) GetAccount(ctx context.Context, id int64) (*service.Account, error) {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
out := make([]*service.Account, 0, len(ids))
for _, id := range ids {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
out = append(out, &account)
}
return out, nil
}
func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.CreateAccountInput) (*service.Account, error) {
account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) DeleteAccount(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) RefreshAccountCredentials(ctx context.Context, id int64) (*service.Account, error) {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) ClearAccountError(ctx context.Context, id int64) (*service.Account, error) {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) SetAccountError(ctx context.Context, id int64, errorMsg string) error {
return nil
}
func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*service.Account, error) {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive, Schedulable: schedulable}
return &account, nil
}
func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *service.BulkUpdateAccountsInput) (*service.BulkUpdateAccountsResult, error) {
return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil
}
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
return s.proxies, int64(len(s.proxies)), nil
}
func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.ProxyWithAccountCount, int64, error) {
return s.proxyCounts, int64(len(s.proxyCounts)), nil
}
func (s *stubAdminService) GetAllProxies(ctx context.Context) ([]service.Proxy, error) {
return s.proxies, nil
}
func (s *stubAdminService) GetAllProxiesWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) {
return s.proxyCounts, nil
}
func (s *stubAdminService) GetProxy(ctx context.Context, id int64) (*service.Proxy, error) {
proxy := service.Proxy{ID: id, Name: "proxy", Status: service.StatusActive}
return &proxy, nil
}
func (s *stubAdminService) CreateProxy(ctx context.Context, input *service.CreateProxyInput) (*service.Proxy, error) {
proxy := service.Proxy{ID: 400, Name: input.Name, Status: service.StatusActive}
return &proxy, nil
}
func (s *stubAdminService) UpdateProxy(ctx context.Context, id int64, input *service.UpdateProxyInput) (*service.Proxy, error) {
proxy := service.Proxy{ID: id, Name: input.Name, Status: service.StatusActive}
return &proxy, nil
}
func (s *stubAdminService) DeleteProxy(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) BatchDeleteProxies(ctx context.Context, ids []int64) (*service.ProxyBatchDeleteResult, error) {
return &service.ProxyBatchDeleteResult{DeletedIDs: ids}, nil
}
func (s *stubAdminService) GetProxyAccounts(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
return []service.ProxyAccountSummary{{ID: 1, Name: "account"}}, nil
}
func (s *stubAdminService) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
return false, nil
}
func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.ProxyTestResult, error) {
return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
}
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) {
return s.redeems, int64(len(s.redeems)), nil
}
func (s *stubAdminService) GetRedeemCode(ctx context.Context, id int64) (*service.RedeemCode, error) {
code := service.RedeemCode{ID: id, Code: "R-TEST", Status: service.StatusUnused}
return &code, nil
}
func (s *stubAdminService) GenerateRedeemCodes(ctx context.Context, input *service.GenerateRedeemCodesInput) ([]service.RedeemCode, error) {
return s.redeems, nil
}
func (s *stubAdminService) DeleteRedeemCode(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) {
return int64(len(ids)), nil
}
func (s *stubAdminService) ExpireRedeemCode(ctx context.Context, id int64) (*service.RedeemCode, error) {
code := service.RedeemCode{ID: id, Code: "R-TEST", Status: service.StatusUsed}
return &code, nil
}
// Ensure stub implements interface.
var _ service.AdminService = (*stubAdminService)(nil)

View File

@@ -1,7 +1,6 @@
package admin
import (
"errors"
"strconv"
"time"
@@ -14,17 +13,15 @@ import (
// DashboardHandler handles admin dashboard statistics
type DashboardHandler struct {
dashboardService *service.DashboardService
aggregationService *service.DashboardAggregationService
startTime time.Time // Server start time for uptime calculation
dashboardService *service.DashboardService
startTime time.Time // Server start time for uptime calculation
}
// NewDashboardHandler creates a new admin dashboard handler
func NewDashboardHandler(dashboardService *service.DashboardService, aggregationService *service.DashboardAggregationService) *DashboardHandler {
func NewDashboardHandler(dashboardService *service.DashboardService) *DashboardHandler {
return &DashboardHandler{
dashboardService: dashboardService,
aggregationService: aggregationService,
startTime: time.Now(),
dashboardService: dashboardService,
startTime: time.Now(),
}
}
@@ -117,58 +114,6 @@ func (h *DashboardHandler) GetStats(c *gin.Context) {
// 性能指标
"rpm": stats.Rpm,
"tpm": stats.Tpm,
// 预聚合新鲜度
"hourly_active_users": stats.HourlyActiveUsers,
"stats_updated_at": stats.StatsUpdatedAt,
"stats_stale": stats.StatsStale,
})
}
type DashboardAggregationBackfillRequest struct {
Start string `json:"start"`
End string `json:"end"`
}
// BackfillAggregation handles triggering aggregation backfill
// POST /api/v1/admin/dashboard/aggregation/backfill
func (h *DashboardHandler) BackfillAggregation(c *gin.Context) {
if h.aggregationService == nil {
response.InternalError(c, "Aggregation service not available")
return
}
var req DashboardAggregationBackfillRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
start, err := time.Parse(time.RFC3339, req.Start)
if err != nil {
response.BadRequest(c, "Invalid start time")
return
}
end, err := time.Parse(time.RFC3339, req.End)
if err != nil {
response.BadRequest(c, "Invalid end time")
return
}
if err := h.aggregationService.TriggerBackfill(start, end); err != nil {
if errors.Is(err, service.ErrDashboardBackfillDisabled) {
response.Forbidden(c, "Backfill is disabled")
return
}
if errors.Is(err, service.ErrDashboardBackfillTooLarge) {
response.BadRequest(c, "Backfill range too large")
return
}
response.InternalError(c, "Failed to trigger backfill")
return
}
response.Success(c, gin.H{
"status": "accepted",
})
}
@@ -186,17 +131,13 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
// GetUsageTrend handles getting usage trend data
// GET /api/v1/admin/dashboard/trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
// Parse optional filter params
var userID, apiKeyID, accountID, groupID int64
var model string
var stream *bool
var billingType *int8
var userID, apiKeyID int64
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = id
@@ -207,35 +148,8 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
apiKeyID = id
}
}
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil {
accountID = id
}
}
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
groupID = id
}
}
if modelStr := c.Query("model"); modelStr != "" {
model = modelStr
}
if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
bt := int8(v)
billingType = &bt
} else {
response.BadRequest(c, "Invalid billing_type")
return
}
}
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
@@ -251,15 +165,12 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// GetModelStats handles getting model usage statistics
// GET /api/v1/admin/dashboard/models
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
// Parse optional filter params
var userID, apiKeyID, accountID, groupID int64
var stream *bool
var billingType *int8
var userID, apiKeyID int64
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = id
@@ -270,32 +181,8 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
apiKeyID = id
}
}
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil {
accountID = id
}
}
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
groupID = id
}
}
if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
bt := int8(v)
billingType = &bt
} else {
response.BadRequest(c, "Invalid billing_type")
return
}
}
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return

View File

@@ -40,9 +40,6 @@ type CreateGroupRequest struct {
ImagePrice4K *float64 `json:"image_price_4k"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
}
// UpdateGroupRequest represents update group request
@@ -63,9 +60,6 @@ type UpdateGroupRequest struct {
ImagePrice4K *float64 `json:"image_price_4k"`
ClaudeCodeOnly *bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
}
// List handles listing all groups with pagination
@@ -155,22 +149,20 @@ func (h *GroupHandler) Create(c *gin.Context) {
}
group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{
Name: req.Name,
Description: req.Description,
Platform: req.Platform,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD,
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled,
Name: req.Name,
Description: req.Description,
Platform: req.Platform,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD,
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -196,23 +188,21 @@ func (h *GroupHandler) Update(c *gin.Context) {
}
group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{
Name: req.Name,
Description: req.Description,
Platform: req.Platform,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
Status: req.Status,
SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD,
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled,
Name: req.Name,
Description: req.Description,
Platform: req.Platform,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
Status: req.Status,
SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD,
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
})
if err != nil {
response.ErrorFrom(c, err)

View File

@@ -1,602 +0,0 @@
package admin
import (
"encoding/json"
"fmt"
"math"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
)
var validOpsAlertMetricTypes = []string{
"success_rate",
"error_rate",
"upstream_error_rate",
"cpu_usage_percent",
"memory_usage_percent",
"concurrency_queue_depth",
}
var validOpsAlertMetricTypeSet = func() map[string]struct{} {
set := make(map[string]struct{}, len(validOpsAlertMetricTypes))
for _, v := range validOpsAlertMetricTypes {
set[v] = struct{}{}
}
return set
}()
var validOpsAlertOperators = []string{">", "<", ">=", "<=", "==", "!="}
var validOpsAlertOperatorSet = func() map[string]struct{} {
set := make(map[string]struct{}, len(validOpsAlertOperators))
for _, v := range validOpsAlertOperators {
set[v] = struct{}{}
}
return set
}()
var validOpsAlertSeverities = []string{"P0", "P1", "P2", "P3"}
var validOpsAlertSeveritySet = func() map[string]struct{} {
set := make(map[string]struct{}, len(validOpsAlertSeverities))
for _, v := range validOpsAlertSeverities {
set[v] = struct{}{}
}
return set
}()
type opsAlertRuleValidatedInput struct {
Name string
MetricType string
Operator string
Threshold float64
Severity string
WindowMinutes int
SustainedMinutes int
CooldownMinutes int
Enabled bool
NotifyEmail bool
WindowProvided bool
SustainedProvided bool
CooldownProvided bool
SeverityProvided bool
EnabledProvided bool
NotifyProvided bool
}
func isPercentOrRateMetric(metricType string) bool {
switch metricType {
case "success_rate",
"error_rate",
"upstream_error_rate",
"cpu_usage_percent",
"memory_usage_percent":
return true
default:
return false
}
}
func validateOpsAlertRulePayload(raw map[string]json.RawMessage) (*opsAlertRuleValidatedInput, error) {
if raw == nil {
return nil, fmt.Errorf("invalid request body")
}
requiredFields := []string{"name", "metric_type", "operator", "threshold"}
for _, field := range requiredFields {
if _, ok := raw[field]; !ok {
return nil, fmt.Errorf("%s is required", field)
}
}
var name string
if err := json.Unmarshal(raw["name"], &name); err != nil || strings.TrimSpace(name) == "" {
return nil, fmt.Errorf("name is required")
}
name = strings.TrimSpace(name)
var metricType string
if err := json.Unmarshal(raw["metric_type"], &metricType); err != nil || strings.TrimSpace(metricType) == "" {
return nil, fmt.Errorf("metric_type is required")
}
metricType = strings.TrimSpace(metricType)
if _, ok := validOpsAlertMetricTypeSet[metricType]; !ok {
return nil, fmt.Errorf("metric_type must be one of: %s", strings.Join(validOpsAlertMetricTypes, ", "))
}
var operator string
if err := json.Unmarshal(raw["operator"], &operator); err != nil || strings.TrimSpace(operator) == "" {
return nil, fmt.Errorf("operator is required")
}
operator = strings.TrimSpace(operator)
if _, ok := validOpsAlertOperatorSet[operator]; !ok {
return nil, fmt.Errorf("operator must be one of: %s", strings.Join(validOpsAlertOperators, ", "))
}
var threshold float64
if err := json.Unmarshal(raw["threshold"], &threshold); err != nil {
return nil, fmt.Errorf("threshold must be a number")
}
if math.IsNaN(threshold) || math.IsInf(threshold, 0) {
return nil, fmt.Errorf("threshold must be a finite number")
}
if isPercentOrRateMetric(metricType) {
if threshold < 0 || threshold > 100 {
return nil, fmt.Errorf("threshold must be between 0 and 100 for metric_type %s", metricType)
}
} else if threshold < 0 {
return nil, fmt.Errorf("threshold must be >= 0")
}
validated := &opsAlertRuleValidatedInput{
Name: name,
MetricType: metricType,
Operator: operator,
Threshold: threshold,
}
if v, ok := raw["severity"]; ok {
validated.SeverityProvided = true
var sev string
if err := json.Unmarshal(v, &sev); err != nil {
return nil, fmt.Errorf("severity must be a string")
}
sev = strings.ToUpper(strings.TrimSpace(sev))
if sev != "" {
if _, ok := validOpsAlertSeveritySet[sev]; !ok {
return nil, fmt.Errorf("severity must be one of: %s", strings.Join(validOpsAlertSeverities, ", "))
}
validated.Severity = sev
}
}
if validated.Severity == "" {
validated.Severity = "P2"
}
if v, ok := raw["enabled"]; ok {
validated.EnabledProvided = true
if err := json.Unmarshal(v, &validated.Enabled); err != nil {
return nil, fmt.Errorf("enabled must be a boolean")
}
} else {
validated.Enabled = true
}
if v, ok := raw["notify_email"]; ok {
validated.NotifyProvided = true
if err := json.Unmarshal(v, &validated.NotifyEmail); err != nil {
return nil, fmt.Errorf("notify_email must be a boolean")
}
} else {
validated.NotifyEmail = true
}
if v, ok := raw["window_minutes"]; ok {
validated.WindowProvided = true
if err := json.Unmarshal(v, &validated.WindowMinutes); err != nil {
return nil, fmt.Errorf("window_minutes must be an integer")
}
switch validated.WindowMinutes {
case 1, 5, 60:
default:
return nil, fmt.Errorf("window_minutes must be one of: 1, 5, 60")
}
} else {
validated.WindowMinutes = 1
}
if v, ok := raw["sustained_minutes"]; ok {
validated.SustainedProvided = true
if err := json.Unmarshal(v, &validated.SustainedMinutes); err != nil {
return nil, fmt.Errorf("sustained_minutes must be an integer")
}
if validated.SustainedMinutes < 1 || validated.SustainedMinutes > 1440 {
return nil, fmt.Errorf("sustained_minutes must be between 1 and 1440")
}
} else {
validated.SustainedMinutes = 1
}
if v, ok := raw["cooldown_minutes"]; ok {
validated.CooldownProvided = true
if err := json.Unmarshal(v, &validated.CooldownMinutes); err != nil {
return nil, fmt.Errorf("cooldown_minutes must be an integer")
}
if validated.CooldownMinutes < 0 || validated.CooldownMinutes > 1440 {
return nil, fmt.Errorf("cooldown_minutes must be between 0 and 1440")
}
} else {
validated.CooldownMinutes = 0
}
return validated, nil
}
// ListAlertRules returns all ops alert rules.
// GET /api/v1/admin/ops/alert-rules
func (h *OpsHandler) ListAlertRules(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
rules, err := h.opsService.ListAlertRules(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, rules)
}
// CreateAlertRule creates an ops alert rule.
// POST /api/v1/admin/ops/alert-rules
func (h *OpsHandler) CreateAlertRule(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
var raw map[string]json.RawMessage
if err := c.ShouldBindBodyWith(&raw, binding.JSON); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
validated, err := validateOpsAlertRulePayload(raw)
if err != nil {
response.BadRequest(c, err.Error())
return
}
var rule service.OpsAlertRule
if err := c.ShouldBindBodyWith(&rule, binding.JSON); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
rule.Name = validated.Name
rule.MetricType = validated.MetricType
rule.Operator = validated.Operator
rule.Threshold = validated.Threshold
rule.WindowMinutes = validated.WindowMinutes
rule.SustainedMinutes = validated.SustainedMinutes
rule.CooldownMinutes = validated.CooldownMinutes
rule.Severity = validated.Severity
rule.Enabled = validated.Enabled
rule.NotifyEmail = validated.NotifyEmail
created, err := h.opsService.CreateAlertRule(c.Request.Context(), &rule)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, created)
}
// UpdateAlertRule updates an existing ops alert rule.
// PUT /api/v1/admin/ops/alert-rules/:id
func (h *OpsHandler) UpdateAlertRule(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid rule ID")
return
}
var raw map[string]json.RawMessage
if err := c.ShouldBindBodyWith(&raw, binding.JSON); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
validated, err := validateOpsAlertRulePayload(raw)
if err != nil {
response.BadRequest(c, err.Error())
return
}
var rule service.OpsAlertRule
if err := c.ShouldBindBodyWith(&rule, binding.JSON); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
rule.ID = id
rule.Name = validated.Name
rule.MetricType = validated.MetricType
rule.Operator = validated.Operator
rule.Threshold = validated.Threshold
rule.WindowMinutes = validated.WindowMinutes
rule.SustainedMinutes = validated.SustainedMinutes
rule.CooldownMinutes = validated.CooldownMinutes
rule.Severity = validated.Severity
rule.Enabled = validated.Enabled
rule.NotifyEmail = validated.NotifyEmail
updated, err := h.opsService.UpdateAlertRule(c.Request.Context(), &rule)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, updated)
}
// DeleteAlertRule deletes an ops alert rule.
// DELETE /api/v1/admin/ops/alert-rules/:id
func (h *OpsHandler) DeleteAlertRule(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid rule ID")
return
}
if err := h.opsService.DeleteAlertRule(c.Request.Context(), id); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": true})
}
// GetAlertEvent returns a single ops alert event.
// GET /api/v1/admin/ops/alert-events/:id
func (h *OpsHandler) GetAlertEvent(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid event ID")
return
}
ev, err := h.opsService.GetAlertEventByID(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, ev)
}
// UpdateAlertEventStatus updates an ops alert event status.
// PUT /api/v1/admin/ops/alert-events/:id/status
func (h *OpsHandler) UpdateAlertEventStatus(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid event ID")
return
}
var payload struct {
Status string `json:"status"`
}
if err := c.ShouldBindJSON(&payload); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
payload.Status = strings.TrimSpace(payload.Status)
if payload.Status == "" {
response.BadRequest(c, "Invalid status")
return
}
if payload.Status != service.OpsAlertStatusResolved && payload.Status != service.OpsAlertStatusManualResolved {
response.BadRequest(c, "Invalid status")
return
}
var resolvedAt *time.Time
if payload.Status == service.OpsAlertStatusResolved || payload.Status == service.OpsAlertStatusManualResolved {
now := time.Now().UTC()
resolvedAt = &now
}
if err := h.opsService.UpdateAlertEventStatus(c.Request.Context(), id, payload.Status, resolvedAt); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"updated": true})
}
// ListAlertEvents lists recent ops alert events.
// GET /api/v1/admin/ops/alert-events
// CreateAlertSilence creates a scoped silence for ops alerts.
// POST /api/v1/admin/ops/alert-silences
func (h *OpsHandler) CreateAlertSilence(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
var payload struct {
RuleID int64 `json:"rule_id"`
Platform string `json:"platform"`
GroupID *int64 `json:"group_id"`
Region *string `json:"region"`
Until string `json:"until"`
Reason string `json:"reason"`
}
if err := c.ShouldBindJSON(&payload); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
until, err := time.Parse(time.RFC3339, strings.TrimSpace(payload.Until))
if err != nil {
response.BadRequest(c, "Invalid until")
return
}
createdBy := (*int64)(nil)
if subject, ok := middleware.GetAuthSubjectFromContext(c); ok {
uid := subject.UserID
createdBy = &uid
}
silence := &service.OpsAlertSilence{
RuleID: payload.RuleID,
Platform: strings.TrimSpace(payload.Platform),
GroupID: payload.GroupID,
Region: payload.Region,
Until: until,
Reason: strings.TrimSpace(payload.Reason),
CreatedBy: createdBy,
}
created, err := h.opsService.CreateAlertSilence(c.Request.Context(), silence)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, created)
}
func (h *OpsHandler) ListAlertEvents(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
limit := 20
if raw := strings.TrimSpace(c.Query("limit")); raw != "" {
n, err := strconv.Atoi(raw)
if err != nil || n <= 0 {
response.BadRequest(c, "Invalid limit")
return
}
limit = n
}
filter := &service.OpsAlertEventFilter{
Limit: limit,
Status: strings.TrimSpace(c.Query("status")),
Severity: strings.TrimSpace(c.Query("severity")),
}
if v := strings.TrimSpace(c.Query("email_sent")); v != "" {
vv := strings.ToLower(v)
switch vv {
case "true", "1":
b := true
filter.EmailSent = &b
case "false", "0":
b := false
filter.EmailSent = &b
default:
response.BadRequest(c, "Invalid email_sent")
return
}
}
// Cursor pagination: both params must be provided together.
rawTS := strings.TrimSpace(c.Query("before_fired_at"))
rawID := strings.TrimSpace(c.Query("before_id"))
if (rawTS == "") != (rawID == "") {
response.BadRequest(c, "before_fired_at and before_id must be provided together")
return
}
if rawTS != "" {
ts, err := time.Parse(time.RFC3339Nano, rawTS)
if err != nil {
if t2, err2 := time.Parse(time.RFC3339, rawTS); err2 == nil {
ts = t2
} else {
response.BadRequest(c, "Invalid before_fired_at")
return
}
}
filter.BeforeFiredAt = &ts
}
if rawID != "" {
id, err := strconv.ParseInt(rawID, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid before_id")
return
}
filter.BeforeID = &id
}
// Optional global filter support (platform/group/time range).
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
filter.Platform = platform
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
if startTime, endTime, err := parseOpsTimeRange(c, "24h"); err == nil {
// Only apply when explicitly provided to avoid surprising default narrowing.
if strings.TrimSpace(c.Query("start_time")) != "" || strings.TrimSpace(c.Query("end_time")) != "" || strings.TrimSpace(c.Query("time_range")) != "" {
filter.StartTime = &startTime
filter.EndTime = &endTime
}
} else {
response.BadRequest(c, err.Error())
return
}
events, err := h.opsService.ListAlertEvents(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, events)
}

View File

@@ -1,243 +0,0 @@
package admin
import (
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// GetDashboardOverview returns vNext ops dashboard overview (raw path).
// GET /api/v1/admin/ops/dashboard/overview
func (h *OpsHandler) GetDashboardOverview(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsDashboardFilter{
StartTime: startTime,
EndTime: endTime,
Platform: strings.TrimSpace(c.Query("platform")),
QueryMode: parseOpsQueryMode(c),
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
data, err := h.opsService.GetDashboardOverview(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, data)
}
// GetDashboardThroughputTrend returns throughput time series (raw path).
// GET /api/v1/admin/ops/dashboard/throughput-trend
func (h *OpsHandler) GetDashboardThroughputTrend(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsDashboardFilter{
StartTime: startTime,
EndTime: endTime,
Platform: strings.TrimSpace(c.Query("platform")),
QueryMode: parseOpsQueryMode(c),
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
bucketSeconds := pickThroughputBucketSeconds(endTime.Sub(startTime))
data, err := h.opsService.GetThroughputTrend(c.Request.Context(), filter, bucketSeconds)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, data)
}
// GetDashboardLatencyHistogram returns the latency distribution histogram (success requests).
// GET /api/v1/admin/ops/dashboard/latency-histogram
func (h *OpsHandler) GetDashboardLatencyHistogram(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsDashboardFilter{
StartTime: startTime,
EndTime: endTime,
Platform: strings.TrimSpace(c.Query("platform")),
QueryMode: parseOpsQueryMode(c),
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
data, err := h.opsService.GetLatencyHistogram(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, data)
}
// GetDashboardErrorTrend returns error counts time series (raw path).
// GET /api/v1/admin/ops/dashboard/error-trend
func (h *OpsHandler) GetDashboardErrorTrend(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsDashboardFilter{
StartTime: startTime,
EndTime: endTime,
Platform: strings.TrimSpace(c.Query("platform")),
QueryMode: parseOpsQueryMode(c),
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
bucketSeconds := pickThroughputBucketSeconds(endTime.Sub(startTime))
data, err := h.opsService.GetErrorTrend(c.Request.Context(), filter, bucketSeconds)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, data)
}
// GetDashboardErrorDistribution returns error distribution by status code (raw path).
// GET /api/v1/admin/ops/dashboard/error-distribution
func (h *OpsHandler) GetDashboardErrorDistribution(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsDashboardFilter{
StartTime: startTime,
EndTime: endTime,
Platform: strings.TrimSpace(c.Query("platform")),
QueryMode: parseOpsQueryMode(c),
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
data, err := h.opsService.GetErrorDistribution(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, data)
}
func pickThroughputBucketSeconds(window time.Duration) int {
// Keep buckets predictable and avoid huge responses.
switch {
case window <= 2*time.Hour:
return 60
case window <= 24*time.Hour:
return 300
default:
return 3600
}
}
func parseOpsQueryMode(c *gin.Context) service.OpsQueryMode {
if c == nil {
return ""
}
raw := strings.TrimSpace(c.Query("mode"))
if raw == "" {
// Empty means "use server default" (DB setting ops_query_mode_default).
return ""
}
return service.ParseOpsQueryMode(raw)
}

View File

@@ -1,925 +0,0 @@
package admin
import (
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type OpsHandler struct {
opsService *service.OpsService
}
// GetErrorLogByID returns ops error log detail.
// GET /api/v1/admin/ops/errors/:id
func (h *OpsHandler) GetErrorLogByID(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
detail, err := h.opsService.GetErrorLogByID(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, detail)
}
const (
opsListViewErrors = "errors"
opsListViewExcluded = "excluded"
opsListViewAll = "all"
)
func parseOpsViewParam(c *gin.Context) string {
if c == nil {
return ""
}
v := strings.ToLower(strings.TrimSpace(c.Query("view")))
switch v {
case "", opsListViewErrors:
return opsListViewErrors
case opsListViewExcluded:
return opsListViewExcluded
case opsListViewAll:
return opsListViewAll
default:
return opsListViewErrors
}
}
func NewOpsHandler(opsService *service.OpsService) *OpsHandler {
return &OpsHandler{opsService: opsService}
}
// GetErrorLogs lists ops error logs.
// GET /api/v1/admin/ops/errors
func (h *OpsHandler) GetErrorLogs(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
page, pageSize := response.ParsePagination(c)
// Ops list can be larger than standard admin tables.
if pageSize > 500 {
pageSize = 500
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
if !startTime.IsZero() {
filter.StartTime = &startTime
}
if !endTime.IsZero() {
filter.EndTime = &endTime
}
filter.View = parseOpsViewParam(c)
filter.Phase = strings.TrimSpace(c.Query("phase"))
filter.Owner = strings.TrimSpace(c.Query("error_owner"))
filter.Source = strings.TrimSpace(c.Query("error_source"))
filter.Query = strings.TrimSpace(c.Query("q"))
filter.UserQuery = strings.TrimSpace(c.Query("user_query"))
// Force request errors: client-visible status >= 400.
// buildOpsErrorLogsWhere already applies this for non-upstream phase.
if strings.EqualFold(strings.TrimSpace(filter.Phase), "upstream") {
filter.Phase = ""
}
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
filter.Platform = platform
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid account_id")
return
}
filter.AccountID = &id
}
if v := strings.TrimSpace(c.Query("resolved")); v != "" {
switch strings.ToLower(v) {
case "1", "true", "yes":
b := true
filter.Resolved = &b
case "0", "false", "no":
b := false
filter.Resolved = &b
default:
response.BadRequest(c, "Invalid resolved")
return
}
}
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
parts := strings.Split(statusCodesStr, ",")
out := make([]int, 0, len(parts))
for _, part := range parts {
p := strings.TrimSpace(part)
if p == "" {
continue
}
n, err := strconv.Atoi(p)
if err != nil || n < 0 {
response.BadRequest(c, "Invalid status_codes")
return
}
out = append(out, n)
}
filter.StatusCodes = out
}
result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
}
// ListRequestErrors lists client-visible request errors.
// GET /api/v1/admin/ops/request-errors
func (h *OpsHandler) ListRequestErrors(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
page, pageSize := response.ParsePagination(c)
if pageSize > 500 {
pageSize = 500
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
if !startTime.IsZero() {
filter.StartTime = &startTime
}
if !endTime.IsZero() {
filter.EndTime = &endTime
}
filter.View = parseOpsViewParam(c)
filter.Phase = strings.TrimSpace(c.Query("phase"))
filter.Owner = strings.TrimSpace(c.Query("error_owner"))
filter.Source = strings.TrimSpace(c.Query("error_source"))
filter.Query = strings.TrimSpace(c.Query("q"))
filter.UserQuery = strings.TrimSpace(c.Query("user_query"))
// Force request errors: client-visible status >= 400.
// buildOpsErrorLogsWhere already applies this for non-upstream phase.
if strings.EqualFold(strings.TrimSpace(filter.Phase), "upstream") {
filter.Phase = ""
}
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
filter.Platform = platform
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid account_id")
return
}
filter.AccountID = &id
}
if v := strings.TrimSpace(c.Query("resolved")); v != "" {
switch strings.ToLower(v) {
case "1", "true", "yes":
b := true
filter.Resolved = &b
case "0", "false", "no":
b := false
filter.Resolved = &b
default:
response.BadRequest(c, "Invalid resolved")
return
}
}
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
parts := strings.Split(statusCodesStr, ",")
out := make([]int, 0, len(parts))
for _, part := range parts {
p := strings.TrimSpace(part)
if p == "" {
continue
}
n, err := strconv.Atoi(p)
if err != nil || n < 0 {
response.BadRequest(c, "Invalid status_codes")
return
}
out = append(out, n)
}
filter.StatusCodes = out
}
result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
}
// GetRequestError returns request error detail.
// GET /api/v1/admin/ops/request-errors/:id
func (h *OpsHandler) GetRequestError(c *gin.Context) {
// same storage; just proxy to existing detail
h.GetErrorLogByID(c)
}
// ListRequestErrorUpstreamErrors lists upstream error logs correlated to a request error.
// GET /api/v1/admin/ops/request-errors/:id/upstream-errors
func (h *OpsHandler) ListRequestErrorUpstreamErrors(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
// Load request error to get correlation keys.
detail, err := h.opsService.GetErrorLogByID(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Correlate by request_id/client_request_id.
requestID := strings.TrimSpace(detail.RequestID)
clientRequestID := strings.TrimSpace(detail.ClientRequestID)
if requestID == "" && clientRequestID == "" {
response.Paginated(c, []*service.OpsErrorLog{}, 0, 1, 10)
return
}
page, pageSize := response.ParsePagination(c)
if pageSize > 500 {
pageSize = 500
}
// Keep correlation window wide enough so linked upstream errors
// are discoverable even when UI defaults to 1h elsewhere.
startTime, endTime, err := parseOpsTimeRange(c, "30d")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
if !startTime.IsZero() {
filter.StartTime = &startTime
}
if !endTime.IsZero() {
filter.EndTime = &endTime
}
filter.View = "all"
filter.Phase = "upstream"
filter.Owner = "provider"
filter.Source = strings.TrimSpace(c.Query("error_source"))
filter.Query = strings.TrimSpace(c.Query("q"))
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
filter.Platform = platform
}
// Prefer exact match on request_id; if missing, fall back to client_request_id.
if requestID != "" {
filter.RequestID = requestID
} else {
filter.ClientRequestID = clientRequestID
}
result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
// If client asks for details, expand each upstream error log to include upstream response fields.
includeDetail := strings.TrimSpace(c.Query("include_detail"))
if includeDetail == "1" || strings.EqualFold(includeDetail, "true") || strings.EqualFold(includeDetail, "yes") {
details := make([]*service.OpsErrorLogDetail, 0, len(result.Errors))
for _, item := range result.Errors {
if item == nil {
continue
}
d, err := h.opsService.GetErrorLogByID(c.Request.Context(), item.ID)
if err != nil || d == nil {
continue
}
details = append(details, d)
}
response.Paginated(c, details, int64(result.Total), result.Page, result.PageSize)
return
}
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
}
// RetryRequestErrorClient retries the client request based on stored request body.
// POST /api/v1/admin/ops/request-errors/:id/retry-client
func (h *OpsHandler) RetryRequestErrorClient(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, service.OpsRetryModeClient, nil)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// RetryRequestErrorUpstreamEvent retries a specific upstream attempt using captured upstream_request_body.
// POST /api/v1/admin/ops/request-errors/:id/upstream-errors/:idx/retry
func (h *OpsHandler) RetryRequestErrorUpstreamEvent(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
idxStr := strings.TrimSpace(c.Param("idx"))
idx, err := strconv.Atoi(idxStr)
if err != nil || idx < 0 {
response.BadRequest(c, "Invalid upstream idx")
return
}
result, err := h.opsService.RetryUpstreamEvent(c.Request.Context(), subject.UserID, id, idx)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// ResolveRequestError toggles resolved status.
// PUT /api/v1/admin/ops/request-errors/:id/resolve
func (h *OpsHandler) ResolveRequestError(c *gin.Context) {
h.UpdateErrorResolution(c)
}
// ListUpstreamErrors lists independent upstream errors.
// GET /api/v1/admin/ops/upstream-errors
func (h *OpsHandler) ListUpstreamErrors(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
page, pageSize := response.ParsePagination(c)
if pageSize > 500 {
pageSize = 500
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
if !startTime.IsZero() {
filter.StartTime = &startTime
}
if !endTime.IsZero() {
filter.EndTime = &endTime
}
filter.View = parseOpsViewParam(c)
filter.Phase = "upstream"
filter.Owner = "provider"
filter.Source = strings.TrimSpace(c.Query("error_source"))
filter.Query = strings.TrimSpace(c.Query("q"))
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
filter.Platform = platform
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid account_id")
return
}
filter.AccountID = &id
}
if v := strings.TrimSpace(c.Query("resolved")); v != "" {
switch strings.ToLower(v) {
case "1", "true", "yes":
b := true
filter.Resolved = &b
case "0", "false", "no":
b := false
filter.Resolved = &b
default:
response.BadRequest(c, "Invalid resolved")
return
}
}
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
parts := strings.Split(statusCodesStr, ",")
out := make([]int, 0, len(parts))
for _, part := range parts {
p := strings.TrimSpace(part)
if p == "" {
continue
}
n, err := strconv.Atoi(p)
if err != nil || n < 0 {
response.BadRequest(c, "Invalid status_codes")
return
}
out = append(out, n)
}
filter.StatusCodes = out
}
result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
}
// GetUpstreamError returns upstream error detail.
// GET /api/v1/admin/ops/upstream-errors/:id
func (h *OpsHandler) GetUpstreamError(c *gin.Context) {
h.GetErrorLogByID(c)
}
// RetryUpstreamError retries upstream error using the original account_id.
// POST /api/v1/admin/ops/upstream-errors/:id/retry
func (h *OpsHandler) RetryUpstreamError(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, service.OpsRetryModeUpstream, nil)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// ResolveUpstreamError toggles resolved status.
// PUT /api/v1/admin/ops/upstream-errors/:id/resolve
func (h *OpsHandler) ResolveUpstreamError(c *gin.Context) {
h.UpdateErrorResolution(c)
}
// ==================== Existing endpoints ====================
// ListRequestDetails returns a request-level list (success + error) for drill-down.
// GET /api/v1/admin/ops/requests
func (h *OpsHandler) ListRequestDetails(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
page, pageSize := response.ParsePagination(c)
if pageSize > 100 {
pageSize = 100
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsRequestDetailFilter{
Page: page,
PageSize: pageSize,
StartTime: &startTime,
EndTime: &endTime,
}
filter.Kind = strings.TrimSpace(c.Query("kind"))
filter.Platform = strings.TrimSpace(c.Query("platform"))
filter.Model = strings.TrimSpace(c.Query("model"))
filter.RequestID = strings.TrimSpace(c.Query("request_id"))
filter.Query = strings.TrimSpace(c.Query("q"))
filter.Sort = strings.TrimSpace(c.Query("sort"))
if v := strings.TrimSpace(c.Query("user_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid user_id")
return
}
filter.UserID = &id
}
if v := strings.TrimSpace(c.Query("api_key_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid api_key_id")
return
}
filter.APIKeyID = &id
}
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid account_id")
return
}
filter.AccountID = &id
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
if v := strings.TrimSpace(c.Query("min_duration_ms")); v != "" {
parsed, err := strconv.Atoi(v)
if err != nil || parsed < 0 {
response.BadRequest(c, "Invalid min_duration_ms")
return
}
filter.MinDurationMs = &parsed
}
if v := strings.TrimSpace(c.Query("max_duration_ms")); v != "" {
parsed, err := strconv.Atoi(v)
if err != nil || parsed < 0 {
response.BadRequest(c, "Invalid max_duration_ms")
return
}
filter.MaxDurationMs = &parsed
}
out, err := h.opsService.ListRequestDetails(c.Request.Context(), filter)
if err != nil {
// Invalid sort/kind/platform etc should be a bad request; keep it simple.
if strings.Contains(strings.ToLower(err.Error()), "invalid") {
response.BadRequest(c, err.Error())
return
}
response.Error(c, http.StatusInternalServerError, "Failed to list request details")
return
}
response.Paginated(c, out.Items, out.Total, out.Page, out.PageSize)
}
type opsRetryRequest struct {
Mode string `json:"mode"`
PinnedAccountID *int64 `json:"pinned_account_id"`
Force bool `json:"force"`
}
type opsResolveRequest struct {
Resolved bool `json:"resolved"`
}
// RetryErrorRequest retries a failed request using stored request_body.
// POST /api/v1/admin/ops/errors/:id/retry
func (h *OpsHandler) RetryErrorRequest(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
req := opsRetryRequest{Mode: service.OpsRetryModeClient}
if err := c.ShouldBindJSON(&req); err != nil && !errors.Is(err, io.EOF) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if strings.TrimSpace(req.Mode) == "" {
req.Mode = service.OpsRetryModeClient
}
// Force flag is currently a UI-level acknowledgement. Server may still enforce safety constraints.
_ = req.Force
// Legacy endpoint safety: only allow retrying the client request here.
// Upstream retries must go through the split endpoints.
if strings.EqualFold(strings.TrimSpace(req.Mode), service.OpsRetryModeUpstream) {
response.BadRequest(c, "upstream retry is not supported on this endpoint")
return
}
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, req.Mode, req.PinnedAccountID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// ListRetryAttempts lists retry attempts for an error log.
// GET /api/v1/admin/ops/errors/:id/retries
func (h *OpsHandler) ListRetryAttempts(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
limit := 50
if v := strings.TrimSpace(c.Query("limit")); v != "" {
n, err := strconv.Atoi(v)
if err != nil || n <= 0 {
response.BadRequest(c, "Invalid limit")
return
}
limit = n
}
items, err := h.opsService.ListRetryAttemptsByErrorID(c.Request.Context(), id, limit)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, items)
}
// UpdateErrorResolution allows manual resolve/unresolve.
// PUT /api/v1/admin/ops/errors/:id/resolve
func (h *OpsHandler) UpdateErrorResolution(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
var req opsResolveRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
uid := subject.UserID
if err := h.opsService.UpdateErrorResolution(c.Request.Context(), id, req.Resolved, &uid, nil); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"ok": true})
}
func parseOpsTimeRange(c *gin.Context, defaultRange string) (time.Time, time.Time, error) {
startStr := strings.TrimSpace(c.Query("start_time"))
endStr := strings.TrimSpace(c.Query("end_time"))
parseTS := func(s string) (time.Time, error) {
if s == "" {
return time.Time{}, nil
}
if t, err := time.Parse(time.RFC3339Nano, s); err == nil {
return t, nil
}
return time.Parse(time.RFC3339, s)
}
start, err := parseTS(startStr)
if err != nil {
return time.Time{}, time.Time{}, err
}
end, err := parseTS(endStr)
if err != nil {
return time.Time{}, time.Time{}, err
}
// start/end explicitly provided (even partially)
if startStr != "" || endStr != "" {
if end.IsZero() {
end = time.Now()
}
if start.IsZero() {
dur, _ := parseOpsDuration(defaultRange)
start = end.Add(-dur)
}
if start.After(end) {
return time.Time{}, time.Time{}, fmt.Errorf("invalid time range: start_time must be <= end_time")
}
if end.Sub(start) > 30*24*time.Hour {
return time.Time{}, time.Time{}, fmt.Errorf("invalid time range: max window is 30 days")
}
return start, end, nil
}
// time_range fallback
tr := strings.TrimSpace(c.Query("time_range"))
if tr == "" {
tr = defaultRange
}
dur, ok := parseOpsDuration(tr)
if !ok {
dur, _ = parseOpsDuration(defaultRange)
}
end = time.Now()
start = end.Add(-dur)
if end.Sub(start) > 30*24*time.Hour {
return time.Time{}, time.Time{}, fmt.Errorf("invalid time range: max window is 30 days")
}
return start, end, nil
}
func parseOpsDuration(v string) (time.Duration, bool) {
switch strings.TrimSpace(v) {
case "5m":
return 5 * time.Minute, true
case "30m":
return 30 * time.Minute, true
case "1h":
return time.Hour, true
case "6h":
return 6 * time.Hour, true
case "24h":
return 24 * time.Hour, true
case "7d":
return 7 * 24 * time.Hour, true
case "30d":
return 30 * 24 * time.Hour, true
default:
return 0, false
}
}

View File

@@ -1,213 +0,0 @@
package admin
import (
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// GetConcurrencyStats returns real-time concurrency usage aggregated by platform/group/account.
// GET /api/v1/admin/ops/concurrency
func (h *OpsHandler) GetConcurrencyStats(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) {
response.Success(c, gin.H{
"enabled": false,
"platform": map[string]*service.PlatformConcurrencyInfo{},
"group": map[int64]*service.GroupConcurrencyInfo{},
"account": map[int64]*service.AccountConcurrencyInfo{},
"timestamp": time.Now().UTC(),
})
return
}
platformFilter := strings.TrimSpace(c.Query("platform"))
var groupID *int64
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
groupID = &id
}
platform, group, account, collectedAt, err := h.opsService.GetConcurrencyStats(c.Request.Context(), platformFilter, groupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
payload := gin.H{
"enabled": true,
"platform": platform,
"group": group,
"account": account,
}
if collectedAt != nil {
payload["timestamp"] = collectedAt.UTC()
}
response.Success(c, payload)
}
// GetAccountAvailability returns account availability statistics.
// GET /api/v1/admin/ops/account-availability
//
// Query params:
// - platform: optional
// - group_id: optional
func (h *OpsHandler) GetAccountAvailability(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) {
response.Success(c, gin.H{
"enabled": false,
"platform": map[string]*service.PlatformAvailability{},
"group": map[int64]*service.GroupAvailability{},
"account": map[int64]*service.AccountAvailability{},
"timestamp": time.Now().UTC(),
})
return
}
platform := strings.TrimSpace(c.Query("platform"))
var groupID *int64
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
groupID = &id
}
platformStats, groupStats, accountStats, collectedAt, err := h.opsService.GetAccountAvailabilityStats(c.Request.Context(), platform, groupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
payload := gin.H{
"enabled": true,
"platform": platformStats,
"group": groupStats,
"account": accountStats,
}
if collectedAt != nil {
payload["timestamp"] = collectedAt.UTC()
}
response.Success(c, payload)
}
func parseOpsRealtimeWindow(v string) (time.Duration, string, bool) {
switch strings.ToLower(strings.TrimSpace(v)) {
case "", "1min", "1m":
return 1 * time.Minute, "1min", true
case "5min", "5m":
return 5 * time.Minute, "5min", true
case "30min", "30m":
return 30 * time.Minute, "30min", true
case "1h", "60m", "60min":
return 1 * time.Hour, "1h", true
default:
return 0, "", false
}
}
// GetRealtimeTrafficSummary returns QPS/TPS current/peak/avg for the selected window.
// GET /api/v1/admin/ops/realtime-traffic
//
// Query params:
// - window: 1min|5min|30min|1h (default: 1min)
// - platform: optional
// - group_id: optional
func (h *OpsHandler) GetRealtimeTrafficSummary(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
windowDur, windowLabel, ok := parseOpsRealtimeWindow(c.Query("window"))
if !ok {
response.BadRequest(c, "Invalid window")
return
}
platform := strings.TrimSpace(c.Query("platform"))
var groupID *int64
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
groupID = &id
}
endTime := time.Now().UTC()
startTime := endTime.Add(-windowDur)
if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) {
disabledSummary := &service.OpsRealtimeTrafficSummary{
Window: windowLabel,
StartTime: startTime,
EndTime: endTime,
Platform: platform,
GroupID: groupID,
QPS: service.OpsRateSummary{},
TPS: service.OpsRateSummary{},
}
response.Success(c, gin.H{
"enabled": false,
"summary": disabledSummary,
"timestamp": endTime,
})
return
}
filter := &service.OpsDashboardFilter{
StartTime: startTime,
EndTime: endTime,
Platform: platform,
GroupID: groupID,
QueryMode: service.OpsQueryModeRaw,
}
summary, err := h.opsService.GetRealtimeTrafficSummary(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
if summary != nil {
summary.Window = windowLabel
}
response.Success(c, gin.H{
"enabled": true,
"summary": summary,
"timestamp": endTime,
})
}

View File

@@ -1,194 +0,0 @@
package admin
import (
"net/http"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// GetEmailNotificationConfig returns Ops email notification config (DB-backed).
// GET /api/v1/admin/ops/email-notification/config
func (h *OpsHandler) GetEmailNotificationConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
cfg, err := h.opsService.GetEmailNotificationConfig(c.Request.Context())
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to get email notification config")
return
}
response.Success(c, cfg)
}
// UpdateEmailNotificationConfig updates Ops email notification config (DB-backed).
// PUT /api/v1/admin/ops/email-notification/config
func (h *OpsHandler) UpdateEmailNotificationConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
var req service.OpsEmailNotificationConfigUpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
updated, err := h.opsService.UpdateEmailNotificationConfig(c.Request.Context(), &req)
if err != nil {
// Most failures here are validation errors from request payload; treat as 400.
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, updated)
}
// GetAlertRuntimeSettings returns Ops alert evaluator runtime settings (DB-backed).
// GET /api/v1/admin/ops/runtime/alert
func (h *OpsHandler) GetAlertRuntimeSettings(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
cfg, err := h.opsService.GetOpsAlertRuntimeSettings(c.Request.Context())
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to get alert runtime settings")
return
}
response.Success(c, cfg)
}
// UpdateAlertRuntimeSettings updates Ops alert evaluator runtime settings (DB-backed).
// PUT /api/v1/admin/ops/runtime/alert
func (h *OpsHandler) UpdateAlertRuntimeSettings(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
var req service.OpsAlertRuntimeSettings
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
updated, err := h.opsService.UpdateOpsAlertRuntimeSettings(c.Request.Context(), &req)
if err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, updated)
}
// GetAdvancedSettings returns Ops advanced settings (DB-backed).
// GET /api/v1/admin/ops/advanced-settings
func (h *OpsHandler) GetAdvancedSettings(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
cfg, err := h.opsService.GetOpsAdvancedSettings(c.Request.Context())
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to get advanced settings")
return
}
response.Success(c, cfg)
}
// UpdateAdvancedSettings updates Ops advanced settings (DB-backed).
// PUT /api/v1/admin/ops/advanced-settings
func (h *OpsHandler) UpdateAdvancedSettings(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
var req service.OpsAdvancedSettings
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
updated, err := h.opsService.UpdateOpsAdvancedSettings(c.Request.Context(), &req)
if err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, updated)
}
// GetMetricThresholds returns Ops metric thresholds (DB-backed).
// GET /api/v1/admin/ops/settings/metric-thresholds
func (h *OpsHandler) GetMetricThresholds(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
cfg, err := h.opsService.GetMetricThresholds(c.Request.Context())
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to get metric thresholds")
return
}
response.Success(c, cfg)
}
// UpdateMetricThresholds updates Ops metric thresholds (DB-backed).
// PUT /api/v1/admin/ops/settings/metric-thresholds
func (h *OpsHandler) UpdateMetricThresholds(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
var req service.OpsMetricThresholds
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
updated, err := h.opsService.UpdateMetricThresholds(c.Request.Context(), &req)
if err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, updated)
}

View File

@@ -1,771 +0,0 @@
package admin
import (
"context"
"encoding/json"
"log"
"math"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
type OpsWSProxyConfig struct {
TrustProxy bool
TrustedProxies []netip.Prefix
OriginPolicy string
}
const (
envOpsWSTrustProxy = "OPS_WS_TRUST_PROXY"
envOpsWSTrustedProxies = "OPS_WS_TRUSTED_PROXIES"
envOpsWSOriginPolicy = "OPS_WS_ORIGIN_POLICY"
envOpsWSMaxConns = "OPS_WS_MAX_CONNS"
envOpsWSMaxConnsPerIP = "OPS_WS_MAX_CONNS_PER_IP"
)
const (
OriginPolicyStrict = "strict"
OriginPolicyPermissive = "permissive"
)
var opsWSProxyConfig = loadOpsWSProxyConfigFromEnv()
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return isAllowedOpsWSOrigin(r)
},
// Subprotocol negotiation:
// - The frontend passes ["sub2api-admin", "jwt.<token>"].
// - We always select "sub2api-admin" so the token is never echoed back in the handshake response.
Subprotocols: []string{"sub2api-admin"},
}
const (
qpsWSPushInterval = 2 * time.Second
qpsWSRefreshInterval = 5 * time.Second
qpsWSRequestCountWindow = 1 * time.Minute
defaultMaxWSConns = 100
defaultMaxWSConnsPerIP = 20
)
var wsConnCount atomic.Int32
var wsConnCountByIP sync.Map // map[string]*atomic.Int32
const qpsWSIdleStopDelay = 30 * time.Second
const (
opsWSCloseRealtimeDisabled = 4001
)
var qpsWSIdleStopMu sync.Mutex
var qpsWSIdleStopTimer *time.Timer
func cancelQPSWSIdleStop() {
qpsWSIdleStopMu.Lock()
if qpsWSIdleStopTimer != nil {
qpsWSIdleStopTimer.Stop()
qpsWSIdleStopTimer = nil
}
qpsWSIdleStopMu.Unlock()
}
func scheduleQPSWSIdleStop() {
qpsWSIdleStopMu.Lock()
if qpsWSIdleStopTimer != nil {
qpsWSIdleStopMu.Unlock()
return
}
qpsWSIdleStopTimer = time.AfterFunc(qpsWSIdleStopDelay, func() {
// Only stop if truly idle at fire time.
if wsConnCount.Load() == 0 {
qpsWSCache.Stop()
}
qpsWSIdleStopMu.Lock()
qpsWSIdleStopTimer = nil
qpsWSIdleStopMu.Unlock()
})
qpsWSIdleStopMu.Unlock()
}
type opsWSRuntimeLimits struct {
MaxConns int32
MaxConnsPerIP int32
}
var opsWSLimits = loadOpsWSRuntimeLimitsFromEnv()
const (
qpsWSWriteTimeout = 10 * time.Second
qpsWSPongWait = 60 * time.Second
qpsWSPingInterval = 30 * time.Second
// We don't expect clients to send application messages; we only read to process control frames (Pong/Close).
qpsWSMaxReadBytes = 1024
)
type opsWSQPSCache struct {
refreshInterval time.Duration
requestCountWindow time.Duration
lastUpdatedUnixNano atomic.Int64
payload atomic.Value // []byte
opsService *service.OpsService
cancel context.CancelFunc
done chan struct{}
mu sync.Mutex
running bool
}
var qpsWSCache = &opsWSQPSCache{
refreshInterval: qpsWSRefreshInterval,
requestCountWindow: qpsWSRequestCountWindow,
}
func (c *opsWSQPSCache) start(opsService *service.OpsService) {
if c == nil || opsService == nil {
return
}
for {
c.mu.Lock()
if c.running {
c.mu.Unlock()
return
}
// If a previous refresh loop is currently stopping, wait for it to fully exit.
done := c.done
if done != nil {
c.mu.Unlock()
<-done
c.mu.Lock()
if c.done == done && !c.running {
c.done = nil
}
c.mu.Unlock()
continue
}
c.opsService = opsService
ctx, cancel := context.WithCancel(context.Background())
c.cancel = cancel
c.done = make(chan struct{})
done = c.done
c.running = true
c.mu.Unlock()
go func() {
defer close(done)
c.refreshLoop(ctx)
}()
return
}
}
// Stop stops the background refresh loop.
// It is safe to call multiple times.
func (c *opsWSQPSCache) Stop() {
if c == nil {
return
}
c.mu.Lock()
if !c.running {
done := c.done
c.mu.Unlock()
if done != nil {
<-done
}
return
}
cancel := c.cancel
c.cancel = nil
c.running = false
c.opsService = nil
done := c.done
c.mu.Unlock()
if cancel != nil {
cancel()
}
if done != nil {
<-done
}
c.mu.Lock()
if c.done == done && !c.running {
c.done = nil
}
c.mu.Unlock()
}
func (c *opsWSQPSCache) refreshLoop(ctx context.Context) {
ticker := time.NewTicker(c.refreshInterval)
defer ticker.Stop()
c.refresh(ctx)
for {
select {
case <-ticker.C:
c.refresh(ctx)
case <-ctx.Done():
return
}
}
}
func (c *opsWSQPSCache) refresh(parentCtx context.Context) {
if c == nil {
return
}
c.mu.Lock()
opsService := c.opsService
c.mu.Unlock()
if opsService == nil {
return
}
if parentCtx == nil {
parentCtx = context.Background()
}
ctx, cancel := context.WithTimeout(parentCtx, 10*time.Second)
defer cancel()
now := time.Now().UTC()
stats, err := opsService.GetWindowStats(ctx, now.Add(-c.requestCountWindow), now)
if err != nil || stats == nil {
if err != nil {
log.Printf("[OpsWS] refresh: get window stats failed: %v", err)
}
return
}
requestCount := stats.SuccessCount + stats.ErrorCountTotal
qps := 0.0
tps := 0.0
if c.requestCountWindow > 0 {
seconds := c.requestCountWindow.Seconds()
qps = roundTo1DP(float64(requestCount) / seconds)
tps = roundTo1DP(float64(stats.TokenConsumed) / seconds)
}
payload := gin.H{
"type": "qps_update",
"timestamp": now.Format(time.RFC3339),
"data": gin.H{
"qps": qps,
"tps": tps,
"request_count": requestCount,
},
}
msg, err := json.Marshal(payload)
if err != nil {
log.Printf("[OpsWS] refresh: marshal payload failed: %v", err)
return
}
c.payload.Store(msg)
c.lastUpdatedUnixNano.Store(now.UnixNano())
}
func roundTo1DP(v float64) float64 {
return math.Round(v*10) / 10
}
func (c *opsWSQPSCache) getPayload() []byte {
if c == nil {
return nil
}
if cached, ok := c.payload.Load().([]byte); ok && cached != nil {
return cached
}
return nil
}
func closeWS(conn *websocket.Conn, code int, reason string) {
if conn == nil {
return
}
msg := websocket.FormatCloseMessage(code, reason)
_ = conn.WriteControl(websocket.CloseMessage, msg, time.Now().Add(qpsWSWriteTimeout))
_ = conn.Close()
}
// QPSWSHandler handles realtime QPS push via WebSocket.
// GET /api/v1/admin/ops/ws/qps
func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
clientIP := requestClientIP(c.Request)
if h == nil || h.opsService == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "ops service not initialized"})
return
}
// If realtime monitoring is disabled, prefer a successful WS upgrade followed by a clean close
// with a deterministic close code. This prevents clients from spinning on 404/1006 reconnect loops.
if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) {
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "ops realtime monitoring is disabled"})
return
}
closeWS(conn, opsWSCloseRealtimeDisabled, "realtime_disabled")
return
}
cancelQPSWSIdleStop()
// Lazily start the background refresh loop so unit tests that never hit the
// websocket route don't spawn goroutines that depend on DB/Redis stubs.
qpsWSCache.start(h.opsService)
// Reserve a global slot before upgrading the connection to keep the limit strict.
if !tryAcquireOpsWSTotalSlot(opsWSLimits.MaxConns) {
log.Printf("[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns)
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
return
}
defer func() {
if wsConnCount.Add(-1) == 0 {
scheduleQPSWSIdleStop()
}
}()
if opsWSLimits.MaxConnsPerIP > 0 && clientIP != "" {
if !tryAcquireOpsWSIPSlot(clientIP, opsWSLimits.MaxConnsPerIP) {
log.Printf("[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP)
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
return
}
defer releaseOpsWSIPSlot(clientIP)
}
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Printf("[OpsWS] upgrade failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
handleQPSWebSocket(c.Request.Context(), conn)
}
func tryAcquireOpsWSTotalSlot(limit int32) bool {
if limit <= 0 {
return true
}
for {
current := wsConnCount.Load()
if current >= limit {
return false
}
if wsConnCount.CompareAndSwap(current, current+1) {
return true
}
}
}
func tryAcquireOpsWSIPSlot(clientIP string, limit int32) bool {
if strings.TrimSpace(clientIP) == "" || limit <= 0 {
return true
}
v, _ := wsConnCountByIP.LoadOrStore(clientIP, &atomic.Int32{})
counter, ok := v.(*atomic.Int32)
if !ok {
return false
}
for {
current := counter.Load()
if current >= limit {
return false
}
if counter.CompareAndSwap(current, current+1) {
return true
}
}
}
func releaseOpsWSIPSlot(clientIP string) {
if strings.TrimSpace(clientIP) == "" {
return
}
v, ok := wsConnCountByIP.Load(clientIP)
if !ok {
return
}
counter, ok := v.(*atomic.Int32)
if !ok {
return
}
next := counter.Add(-1)
if next <= 0 {
// Best-effort cleanup; safe even if a new slot was acquired concurrently.
wsConnCountByIP.Delete(clientIP)
}
}
func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
if conn == nil {
return
}
ctx, cancel := context.WithCancel(parentCtx)
defer cancel()
var closeOnce sync.Once
closeConn := func() {
closeOnce.Do(func() {
_ = conn.Close()
})
}
closeFrameCh := make(chan []byte, 1)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
conn.SetReadLimit(qpsWSMaxReadBytes)
if err := conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)); err != nil {
log.Printf("[OpsWS] set read deadline failed: %v", err)
return
}
conn.SetPongHandler(func(string) error {
return conn.SetReadDeadline(time.Now().Add(qpsWSPongWait))
})
conn.SetCloseHandler(func(code int, text string) error {
select {
case closeFrameCh <- websocket.FormatCloseMessage(code, text):
default:
}
cancel()
return nil
})
for {
_, _, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
log.Printf("[OpsWS] read failed: %v", err)
}
return
}
}
}()
// Push QPS data every 2 seconds (values are globally cached and refreshed at most once per qpsWSRefreshInterval).
pushTicker := time.NewTicker(qpsWSPushInterval)
defer pushTicker.Stop()
// Heartbeat ping every 30 seconds.
pingTicker := time.NewTicker(qpsWSPingInterval)
defer pingTicker.Stop()
writeWithTimeout := func(messageType int, data []byte) error {
if err := conn.SetWriteDeadline(time.Now().Add(qpsWSWriteTimeout)); err != nil {
return err
}
return conn.WriteMessage(messageType, data)
}
sendClose := func(closeFrame []byte) {
if closeFrame == nil {
closeFrame = websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")
}
_ = writeWithTimeout(websocket.CloseMessage, closeFrame)
}
for {
select {
case <-pushTicker.C:
msg := qpsWSCache.getPayload()
if msg == nil {
continue
}
if err := writeWithTimeout(websocket.TextMessage, msg); err != nil {
log.Printf("[OpsWS] write failed: %v", err)
cancel()
closeConn()
wg.Wait()
return
}
case <-pingTicker.C:
if err := writeWithTimeout(websocket.PingMessage, nil); err != nil {
log.Printf("[OpsWS] ping failed: %v", err)
cancel()
closeConn()
wg.Wait()
return
}
case closeFrame := <-closeFrameCh:
sendClose(closeFrame)
closeConn()
wg.Wait()
return
case <-ctx.Done():
var closeFrame []byte
select {
case closeFrame = <-closeFrameCh:
default:
}
sendClose(closeFrame)
closeConn()
wg.Wait()
return
}
}
}
func isAllowedOpsWSOrigin(r *http.Request) bool {
if r == nil {
return false
}
origin := strings.TrimSpace(r.Header.Get("Origin"))
if origin == "" {
switch strings.ToLower(strings.TrimSpace(opsWSProxyConfig.OriginPolicy)) {
case OriginPolicyStrict:
return false
case OriginPolicyPermissive, "":
return true
default:
return true
}
}
parsed, err := url.Parse(origin)
if err != nil || parsed.Hostname() == "" {
return false
}
originHost := strings.ToLower(parsed.Hostname())
trustProxyHeaders := shouldTrustOpsWSProxyHeaders(r)
reqHost := hostWithoutPort(r.Host)
if trustProxyHeaders {
xfHost := strings.TrimSpace(r.Header.Get("X-Forwarded-Host"))
if xfHost != "" {
xfHost = strings.TrimSpace(strings.Split(xfHost, ",")[0])
if xfHost != "" {
reqHost = hostWithoutPort(xfHost)
}
}
}
reqHost = strings.ToLower(reqHost)
if reqHost == "" {
return false
}
return originHost == reqHost
}
func shouldTrustOpsWSProxyHeaders(r *http.Request) bool {
if r == nil {
return false
}
if !opsWSProxyConfig.TrustProxy {
return false
}
peerIP, ok := requestPeerIP(r)
if !ok {
return false
}
return isAddrInTrustedProxies(peerIP, opsWSProxyConfig.TrustedProxies)
}
func requestPeerIP(r *http.Request) (netip.Addr, bool) {
if r == nil {
return netip.Addr{}, false
}
host, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr))
if err != nil {
host = strings.TrimSpace(r.RemoteAddr)
}
host = strings.TrimPrefix(host, "[")
host = strings.TrimSuffix(host, "]")
if host == "" {
return netip.Addr{}, false
}
addr, err := netip.ParseAddr(host)
if err != nil {
return netip.Addr{}, false
}
return addr.Unmap(), true
}
func requestClientIP(r *http.Request) string {
if r == nil {
return ""
}
trustProxyHeaders := shouldTrustOpsWSProxyHeaders(r)
if trustProxyHeaders {
xff := strings.TrimSpace(r.Header.Get("X-Forwarded-For"))
if xff != "" {
// Use the left-most entry (original client). If multiple proxies add values, they are comma-separated.
xff = strings.TrimSpace(strings.Split(xff, ",")[0])
xff = strings.TrimPrefix(xff, "[")
xff = strings.TrimSuffix(xff, "]")
if addr, err := netip.ParseAddr(xff); err == nil && addr.IsValid() {
return addr.Unmap().String()
}
}
}
if peer, ok := requestPeerIP(r); ok && peer.IsValid() {
return peer.String()
}
return ""
}
func isAddrInTrustedProxies(addr netip.Addr, trusted []netip.Prefix) bool {
if !addr.IsValid() {
return false
}
for _, p := range trusted {
if p.Contains(addr) {
return true
}
}
return false
}
func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig {
cfg := OpsWSProxyConfig{
TrustProxy: true,
TrustedProxies: defaultTrustedProxies(),
OriginPolicy: OriginPolicyPermissive,
}
if v := strings.TrimSpace(os.Getenv(envOpsWSTrustProxy)); v != "" {
if parsed, err := strconv.ParseBool(v); err == nil {
cfg.TrustProxy = parsed
} else {
log.Printf("[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy)
}
}
if raw := strings.TrimSpace(os.Getenv(envOpsWSTrustedProxies)); raw != "" {
prefixes, invalid := parseTrustedProxyList(raw)
if len(invalid) > 0 {
log.Printf("[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", "))
}
cfg.TrustedProxies = prefixes
}
if v := strings.TrimSpace(os.Getenv(envOpsWSOriginPolicy)); v != "" {
normalized := strings.ToLower(v)
switch normalized {
case OriginPolicyStrict, OriginPolicyPermissive:
cfg.OriginPolicy = normalized
default:
log.Printf("[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy)
}
}
return cfg
}
func loadOpsWSRuntimeLimitsFromEnv() opsWSRuntimeLimits {
cfg := opsWSRuntimeLimits{
MaxConns: defaultMaxWSConns,
MaxConnsPerIP: defaultMaxWSConnsPerIP,
}
if v := strings.TrimSpace(os.Getenv(envOpsWSMaxConns)); v != "" {
if parsed, err := strconv.Atoi(v); err == nil && parsed > 0 {
cfg.MaxConns = int32(parsed)
} else {
log.Printf("[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns)
}
}
if v := strings.TrimSpace(os.Getenv(envOpsWSMaxConnsPerIP)); v != "" {
if parsed, err := strconv.Atoi(v); err == nil && parsed >= 0 {
cfg.MaxConnsPerIP = int32(parsed)
} else {
log.Printf("[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP)
}
}
return cfg
}
func defaultTrustedProxies() []netip.Prefix {
prefixes, _ := parseTrustedProxyList("127.0.0.0/8,::1/128")
return prefixes
}
func parseTrustedProxyList(raw string) (prefixes []netip.Prefix, invalid []string) {
for _, token := range strings.Split(raw, ",") {
item := strings.TrimSpace(token)
if item == "" {
continue
}
var (
p netip.Prefix
err error
)
if strings.Contains(item, "/") {
p, err = netip.ParsePrefix(item)
} else {
var addr netip.Addr
addr, err = netip.ParseAddr(item)
if err == nil {
addr = addr.Unmap()
bits := 128
if addr.Is4() {
bits = 32
}
p = netip.PrefixFrom(addr, bits)
}
}
if err != nil || !p.IsValid() {
invalid = append(invalid, item)
continue
}
prefixes = append(prefixes, p.Masked())
}
return prefixes, invalid
}
func hostWithoutPort(hostport string) string {
hostport = strings.TrimSpace(hostport)
if hostport == "" {
return ""
}
if host, _, err := net.SplitHostPort(hostport); err == nil {
return host
}
if strings.HasPrefix(hostport, "[") && strings.HasSuffix(hostport, "]") {
return strings.Trim(hostport, "[]")
}
parts := strings.Split(hostport, ":")
return parts[0]
}

View File

@@ -196,28 +196,6 @@ func (h *ProxyHandler) Delete(c *gin.Context) {
response.Success(c, gin.H{"message": "Proxy deleted successfully"})
}
// BatchDelete handles batch deleting proxies
// POST /api/v1/admin/proxies/batch-delete
func (h *ProxyHandler) BatchDelete(c *gin.Context) {
type BatchDeleteRequest struct {
IDs []int64 `json:"ids" binding:"required,min=1"`
}
var req BatchDeleteRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
result, err := h.adminService.BatchDeleteProxies(c.Request.Context(), req.IDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// Test handles testing proxy connectivity
// POST /api/v1/admin/proxies/:id/test
func (h *ProxyHandler) Test(c *gin.Context) {
@@ -265,17 +243,19 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
return
}
accounts, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID)
page, pageSize := response.ParsePagination(c)
accounts, total, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID, page, pageSize)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.ProxyAccountSummary, 0, len(accounts))
out := make([]dto.Account, 0, len(accounts))
for i := range accounts {
out = append(out, *dto.ProxyAccountSummaryFromService(&accounts[i]))
out = append(out, *dto.AccountFromService(&accounts[i]))
}
response.Success(c, out)
response.Paginated(c, out, total, page, pageSize)
}
// BatchCreateProxyItem represents a single proxy in batch create request

View File

@@ -19,16 +19,14 @@ type SettingHandler struct {
settingService *service.SettingService
emailService *service.EmailService
turnstileService *service.TurnstileService
opsService *service.OpsService
}
// NewSettingHandler 创建系统设置处理器
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService) *SettingHandler {
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService) *SettingHandler {
return &SettingHandler{
settingService: settingService,
emailService: emailService,
turnstileService: turnstileService,
opsService: opsService,
}
}
@@ -41,9 +39,6 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
return
}
// Check if ops monitoring is enabled (respects config.ops.enabled)
opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context())
response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
@@ -77,10 +72,6 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
FallbackModelAntigravity: settings.FallbackModelAntigravity,
EnableIdentityPatch: settings.EnableIdentityPatch,
IdentityPatchPrompt: settings.IdentityPatchPrompt,
OpsMonitoringEnabled: opsEnabled && settings.OpsMonitoringEnabled,
OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled,
OpsQueryModeDefault: settings.OpsQueryModeDefault,
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
})
}
@@ -104,7 +95,7 @@ type UpdateSettingsRequest struct {
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key"`
// LinuxDo Connect OAuth 登录
// LinuxDo Connect OAuth 登录(终端用户 SSO
LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"`
LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"`
LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"`
@@ -133,12 +124,6 @@ type UpdateSettingsRequest struct {
// Identity patch configuration (Claude -> Gemini)
EnableIdentityPatch bool `json:"enable_identity_patch"`
IdentityPatchPrompt string `json:"identity_patch_prompt"`
// Ops monitoring (vNext)
OpsMonitoringEnabled *bool `json:"ops_monitoring_enabled"`
OpsRealtimeMonitoringEnabled *bool `json:"ops_realtime_monitoring_enabled"`
OpsQueryModeDefault *string `json:"ops_query_mode_default"`
OpsMetricsIntervalSeconds *int `json:"ops_metrics_interval_seconds"`
}
// UpdateSettings 更新系统设置
@@ -223,18 +208,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
// Ops metrics collector interval validation (seconds).
if req.OpsMetricsIntervalSeconds != nil {
v := *req.OpsMetricsIntervalSeconds
if v < 60 {
v = 60
}
if v > 3600 {
v = 3600
}
req.OpsMetricsIntervalSeconds = &v
}
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
@@ -268,30 +241,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
OpsMonitoringEnabled: func() bool {
if req.OpsMonitoringEnabled != nil {
return *req.OpsMonitoringEnabled
}
return previousSettings.OpsMonitoringEnabled
}(),
OpsRealtimeMonitoringEnabled: func() bool {
if req.OpsRealtimeMonitoringEnabled != nil {
return *req.OpsRealtimeMonitoringEnabled
}
return previousSettings.OpsRealtimeMonitoringEnabled
}(),
OpsQueryModeDefault: func() string {
if req.OpsQueryModeDefault != nil {
return *req.OpsQueryModeDefault
}
return previousSettings.OpsQueryModeDefault
}(),
OpsMetricsIntervalSeconds: func() int {
if req.OpsMetricsIntervalSeconds != nil {
return *req.OpsMetricsIntervalSeconds
}
return previousSettings.OpsMetricsIntervalSeconds
}(),
}
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
@@ -341,10 +290,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
OpsMonitoringEnabled: updatedSettings.OpsMonitoringEnabled,
OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled,
OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
})
}
@@ -466,18 +411,6 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.IdentityPatchPrompt != after.IdentityPatchPrompt {
changed = append(changed, "identity_patch_prompt")
}
if before.OpsMonitoringEnabled != after.OpsMonitoringEnabled {
changed = append(changed, "ops_monitoring_enabled")
}
if before.OpsRealtimeMonitoringEnabled != after.OpsRealtimeMonitoringEnabled {
changed = append(changed, "ops_realtime_monitoring_enabled")
}
if before.OpsQueryModeDefault != after.OpsQueryModeDefault {
changed = append(changed, "ops_query_mode_default")
}
if before.OpsMetricsIntervalSeconds != after.OpsMetricsIntervalSeconds {
changed = append(changed, "ops_metrics_interval_seconds")
}
return changed
}
@@ -654,68 +587,3 @@ func (h *SettingHandler) DeleteAdminAPIKey(c *gin.Context) {
response.Success(c, gin.H{"message": "Admin API key deleted"})
}
// GetStreamTimeoutSettings 获取流超时处理配置
// GET /api/v1/admin/settings/stream-timeout
func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
settings, err := h.settingService.GetStreamTimeoutSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.StreamTimeoutSettings{
Enabled: settings.Enabled,
Action: settings.Action,
TempUnschedMinutes: settings.TempUnschedMinutes,
ThresholdCount: settings.ThresholdCount,
ThresholdWindowMinutes: settings.ThresholdWindowMinutes,
})
}
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
type UpdateStreamTimeoutSettingsRequest struct {
Enabled bool `json:"enabled"`
Action string `json:"action"`
TempUnschedMinutes int `json:"temp_unsched_minutes"`
ThresholdCount int `json:"threshold_count"`
ThresholdWindowMinutes int `json:"threshold_window_minutes"`
}
// UpdateStreamTimeoutSettings 更新流超时处理配置
// PUT /api/v1/admin/settings/stream-timeout
func (h *SettingHandler) UpdateStreamTimeoutSettings(c *gin.Context) {
var req UpdateStreamTimeoutSettingsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
settings := &service.StreamTimeoutSettings{
Enabled: req.Enabled,
Action: req.Action,
TempUnschedMinutes: req.TempUnschedMinutes,
ThresholdCount: req.ThresholdCount,
ThresholdWindowMinutes: req.ThresholdWindowMinutes,
}
if err := h.settingService.SetStreamTimeoutSettings(c.Request.Context(), settings); err != nil {
response.BadRequest(c, err.Error())
return
}
// 重新获取设置返回
updatedSettings, err := h.settingService.GetStreamTimeoutSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.StreamTimeoutSettings{
Enabled: updatedSettings.Enabled,
Action: updatedSettings.Action,
TempUnschedMinutes: updatedSettings.TempUnschedMinutes,
ThresholdCount: updatedSettings.ThresholdCount,
ThresholdWindowMinutes: updatedSettings.ThresholdWindowMinutes,
})
}

View File

@@ -1,377 +0,0 @@
package admin
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type cleanupRepoStub struct {
mu sync.Mutex
created []*service.UsageCleanupTask
listTasks []service.UsageCleanupTask
listResult *pagination.PaginationResult
listErr error
statusByID map[int64]string
}
func (s *cleanupRepoStub) CreateTask(ctx context.Context, task *service.UsageCleanupTask) error {
if task == nil {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
if task.ID == 0 {
task.ID = int64(len(s.created) + 1)
}
if task.CreatedAt.IsZero() {
task.CreatedAt = time.Now().UTC()
}
task.UpdatedAt = task.CreatedAt
clone := *task
s.created = append(s.created, &clone)
return nil
}
func (s *cleanupRepoStub) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) {
s.mu.Lock()
defer s.mu.Unlock()
return s.listTasks, s.listResult, s.listErr
}
func (s *cleanupRepoStub) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*service.UsageCleanupTask, error) {
return nil, nil
}
func (s *cleanupRepoStub) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.statusByID == nil {
return "", sql.ErrNoRows
}
status, ok := s.statusByID[taskID]
if !ok {
return "", sql.ErrNoRows
}
return status, nil
}
func (s *cleanupRepoStub) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error {
return nil
}
func (s *cleanupRepoStub) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.statusByID == nil {
s.statusByID = map[int64]string{}
}
status := s.statusByID[taskID]
if status != service.UsageCleanupStatusPending && status != service.UsageCleanupStatusRunning {
return false, nil
}
s.statusByID[taskID] = service.UsageCleanupStatusCanceled
return true, nil
}
func (s *cleanupRepoStub) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error {
return nil
}
func (s *cleanupRepoStub) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
return nil
}
func (s *cleanupRepoStub) DeleteUsageLogsBatch(ctx context.Context, filters service.UsageCleanupFilters, limit int) (int64, error) {
return 0, nil
}
var _ service.UsageCleanupRepository = (*cleanupRepoStub)(nil)
func setupCleanupRouter(cleanupService *service.UsageCleanupService, userID int64) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
if userID > 0 {
router.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: userID})
c.Next()
})
}
handler := NewUsageHandler(nil, nil, nil, cleanupService)
router.POST("/api/v1/admin/usage/cleanup-tasks", handler.CreateCleanupTask)
router.GET("/api/v1/admin/usage/cleanup-tasks", handler.ListCleanupTasks)
router.POST("/api/v1/admin/usage/cleanup-tasks/:id/cancel", handler.CancelCleanupTask)
return router
}
func TestUsageHandlerCreateCleanupTaskUnauthorized(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 0)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString(`{}`))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusUnauthorized, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskUnavailable(t *testing.T) {
router := setupCleanupRouter(nil, 1)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString(`{}`))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusServiceUnavailable, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskBindError(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 88)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString("{bad-json"))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskMissingRange(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 88)
payload := map[string]any{
"start_date": "2024-01-01",
"timezone": "UTC",
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskInvalidDate(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 88)
payload := map[string]any{
"start_date": "2024-13-01",
"end_date": "2024-01-02",
"timezone": "UTC",
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskInvalidEndDate(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 88)
payload := map[string]any{
"start_date": "2024-01-01",
"end_date": "2024-02-40",
"timezone": "UTC",
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 99)
payload := map[string]any{
"start_date": " 2024-01-01 ",
"end_date": "2024-01-02",
"timezone": "UTC",
"model": "gpt-4",
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
var resp response.Response
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.created, 1)
created := repo.created[0]
require.Equal(t, int64(99), created.CreatedBy)
require.NotNil(t, created.Filters.Model)
require.Equal(t, "gpt-4", *created.Filters.Model)
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC).Add(24*time.Hour - time.Nanosecond)
require.True(t, created.Filters.StartTime.Equal(start))
require.True(t, created.Filters.EndTime.Equal(end))
}
func TestUsageHandlerListCleanupTasksUnavailable(t *testing.T) {
router := setupCleanupRouter(nil, 0)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusServiceUnavailable, recorder.Code)
}
func TestUsageHandlerListCleanupTasksSuccess(t *testing.T) {
repo := &cleanupRepoStub{}
repo.listTasks = []service.UsageCleanupTask{
{
ID: 7,
Status: service.UsageCleanupStatusSucceeded,
CreatedBy: 4,
},
}
repo.listResult = &pagination.PaginationResult{Total: 1, Page: 1, PageSize: 20, Pages: 1}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data struct {
Items []dto.UsageCleanupTask `json:"items"`
Total int64 `json:"total"`
Page int `json:"page"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Len(t, resp.Data.Items, 1)
require.Equal(t, int64(7), resp.Data.Items[0].ID)
require.Equal(t, int64(1), resp.Data.Total)
require.Equal(t, 1, resp.Data.Page)
}
func TestUsageHandlerListCleanupTasksError(t *testing.T) {
repo := &cleanupRepoStub{listErr: errors.New("boom")}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusInternalServerError, recorder.Code)
}
func TestUsageHandlerCancelCleanupTaskUnauthorized(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 0)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/1/cancel", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestUsageHandlerCancelCleanupTaskNotFound(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/999/cancel", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestUsageHandlerCancelCleanupTaskConflict(t *testing.T) {
repo := &cleanupRepoStub{statusByID: map[int64]string{2: service.UsageCleanupStatusSucceeded}}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/2/cancel", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusConflict, rec.Code)
}
func TestUsageHandlerCancelCleanupTaskSuccess(t *testing.T) {
repo := &cleanupRepoStub{statusByID: map[int64]string{3: service.UsageCleanupStatusPending}}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/3/cancel", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}

View File

@@ -1,10 +1,7 @@
package admin
import (
"log"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
@@ -12,7 +9,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -20,10 +16,9 @@ import (
// UsageHandler handles admin usage-related requests
type UsageHandler struct {
usageService *service.UsageService
apiKeyService *service.APIKeyService
adminService service.AdminService
cleanupService *service.UsageCleanupService
usageService *service.UsageService
apiKeyService *service.APIKeyService
adminService service.AdminService
}
// NewUsageHandler creates a new admin usage handler
@@ -31,30 +26,14 @@ func NewUsageHandler(
usageService *service.UsageService,
apiKeyService *service.APIKeyService,
adminService service.AdminService,
cleanupService *service.UsageCleanupService,
) *UsageHandler {
return &UsageHandler{
usageService: usageService,
apiKeyService: apiKeyService,
adminService: adminService,
cleanupService: cleanupService,
usageService: usageService,
apiKeyService: apiKeyService,
adminService: adminService,
}
}
// CreateUsageCleanupTaskRequest represents cleanup task creation request
type CreateUsageCleanupTaskRequest struct {
StartDate string `json:"start_date"`
EndDate string `json:"end_date"`
UserID *int64 `json:"user_id"`
APIKeyID *int64 `json:"api_key_id"`
AccountID *int64 `json:"account_id"`
GroupID *int64 `json:"group_id"`
Model *string `json:"model"`
Stream *bool `json:"stream"`
BillingType *int8 `json:"billing_type"`
Timezone string `json:"timezone"`
}
// List handles listing all usage records with filters
// GET /api/v1/admin/usage
func (h *UsageHandler) List(c *gin.Context) {
@@ -365,162 +344,3 @@ func (h *UsageHandler) SearchAPIKeys(c *gin.Context) {
response.Success(c, result)
}
// ListCleanupTasks handles listing usage cleanup tasks
// GET /api/v1/admin/usage/cleanup-tasks
func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
if h.cleanupService == nil {
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
return
}
operator := int64(0)
if subject, ok := middleware.GetAuthSubjectFromContext(c); ok {
operator = subject.UserID
}
page, pageSize := response.ParsePagination(c)
log.Printf("[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params)
if err != nil {
log.Printf("[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
response.ErrorFrom(c, err)
return
}
out := make([]dto.UsageCleanupTask, 0, len(tasks))
for i := range tasks {
out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i]))
}
log.Printf("[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
response.Paginated(c, out, result.Total, page, pageSize)
}
// CreateCleanupTask handles creating a usage cleanup task
// POST /api/v1/admin/usage/cleanup-tasks
func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
if h.cleanupService == nil {
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Unauthorized(c, "Unauthorized")
return
}
var req CreateUsageCleanupTaskRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
req.StartDate = strings.TrimSpace(req.StartDate)
req.EndDate = strings.TrimSpace(req.EndDate)
if req.StartDate == "" || req.EndDate == "" {
response.BadRequest(c, "start_date and end_date are required")
return
}
startTime, err := timezone.ParseInUserLocation("2006-01-02", req.StartDate, req.Timezone)
if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return
}
endTime, err := timezone.ParseInUserLocation("2006-01-02", req.EndDate, req.Timezone)
if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
filters := service.UsageCleanupFilters{
StartTime: startTime,
EndTime: endTime,
UserID: req.UserID,
APIKeyID: req.APIKeyID,
AccountID: req.AccountID,
GroupID: req.GroupID,
Model: req.Model,
Stream: req.Stream,
BillingType: req.BillingType,
}
var userID any
if filters.UserID != nil {
userID = *filters.UserID
}
var apiKeyID any
if filters.APIKeyID != nil {
apiKeyID = *filters.APIKeyID
}
var accountID any
if filters.AccountID != nil {
accountID = *filters.AccountID
}
var groupID any
if filters.GroupID != nil {
groupID = *filters.GroupID
}
var model any
if filters.Model != nil {
model = *filters.Model
}
var stream any
if filters.Stream != nil {
stream = *filters.Stream
}
var billingType any
if filters.BillingType != nil {
billingType = *filters.BillingType
}
log.Printf("[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
subject.UserID,
filters.StartTime.Format(time.RFC3339),
filters.EndTime.Format(time.RFC3339),
userID,
apiKeyID,
accountID,
groupID,
model,
stream,
billingType,
req.Timezone,
)
task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID)
if err != nil {
log.Printf("[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
response.ErrorFrom(c, err)
return
}
log.Printf("[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
response.Success(c, dto.UsageCleanupTaskFromService(task))
}
// CancelCleanupTask handles canceling a usage cleanup task
// POST /api/v1/admin/usage/cleanup-tasks/:id/cancel
func (h *UsageHandler) CancelCleanupTask(c *gin.Context) {
if h.cleanupService == nil {
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Unauthorized(c, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
taskID, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || taskID <= 0 {
response.BadRequest(c, "Invalid task id")
return
}
log.Printf("[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil {
log.Printf("[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
response.ErrorFrom(c, err)
return
}
log.Printf("[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled})
}

View File

@@ -3,7 +3,6 @@ package handler
import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -77,7 +76,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
if req.VerifyCode == "" {
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.ErrorFrom(c, err)
return
}
@@ -106,7 +105,7 @@ func (h *AuthHandler) SendVerifyCode(c *gin.Context) {
}
// Turnstile 验证
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.ErrorFrom(c, err)
return
}
@@ -133,7 +132,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
}
// Turnstile 验证
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.ErrorFrom(c, err)
return
}

View File

@@ -73,27 +73,25 @@ func GroupFromServiceShallow(g *service.Group) *Group {
return nil
}
return &Group{
ID: g.ID,
Name: g.Name,
Description: g.Description,
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUSD,
WeeklyLimitUSD: g.WeeklyLimitUSD,
MonthlyLimitUSD: g.MonthlyLimitUSD,
ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
AccountCount: g.AccountCount,
ID: g.ID,
Name: g.Name,
Description: g.Description,
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUSD,
WeeklyLimitUSD: g.WeeklyLimitUSD,
MonthlyLimitUSD: g.MonthlyLimitUSD,
ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
AccountCount: g.AccountCount,
}
}
@@ -116,7 +114,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
if a == nil {
return nil
}
out := &Account{
return &Account{
ID: a.ID,
Name: a.Name,
Notes: a.Notes,
@@ -127,7 +125,6 @@ func AccountFromServiceShallow(a *service.Account) *Account {
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
Priority: a.Priority,
RateMultiplier: a.BillingRateMultiplier(),
Status: a.Status,
ErrorMessage: a.ErrorMessage,
LastUsedAt: a.LastUsedAt,
@@ -146,34 +143,6 @@ func AccountFromServiceShallow(a *service.Account) *Account {
SessionWindowStatus: a.SessionWindowStatus,
GroupIDs: a.GroupIDs,
}
// 提取 5h 窗口费用控制和会话数量控制配置(仅 Anthropic OAuth/SetupToken 账号有效)
if a.IsAnthropicOAuthOrSetupToken() {
if limit := a.GetWindowCostLimit(); limit > 0 {
out.WindowCostLimit = &limit
}
if reserve := a.GetWindowCostStickyReserve(); reserve > 0 {
out.WindowCostStickyReserve = &reserve
}
if maxSessions := a.GetMaxSessions(); maxSessions > 0 {
out.MaxSessions = &maxSessions
}
if idleTimeout := a.GetSessionIdleTimeoutMinutes(); idleTimeout > 0 {
out.SessionIdleTimeoutMin = &idleTimeout
}
// TLS指纹伪装开关
if a.IsTLSFingerprintEnabled() {
enabled := true
out.EnableTLSFingerprint = &enabled
}
// 会话ID伪装开关
if a.IsSessionIDMaskingEnabled() {
enabled := true
out.EnableSessionIDMasking = &enabled
}
}
return out
}
func AccountFromService(a *service.Account) *Account {
@@ -243,29 +212,8 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
return nil
}
return &ProxyWithAccountCount{
Proxy: *ProxyFromService(&p.Proxy),
AccountCount: p.AccountCount,
LatencyMs: p.LatencyMs,
LatencyStatus: p.LatencyStatus,
LatencyMessage: p.LatencyMessage,
IPAddress: p.IPAddress,
Country: p.Country,
CountryCode: p.CountryCode,
Region: p.Region,
City: p.City,
}
}
func ProxyAccountSummaryFromService(a *service.ProxyAccountSummary) *ProxyAccountSummary {
if a == nil {
return nil
}
return &ProxyAccountSummary{
ID: a.ID,
Name: a.Name,
Platform: a.Platform,
Type: a.Type,
Notes: a.Notes,
Proxy: *ProxyFromService(&p.Proxy),
AccountCount: p.AccountCount,
}
}
@@ -331,7 +279,6 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, inclu
TotalCost: l.TotalCost,
ActualCost: l.ActualCost,
RateMultiplier: l.RateMultiplier,
AccountRateMultiplier: l.AccountRateMultiplier,
BillingType: l.BillingType,
Stream: l.Stream,
DurationMs: l.DurationMs,
@@ -368,36 +315,6 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog {
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true)
}
func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTask {
if task == nil {
return nil
}
return &UsageCleanupTask{
ID: task.ID,
Status: task.Status,
Filters: UsageCleanupFilters{
StartTime: task.Filters.StartTime,
EndTime: task.Filters.EndTime,
UserID: task.Filters.UserID,
APIKeyID: task.Filters.APIKeyID,
AccountID: task.Filters.AccountID,
GroupID: task.Filters.GroupID,
Model: task.Filters.Model,
Stream: task.Filters.Stream,
BillingType: task.Filters.BillingType,
},
CreatedBy: task.CreatedBy,
DeletedRows: task.DeletedRows,
ErrorMessage: task.ErrorMsg,
CanceledBy: task.CanceledBy,
CanceledAt: task.CanceledAt,
StartedAt: task.StartedAt,
FinishedAt: task.FinishedAt,
CreatedAt: task.CreatedAt,
UpdatedAt: task.UpdatedAt,
}
}
func SettingFromService(s *service.Setting) *Setting {
if s == nil {
return nil

View File

@@ -43,12 +43,6 @@ type SystemSettings struct {
// Identity patch configuration (Claude -> Gemini)
EnableIdentityPatch bool `json:"enable_identity_patch"`
IdentityPatchPrompt string `json:"identity_patch_prompt"`
// Ops monitoring (vNext)
OpsMonitoringEnabled bool `json:"ops_monitoring_enabled"`
OpsRealtimeMonitoringEnabled bool `json:"ops_realtime_monitoring_enabled"`
OpsQueryModeDefault string `json:"ops_query_mode_default"`
OpsMetricsIntervalSeconds int `json:"ops_metrics_interval_seconds"`
}
type PublicSettings struct {
@@ -66,12 +60,3 @@ type PublicSettings struct {
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version"`
}
// StreamTimeoutSettings 流超时处理配置 DTO
type StreamTimeoutSettings struct {
Enabled bool `json:"enabled"`
Action string `json:"action"`
TempUnschedMinutes int `json:"temp_unsched_minutes"`
ThresholdCount int `json:"threshold_count"`
ThresholdWindowMinutes int `json:"threshold_window_minutes"`
}

View File

@@ -58,10 +58,6 @@ type Group struct {
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
@@ -80,7 +76,6 @@ type Account struct {
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
RateMultiplier float64 `json:"rate_multiplier"`
Status string `json:"status"`
ErrorMessage string `json:"error_message"`
LastUsedAt *time.Time `json:"last_used_at"`
@@ -102,25 +97,6 @@ type Account struct {
SessionWindowEnd *time.Time `json:"session_window_end"`
SessionWindowStatus string `json:"session_window_status"`
// 5h窗口费用控制仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
WindowCostLimit *float64 `json:"window_cost_limit,omitempty"`
WindowCostStickyReserve *float64 `json:"window_cost_sticky_reserve,omitempty"`
// 会话数量控制(仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
MaxSessions *int `json:"max_sessions,omitempty"`
SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"`
// TLS指纹伪装仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
EnableTLSFingerprint *bool `json:"enable_tls_fingerprint,omitempty"`
// 会话ID伪装仅 Anthropic OAuth/SetupToken 账号有效)
// 启用后将在15分钟内固定 metadata.user_id 中的 session ID
// 从 extra 字段提取,方便前端显示和编辑
EnableSessionIDMasking *bool `json:"session_id_masking_enabled,omitempty"`
Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
@@ -153,23 +129,7 @@ type Proxy struct {
type ProxyWithAccountCount struct {
Proxy
AccountCount int64 `json:"account_count"`
LatencyMs *int64 `json:"latency_ms,omitempty"`
LatencyStatus string `json:"latency_status,omitempty"`
LatencyMessage string `json:"latency_message,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
Country string `json:"country,omitempty"`
CountryCode string `json:"country_code,omitempty"`
Region string `json:"region,omitempty"`
City string `json:"city,omitempty"`
}
type ProxyAccountSummary struct {
ID int64 `json:"id"`
Name string `json:"name"`
Platform string `json:"platform"`
Type string `json:"type"`
Notes *string `json:"notes,omitempty"`
AccountCount int64 `json:"account_count"`
}
type RedeemCode struct {
@@ -209,14 +169,13 @@ type UsageLog struct {
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
InputCost float64 `json:"input_cost"`
OutputCost float64 `json:"output_cost"`
CacheCreationCost float64 `json:"cache_creation_cost"`
CacheReadCost float64 `json:"cache_read_cost"`
TotalCost float64 `json:"total_cost"`
ActualCost float64 `json:"actual_cost"`
RateMultiplier float64 `json:"rate_multiplier"`
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
InputCost float64 `json:"input_cost"`
OutputCost float64 `json:"output_cost"`
CacheCreationCost float64 `json:"cache_creation_cost"`
CacheReadCost float64 `json:"cache_read_cost"`
TotalCost float64 `json:"total_cost"`
ActualCost float64 `json:"actual_cost"`
RateMultiplier float64 `json:"rate_multiplier"`
BillingType int8 `json:"billing_type"`
Stream bool `json:"stream"`
@@ -242,33 +201,6 @@ type UsageLog struct {
Subscription *UserSubscription `json:"subscription,omitempty"`
}
type UsageCleanupFilters struct {
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
UserID *int64 `json:"user_id,omitempty"`
APIKeyID *int64 `json:"api_key_id,omitempty"`
AccountID *int64 `json:"account_id,omitempty"`
GroupID *int64 `json:"group_id,omitempty"`
Model *string `json:"model,omitempty"`
Stream *bool `json:"stream,omitempty"`
BillingType *int8 `json:"billing_type,omitempty"`
}
type UsageCleanupTask struct {
ID int64 `json:"id"`
Status string `json:"status"`
Filters UsageCleanupFilters `json:"filters"`
CreatedBy int64 `json:"created_by"`
DeletedRows int64 `json:"deleted_rows"`
ErrorMessage *string `json:"error_message,omitempty"`
CanceledBy *int64 `json:"canceled_by,omitempty"`
CanceledAt *time.Time `json:"canceled_at,omitempty"`
StartedAt *time.Time `json:"started_at,omitempty"`
FinishedAt *time.Time `json:"finished_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// AccountSummary is a minimal account info for usage log display.
// It intentionally excludes sensitive fields like Credentials, Proxy, etc.
type AccountSummary struct {

View File

@@ -31,8 +31,6 @@ type GatewayHandler struct {
userService *service.UserService
billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
maxAccountSwitchesGemini int
}
// NewGatewayHandler creates a new GatewayHandler
@@ -46,16 +44,8 @@ func NewGatewayHandler(
cfg *config.Config,
) *GatewayHandler {
pingInterval := time.Duration(0)
maxAccountSwitches := 10
maxAccountSwitchesGemini := 3
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
if cfg.Gateway.MaxAccountSwitches > 0 {
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
}
if cfg.Gateway.MaxAccountSwitchesGemini > 0 {
maxAccountSwitchesGemini = cfg.Gateway.MaxAccountSwitchesGemini
}
}
return &GatewayHandler{
gatewayService: gatewayService,
@@ -64,8 +54,6 @@ func NewGatewayHandler(
userService: userService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
maxAccountSwitches: maxAccountSwitches,
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
}
}
@@ -101,11 +89,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext(c, body)
setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
@@ -114,7 +97,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
reqModel := parsedReq.Model
reqStream := parsedReq.Stream
setOpsRequestContext(c, reqModel, reqStream, body)
// 设置 Claude Code 客户端标识到 context用于分组限制检查
SetClaudeCodeClientContext(c, body)
// 验证 model 必填
if reqModel == "" {
@@ -128,10 +112,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 获取订阅信息可能为nil- 提前获取用于后续检查
subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 获取 User-Agent
userAgent := c.Request.UserAgent()
// 获取客户端 IP
clientIP := ip.GetClientIP(c)
// 0. 检查wait队列是否已满
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
log.Printf("Increment wait count failed: %v", err)
// On error, allow request to proceed
@@ -139,15 +128,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if err == nil && canWait {
waitCounted = true
}
// Ensure we decrement if we exit before acquiring the user slot.
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
// 确保在函数退出时减少wait计数
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
// 1. 首先获取用户并发槽位
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
@@ -156,11 +138,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
// User slot acquired: no longer waiting in the queue.
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false
}
// 在请求结束或 Context 取消时确保释放槽位,避免客户端断开造成泄漏
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
@@ -191,13 +168,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
if platform == service.PlatformGemini {
maxAccountSwitches := h.maxAccountSwitchesGemini
const maxAccountSwitches = 3
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
@@ -207,7 +184,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID)
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
@@ -224,12 +200,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 3. 获取账号并发槽位
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
@@ -237,16 +213,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
if err == nil && canWait {
accountWaitCounted = true
}
// Ensure the wait counter is decremented if we exit before acquiring the slot.
defer func() {
if accountWaitCounted {
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}()
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
@@ -257,21 +229,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
// Slot acquired: no longer waiting in queue.
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
@@ -283,15 +254,19 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverStatus = failoverErr.StatusCode
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
lastFailoverStatus = failoverErr.StatusCode
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
@@ -301,12 +276,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
// 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) {
go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
@@ -316,7 +287,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: clientIP,
IPAddress: cip,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
@@ -325,14 +296,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
}
maxAccountSwitches := h.maxAccountSwitches
const maxAccountSwitches = 10
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
for {
// 选择支持该模型的账号
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
@@ -342,7 +313,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID)
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
@@ -359,12 +329,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 3. 获取账号并发槽位
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
@@ -372,15 +342,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
if err == nil && canWait {
accountWaitCounted = true
}
defer func() {
if accountWaitCounted {
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}()
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
@@ -391,20 +358,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
@@ -416,15 +383,19 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverStatus = failoverErr.StatusCode
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
lastFailoverStatus = failoverErr.StatusCode
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
@@ -434,12 +405,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
// 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) {
go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
@@ -449,7 +416,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: clientIP,
IPAddress: cip,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
@@ -719,22 +686,21 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// 设置 Claude Code 客户端标识到 context用于分组限制检查
SetClaudeCodeClientContext(c, body)
// 验证 model 必填
if parsedReq.Model == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream, body)
// 获取订阅信息可能为nil
subscription, _ := middleware2.GetSubscriptionFromContext(c)
@@ -755,7 +721,6 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
setOpsSelectedAccount(c, account.ID)
// 转发请求(不记录使用量)
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil {

View File

@@ -162,32 +162,28 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
return
}
setOpsRequestContext(c, modelName, stream, body)
// Get subscription (may be nil)
subscription, _ := middleware.GetSubscriptionFromContext(c)
// 获取 User-Agent
userAgent := c.Request.UserAgent()
// 获取客户端 IP
clientIP := ip.GetClientIP(c)
// For Gemini native API, do not send Claude-style ping frames.
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0)
// 0) wait queue check
maxWait := service.CalculateMaxWait(authSubject.Concurrency)
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait)
waitCounted := false
if err != nil {
log.Printf("Increment wait count failed: %v", err)
} else if !canWait {
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
}
if err == nil && canWait {
waitCounted = true
}
defer func() {
if waitCounted {
geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID)
}
}()
defer geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID)
// 1) user concurrency slot
streamStarted := false
@@ -196,10 +192,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
if waitCounted {
geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID)
waitCounted = false
}
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
@@ -215,18 +207,22 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 3) select account (sticky session based on request body)
parsedReq, _ := service.ParseGatewayRequest(body)
// 设置 Claude Code 客户端标识到 context用于分组限制检查
SetClaudeCodeClientContext(c, body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
sessionKey := sessionHash
if sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
maxAccountSwitches := h.maxAccountSwitchesGemini
const maxAccountSwitches = 3
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
@@ -236,16 +232,15 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID)
// 4) account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts")
return
}
accountWaitCounted := false
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
@@ -253,15 +248,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
log.Printf("Account wait queue full: account=%d", account.ID)
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
}
if err == nil && canWait {
accountWaitCounted = true
}
defer func() {
if accountWaitCounted {
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}()
}
accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout(
c,
@@ -272,19 +264,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
if accountWaitCounted {
geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// 5) forward (根据平台分流)
var result *service.ForwardResult
@@ -296,6 +288,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
@@ -315,12 +310,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
return
}
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
// 6) record usage async
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) {
go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
@@ -330,7 +321,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: ip,
IPAddress: cip,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}

View File

@@ -18,7 +18,6 @@ type AdminHandlers struct {
Redeem *admin.RedeemHandler
Promo *admin.PromoHandler
Setting *admin.SettingHandler
Ops *admin.OpsHandler
System *admin.SystemHandler
Subscription *admin.SubscriptionHandler
Usage *admin.UsageHandler

View File

@@ -8,7 +8,6 @@ import (
"io"
"log"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -25,7 +24,6 @@ type OpenAIGatewayHandler struct {
gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
@@ -36,18 +34,13 @@ func NewOpenAIGatewayHandler(
cfg *config.Config,
) *OpenAIGatewayHandler {
pingInterval := time.Duration(0)
maxAccountSwitches := 3
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
if cfg.Gateway.MaxAccountSwitches > 0 {
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
}
}
return &OpenAIGatewayHandler{
gatewayService: gatewayService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
}
}
@@ -83,8 +76,6 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
setOpsRequestContext(c, "", false, body)
// Parse request body to map for potential modification
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
@@ -102,41 +93,19 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
// For non-Codex CLI requests, set default instructions
userAgent := c.GetHeader("User-Agent")
// 获取客户端 IP
clientIP := ip.GetClientIP(c)
if !openai.IsCodexCLIRequest(userAgent) {
existingInstructions, _ := reqBody["instructions"].(string)
if strings.TrimSpace(existingInstructions) == "" {
if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" {
reqBody["instructions"] = instructions
// Re-serialize body
body, err = json.Marshal(reqBody)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
return
}
}
}
}
setOpsRequestContext(c, reqModel, reqStream, body)
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
// 要求 previous_response_id或 input 内存在带 call_id 的 tool_call/function_call
// 或带 id 且与 call_id 匹配的 item_reference。
if service.HasFunctionCallOutput(reqBody) {
previousResponseID, _ := reqBody["previous_response_id"].(string)
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
if service.HasFunctionCallOutputMissingCallID(reqBody) {
log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
return
}
callIDs := service.FunctionCallOutputCallIDs(reqBody)
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
return
}
reqBody["instructions"] = openai.DefaultInstructions
// Re-serialize body
body, err = json.Marshal(reqBody)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
return
}
}
@@ -149,7 +118,6 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 0. Check if wait queue is full
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
log.Printf("Increment wait count failed: %v", err)
// On error, allow request to proceed
@@ -157,14 +125,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if err == nil && canWait {
waitCounted = true
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
// Ensure wait count is decremented when function exits
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
// 1. First acquire user concurrency slot
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
@@ -173,11 +135,6 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
// User slot acquired: no longer waiting.
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false
}
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
@@ -192,10 +149,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
// Generate session hash (header first; fallback to prompt_cache_key)
sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody)
// Generate session hash (from header for OpenAI)
sessionHash := h.gatewayService.GenerateSessionHash(c)
maxAccountSwitches := h.maxAccountSwitches
const maxAccountSwitches = 3
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
@@ -215,16 +172,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
account := selection.Account
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
setOpsSelectedAccount(c, account.ID)
// 3. Acquire account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
@@ -232,15 +188,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
if err == nil && canWait {
accountWaitCounted = true
}
defer func() {
if accountWaitCounted {
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}()
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
@@ -251,26 +204,29 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// Forward request
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
@@ -290,12 +246,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
// Async record usage
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua, ip string) {
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua string, cip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
@@ -305,7 +257,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: ip,
IPAddress: cip,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}

File diff suppressed because it is too large Load Diff

View File

@@ -21,7 +21,6 @@ func ProvideAdminHandlers(
redeemHandler *admin.RedeemHandler,
promoHandler *admin.PromoHandler,
settingHandler *admin.SettingHandler,
opsHandler *admin.OpsHandler,
systemHandler *admin.SystemHandler,
subscriptionHandler *admin.SubscriptionHandler,
usageHandler *admin.UsageHandler,
@@ -40,7 +39,6 @@ func ProvideAdminHandlers(
Redeem: redeemHandler,
Promo: promoHandler,
Setting: settingHandler,
Ops: opsHandler,
System: systemHandler,
Subscription: subscriptionHandler,
Usage: usageHandler,
@@ -111,7 +109,6 @@ var ProviderSet = wire.NewSet(
admin.NewRedeemHandler,
admin.NewPromoHandler,
admin.NewSettingHandler,
admin.NewOpsHandler,
ProvideSystemHandler,
admin.NewSubscriptionHandler,
admin.NewUsageHandler,

View File

@@ -1,63 +1,13 @@
package middleware
import (
"context"
"fmt"
"log"
"net/http"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// RateLimitFailureMode Redis 故障策略
type RateLimitFailureMode int
const (
RateLimitFailOpen RateLimitFailureMode = iota
RateLimitFailClose
)
// RateLimitOptions 限流可选配置
type RateLimitOptions struct {
FailureMode RateLimitFailureMode
}
var rateLimitScript = redis.NewScript(`
local current = redis.call('INCR', KEYS[1])
local ttl = redis.call('PTTL', KEYS[1])
local repaired = 0
if current == 1 then
redis.call('PEXPIRE', KEYS[1], ARGV[1])
elseif ttl == -1 then
redis.call('PEXPIRE', KEYS[1], ARGV[1])
repaired = 1
end
return {current, repaired}
`)
// rateLimitRun 允许测试覆写脚本执行逻辑
var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) {
values, err := rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Slice()
if err != nil {
return 0, false, err
}
if len(values) < 2 {
return 0, false, fmt.Errorf("rate limit script returned %d values", len(values))
}
count, err := parseInt64(values[0])
if err != nil {
return 0, false, err
}
repaired, err := parseInt64(values[1])
if err != nil {
return 0, false, err
}
return count, repaired == 1, nil
}
// RateLimiter Redis 速率限制器
type RateLimiter struct {
redis *redis.Client
@@ -77,85 +27,34 @@ func NewRateLimiter(redisClient *redis.Client) *RateLimiter {
// limit: 时间窗口内最大请求数
// window: 时间窗口
func (r *RateLimiter) Limit(key string, limit int, window time.Duration) gin.HandlerFunc {
return r.LimitWithOptions(key, limit, window, RateLimitOptions{})
}
// LimitWithOptions 返回速率限制中间件(带可选配置)
func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Duration, opts RateLimitOptions) gin.HandlerFunc {
failureMode := opts.FailureMode
if failureMode != RateLimitFailClose {
failureMode = RateLimitFailOpen
}
return func(c *gin.Context) {
ip := c.ClientIP()
redisKey := r.prefix + key + ":" + ip
ctx := c.Request.Context()
windowMillis := windowTTLMillis(window)
// 使用 Lua 脚本原子操作增加计数并设置过期
count, repaired, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis)
// 使用 INCR 原子操作增加计数
count, err := r.redis.Incr(ctx, redisKey).Result()
if err != nil {
log.Printf("[RateLimit] redis error: key=%s mode=%s err=%v", redisKey, failureModeLabel(failureMode), err)
if failureMode == RateLimitFailClose {
abortRateLimit(c)
return
}
// Redis 错误时放行,避免影响正常服务
c.Next()
return
}
if repaired {
log.Printf("[RateLimit] ttl repaired: key=%s window_ms=%d", redisKey, windowMillis)
// 首次访问时设置过期时间
if count == 1 {
r.redis.Expire(ctx, redisKey, window)
}
// 超过限制
if count > int64(limit) {
abortRateLimit(c)
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": "rate limit exceeded",
"message": "Too many requests, please try again later",
})
return
}
c.Next()
}
}
func windowTTLMillis(window time.Duration) int64 {
ttl := window.Milliseconds()
if ttl < 1 {
return 1
}
return ttl
}
func abortRateLimit(c *gin.Context) {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": "rate limit exceeded",
"message": "Too many requests, please try again later",
})
}
func failureModeLabel(mode RateLimitFailureMode) string {
if mode == RateLimitFailClose {
return "fail-close"
}
return "fail-open"
}
func parseInt64(value any) (int64, error) {
switch v := value.(type) {
case int64:
return v, nil
case int:
return int64(v), nil
case string:
parsed, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return 0, err
}
return parsed, nil
default:
return 0, fmt.Errorf("unexpected value type %T", value)
}
}

View File

@@ -1,158 +0,0 @@
//go:build integration
package middleware
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strconv"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
)
const redisImageTag = "redis:8.4-alpine"
func TestRateLimiterSetsTTLAndDoesNotRefresh(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx := context.Background()
rdb := startRedis(t, ctx)
limiter := NewRateLimiter(rdb)
router := gin.New()
router.Use(limiter.Limit("ttl-test", 10, 2*time.Second))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
recorder := performRequest(router)
require.Equal(t, http.StatusOK, recorder.Code)
redisKey := limiter.prefix + "ttl-test:127.0.0.1"
ttlBefore, err := rdb.PTTL(ctx, redisKey).Result()
require.NoError(t, err)
require.Greater(t, ttlBefore, time.Duration(0))
require.LessOrEqual(t, ttlBefore, 2*time.Second)
time.Sleep(50 * time.Millisecond)
recorder = performRequest(router)
require.Equal(t, http.StatusOK, recorder.Code)
ttlAfter, err := rdb.PTTL(ctx, redisKey).Result()
require.NoError(t, err)
require.Less(t, ttlAfter, ttlBefore)
}
func TestRateLimiterFixesMissingTTL(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx := context.Background()
rdb := startRedis(t, ctx)
limiter := NewRateLimiter(rdb)
router := gin.New()
router.Use(limiter.Limit("ttl-missing", 10, 2*time.Second))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
redisKey := limiter.prefix + "ttl-missing:127.0.0.1"
require.NoError(t, rdb.Set(ctx, redisKey, 5, 0).Err())
ttlBefore, err := rdb.PTTL(ctx, redisKey).Result()
require.NoError(t, err)
require.Less(t, ttlBefore, time.Duration(0))
recorder := performRequest(router)
require.Equal(t, http.StatusOK, recorder.Code)
ttlAfter, err := rdb.PTTL(ctx, redisKey).Result()
require.NoError(t, err)
require.Greater(t, ttlAfter, time.Duration(0))
}
func performRequest(router *gin.Engine) *httptest.ResponseRecorder {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "127.0.0.1:1234"
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
return recorder
}
func startRedis(t *testing.T, ctx context.Context) *redis.Client {
t.Helper()
ensureDockerAvailable(t)
redisContainer, err := tcredis.Run(ctx, redisImageTag)
require.NoError(t, err)
t.Cleanup(func() {
_ = redisContainer.Terminate(ctx)
})
redisHost, err := redisContainer.Host(ctx)
require.NoError(t, err)
redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp")
require.NoError(t, err)
rdb := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()),
DB: 0,
})
require.NoError(t, rdb.Ping(ctx).Err())
t.Cleanup(func() {
_ = rdb.Close()
})
return rdb
}
func ensureDockerAvailable(t *testing.T) {
t.Helper()
if dockerAvailable() {
return
}
t.Skip("Docker 未启用,跳过依赖 testcontainers 的集成测试")
}
func dockerAvailable() bool {
if os.Getenv("DOCKER_HOST") != "" {
return true
}
socketCandidates := []string{
"/var/run/docker.sock",
filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"),
filepath.Join(userHomeDir(), ".docker", "run", "docker.sock"),
filepath.Join(userHomeDir(), ".docker", "desktop", "docker.sock"),
filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"),
}
for _, socket := range socketCandidates {
if socket == "" {
continue
}
if _, err := os.Stat(socket); err == nil {
return true
}
}
return false
}
func userHomeDir() string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return home
}

View File

@@ -1,100 +0,0 @@
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
func TestWindowTTLMillis(t *testing.T) {
require.Equal(t, int64(1), windowTTLMillis(500*time.Microsecond))
require.Equal(t, int64(1), windowTTLMillis(1500*time.Microsecond))
require.Equal(t, int64(2), windowTTLMillis(2500*time.Microsecond))
}
func TestRateLimiterFailureModes(t *testing.T) {
gin.SetMode(gin.TestMode)
rdb := redis.NewClient(&redis.Options{
Addr: "127.0.0.1:1",
DialTimeout: 50 * time.Millisecond,
ReadTimeout: 50 * time.Millisecond,
WriteTimeout: 50 * time.Millisecond,
})
t.Cleanup(func() {
_ = rdb.Close()
})
limiter := NewRateLimiter(rdb)
failOpenRouter := gin.New()
failOpenRouter.Use(limiter.Limit("test", 1, time.Second))
failOpenRouter.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "127.0.0.1:1234"
recorder := httptest.NewRecorder()
failOpenRouter.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
failCloseRouter := gin.New()
failCloseRouter.Use(limiter.LimitWithOptions("test", 1, time.Second, RateLimitOptions{
FailureMode: RateLimitFailClose,
}))
failCloseRouter.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
req = httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "127.0.0.1:1234"
recorder = httptest.NewRecorder()
failCloseRouter.ServeHTTP(recorder, req)
require.Equal(t, http.StatusTooManyRequests, recorder.Code)
}
func TestRateLimiterSuccessAndLimit(t *testing.T) {
gin.SetMode(gin.TestMode)
originalRun := rateLimitRun
counts := []int64{1, 2}
callIndex := 0
rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) {
if callIndex >= len(counts) {
return counts[len(counts)-1], false, nil
}
value := counts[callIndex]
callIndex++
return value, false, nil
}
t.Cleanup(func() {
rateLimitRun = originalRun
})
limiter := NewRateLimiter(redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"}))
router := gin.New()
router.Use(limiter.Limit("test", 1, time.Second))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "127.0.0.1:1234"
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
req = httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "127.0.0.1:1234"
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusTooManyRequests, recorder.Code)
}

View File

@@ -16,6 +16,15 @@ import (
"time"
)
// resolveHost 从 URL 解析 host
func resolveHost(urlStr string) string {
parsed, err := url.Parse(urlStr)
if err != nil {
return ""
}
return parsed.Host
}
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求v1internal 端点)
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
// 构建 URL流式请求添加 ?alt=sse 参数
@@ -30,11 +39,23 @@ func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken stri
return nil, err
}
// 基础 Headers(与 Antigravity-Manager 保持一致,只设置这 3 个)
// 基础 Headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", UserAgent)
// Accept Header 根据请求类型设置
if isStream {
req.Header.Set("Accept", "text/event-stream")
} else {
req.Header.Set("Accept", "application/json")
}
// 显式设置 Host Header
if host := resolveHost(apiURL); host != "" {
req.Host = host
}
return req, nil
}
@@ -174,15 +195,12 @@ func isConnectionError(err error) bool {
}
// shouldFallbackToNextURL 判断是否应切换到下一个 URL
// 与 Antigravity-Manager 保持一致连接错误、429、408、404、5xx 触发 URL 降级
// 仅连接错误和 HTTP 429 触发 URL 降级
func shouldFallbackToNextURL(err error, statusCode int) bool {
if isConnectionError(err) {
return true
}
return statusCode == http.StatusTooManyRequests ||
statusCode == http.StatusRequestTimeout ||
statusCode == http.StatusNotFound ||
statusCode >= 500
return statusCode == http.StatusTooManyRequests
}
// ExchangeCode 用 authorization code 交换 token
@@ -303,8 +321,11 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
}
// 固定顺序prod -> daily
availableURLs := BaseURLs
// 获取可用的 URL 列表
availableURLs := DefaultURLAvailability.GetAvailableURLs()
if len(availableURLs) == 0 {
availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有
}
var lastErr error
for urlIdx, baseURL := range availableURLs {
@@ -322,6 +343,7 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
if err != nil {
lastErr = fmt.Errorf("loadCodeAssist 请求失败: %w", err)
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity] loadCodeAssist URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
continue
}
@@ -336,6 +358,7 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
// 检查是否需要 URL 降级
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity] loadCodeAssist URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
continue
}
@@ -353,8 +376,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
var rawResp map[string]any
_ = json.Unmarshal(respBodyBytes, &rawResp)
// 标记成功的 URL下次优先使用
DefaultURLAvailability.MarkSuccess(baseURL)
return &loadResp, rawResp, nil
}
@@ -391,8 +412,11 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
}
// 固定顺序prod -> daily
availableURLs := BaseURLs
// 获取可用的 URL 列表
availableURLs := DefaultURLAvailability.GetAvailableURLs()
if len(availableURLs) == 0 {
availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有
}
var lastErr error
for urlIdx, baseURL := range availableURLs {
@@ -410,6 +434,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
if err != nil {
lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity] fetchAvailableModels URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
continue
}
@@ -424,6 +449,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
// 检查是否需要 URL 降级
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity] fetchAvailableModels URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
continue
}
@@ -441,8 +467,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
var rawResp map[string]any
_ = json.Unmarshal(respBodyBytes, &rawResp)
// 标记成功的 URL下次优先使用
DefaultURLAvailability.MarkSuccess(baseURL)
return &modelsResp, rawResp, nil
}

View File

@@ -143,10 +143,9 @@ type GeminiResponse struct {
// GeminiCandidate Gemini 候选响应
type GeminiCandidate struct {
Content *GeminiContent `json:"content,omitempty"`
FinishReason string `json:"finishReason,omitempty"`
Index int `json:"index,omitempty"`
GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"`
Content *GeminiContent `json:"content,omitempty"`
FinishReason string `json:"finishReason,omitempty"`
Index int `json:"index,omitempty"`
}
// GeminiUsageMetadata Gemini 用量元数据
@@ -157,23 +156,6 @@ type GeminiUsageMetadata struct {
TotalTokenCount int `json:"totalTokenCount,omitempty"`
}
// GeminiGroundingMetadata Gemini grounding 元数据Web Search
type GeminiGroundingMetadata struct {
WebSearchQueries []string `json:"webSearchQueries,omitempty"`
GroundingChunks []GeminiGroundingChunk `json:"groundingChunks,omitempty"`
}
// GeminiGroundingChunk Gemini grounding chunk
type GeminiGroundingChunk struct {
Web *GeminiGroundingWeb `json:"web,omitempty"`
}
// GeminiGroundingWeb Gemini grounding web 信息
type GeminiGroundingWeb struct {
Title string `json:"title,omitempty"`
URI string `json:"uri,omitempty"`
}
// DefaultSafetySettings 默认安全设置(关闭所有过滤)
var DefaultSafetySettings = []GeminiSafetySetting{
{Category: "HARM_CATEGORY_HARASSMENT", Threshold: "OFF"},

View File

@@ -32,8 +32,8 @@ const (
"https://www.googleapis.com/auth/cclog " +
"https://www.googleapis.com/auth/experimentsandconfigs"
// User-Agent与 Antigravity-Manager 保持一致
UserAgent = "antigravity/1.11.9 windows/amd64"
// User-Agent模拟官方客户端
UserAgent = "antigravity/1.104.0 darwin/arm64"
// Session 过期时间
SessionTTL = 30 * time.Minute
@@ -42,21 +42,22 @@ const (
URLAvailabilityTTL = 5 * time.Minute
)
// BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致)
// BaseURLs 定义 Antigravity API 端点,按优先级排序
// fallback 顺序: sandbox → daily → prod
var BaseURLs = []string{
"https://cloudcode-pa.googleapis.com", // prod (优先)
"https://daily-cloudcode-pa.sandbox.googleapis.com", // daily sandbox (备用)
"https://daily-cloudcode-pa.sandbox.googleapis.com", // sandbox
"https://daily-cloudcode-pa.googleapis.com", // daily
"https://cloudcode-pa.googleapis.com", // prod
}
// BaseURL 默认 URL保持向后兼容
var BaseURL = BaseURLs[0]
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复)
type URLAvailability struct {
mu sync.RWMutex
unavailable map[string]time.Time // URL -> 恢复时间
ttl time.Duration
lastSuccess string // 最近成功请求的 URL优先使用
}
// DefaultURLAvailability 全局 URL 可用性管理器
@@ -77,15 +78,6 @@ func (u *URLAvailability) MarkUnavailable(url string) {
u.unavailable[url] = time.Now().Add(u.ttl)
}
// MarkSuccess 标记 URL 请求成功,将其设为优先使用
func (u *URLAvailability) MarkSuccess(url string) {
u.mu.Lock()
defer u.mu.Unlock()
u.lastSuccess = url
// 成功后清除该 URL 的不可用标记
delete(u.unavailable, url)
}
// IsAvailable 检查 URL 是否可用
func (u *URLAvailability) IsAvailable(url string) bool {
u.mu.RLock()
@@ -97,29 +89,14 @@ func (u *URLAvailability) IsAvailable(url string) bool {
return time.Now().After(expiry)
}
// GetAvailableURLs 返回可用的 URL 列表
// 最近成功的 URL 优先,其他按默认顺序
// GetAvailableURLs 返回可用的 URL 列表(保持优先级顺序)
func (u *URLAvailability) GetAvailableURLs() []string {
u.mu.RLock()
defer u.mu.RUnlock()
now := time.Now()
result := make([]string, 0, len(BaseURLs))
// 如果有最近成功的 URL 且可用,放在最前面
if u.lastSuccess != "" {
expiry, exists := u.unavailable[u.lastSuccess]
if !exists || now.After(expiry) {
result = append(result, u.lastSuccess)
}
}
// 添加其他可用的 URL按默认顺序
for _, url := range BaseURLs {
// 跳过已添加的 lastSuccess
if url == u.lastSuccess {
continue
}
expiry, exists := u.unavailable[url]
if !exists || now.After(expiry) {
result = append(result, url)
@@ -263,3 +240,24 @@ func BuildAuthorizationURL(state, codeChallenge string) string {
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
// GenerateMockProjectID 生成随机 project_id当 API 不返回时使用)
// 格式:{形容词}-{名词}-{5位随机字符}
func GenerateMockProjectID() string {
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
randBytes, _ := GenerateRandomBytes(7)
adj := adjectives[int(randBytes[0])%len(adjectives)]
noun := nouns[int(randBytes[1])%len(nouns)]
// 生成 5 位随机字符a-z0-9
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
suffix := make([]byte, 5)
for i := 0; i < 5; i++ {
suffix[i] = charset[int(randBytes[i+2])%len(charset)]
}
return fmt.Sprintf("%s-%s-%s", adj, noun, string(suffix))
}

View File

@@ -54,9 +54,6 @@ func DefaultTransformOptions() TransformOptions {
}
}
// webSearchFallbackModel web_search 请求使用的降级模型
const webSearchFallbackModel = "gemini-2.5-flash"
// TransformClaudeToGemini 将 Claude 请求转换为 v1internal Gemini 格式
func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel string) ([]byte, error) {
return TransformClaudeToGeminiWithOptions(claudeReq, projectID, mappedModel, DefaultTransformOptions())
@@ -67,23 +64,12 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
// 用于存储 tool_use id -> name 映射
toolIDToName := make(map[string]string)
// 检测是否有 web_search 工具
hasWebSearchTool := hasWebSearchTool(claudeReq.Tools)
requestType := "agent"
targetModel := mappedModel
if hasWebSearchTool {
requestType = "web_search"
if targetModel != webSearchFallbackModel {
targetModel = webSearchFallbackModel
}
}
// 检测是否启用 thinking
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 只有 Gemini 模型支持 dummy thought workaround
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
allowDummyThought := strings.HasPrefix(targetModel, "gemini-")
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
// 1. 构建 contents
contents, strippedThinking, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
@@ -92,7 +78,7 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
}
// 2. 构建 systemInstruction
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts, claudeReq.Tools)
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts)
// 3. 构建 generationConfig
reqForConfig := claudeReq
@@ -103,11 +89,6 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
reqCopy.Thinking = nil
reqForConfig = &reqCopy
}
if targetModel != "" && targetModel != reqForConfig.Model {
reqCopy := *reqForConfig
reqCopy.Model = targetModel
reqForConfig = &reqCopy
}
generationConfig := buildGenerationConfig(reqForConfig)
// 4. 构建 tools
@@ -146,8 +127,8 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
Project: projectID,
RequestID: "agent-" + uuid.New().String(),
UserAgent: "antigravity", // 固定值,与官方客户端一致
RequestType: requestType,
Model: targetModel,
RequestType: "agent",
Model: mappedModel,
Request: innerRequest,
}
@@ -173,40 +154,8 @@ func GetDefaultIdentityPatch() string {
return antigravityIdentity
}
// mcpXMLProtocol MCP XML 工具调用协议(与 Antigravity-Manager 保持一致)
const mcpXMLProtocol = `
==== MCP XML 工具调用协议 (Workaround) ====
当你需要调用名称以 ` + "`mcp__`" + ` 开头的 MCP 工具时:
1) 优先尝试 XML 格式调用:输出 ` + "`<mcp__tool_name>{\"arg\":\"value\"}</mcp__tool_name>`" + `
2) 必须直接输出 XML 块,无需 markdown 包装,内容为 JSON 格式的入参。
3) 这种方式具有更高的连通性和容错性,适用于大型结果返回场景。
===========================================`
// hasMCPTools 检测是否有 mcp__ 前缀的工具
func hasMCPTools(tools []ClaudeTool) bool {
for _, tool := range tools {
if strings.HasPrefix(tool.Name, "mcp__") {
return true
}
}
return false
}
// filterOpenCodePrompt 过滤 OpenCode 默认提示词,只保留用户自定义指令
func filterOpenCodePrompt(text string) string {
if !strings.Contains(text, "You are an interactive CLI tool") {
return text
}
// 提取 "Instructions from:" 及之后的部分
if idx := strings.Index(text, "Instructions from:"); idx >= 0 {
return text[idx:]
}
// 如果没有自定义指令,返回空
return ""
}
// buildSystemInstruction 构建 systemInstruction与 Antigravity-Manager 保持一致)
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
// buildSystemInstruction 构建 systemInstruction
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions) *GeminiContent {
var parts []GeminiPart
// 先解析用户的 system prompt检测是否已包含 Antigravity identity
@@ -218,14 +167,10 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
var sysStr string
if err := json.Unmarshal(system, &sysStr); err == nil {
if strings.TrimSpace(sysStr) != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: sysStr})
if strings.Contains(sysStr, "You are Antigravity") {
userHasAntigravityIdentity = true
}
// 过滤 OpenCode 默认提示词
filtered := filterOpenCodePrompt(sysStr)
if filtered != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
}
}
} else {
// 尝试解析为数组
@@ -233,14 +178,10 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
if err := json.Unmarshal(system, &sysBlocks); err == nil {
for _, block := range sysBlocks {
if block.Type == "text" && strings.TrimSpace(block.Text) != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: block.Text})
if strings.Contains(block.Text, "You are Antigravity") {
userHasAntigravityIdentity = true
}
// 过滤 OpenCode 默认提示词
filtered := filterOpenCodePrompt(block.Text)
if filtered != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
}
}
}
}
@@ -259,16 +200,6 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
// 添加用户的 system prompt
parts = append(parts, userSystemParts...)
// 检测是否有 MCP 工具,如有则注入 XML 调用协议
if hasMCPTools(tools) {
parts = append(parts, GeminiPart{Text: mcpXMLProtocol})
}
// 如果用户没有提供 Antigravity 身份,添加结束标记
if !userHasAntigravityIdentity {
parts = append(parts, GeminiPart{Text: "\n--- [SYSTEM_PROMPT_END] ---"})
}
if len(parts) == 0 {
return nil
}
@@ -498,11 +429,6 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
StopSequences: DefaultStopSequences,
}
// 如果请求中指定了 MaxTokens使用请求值
if req.MaxTokens > 0 {
config.MaxOutputTokens = req.MaxTokens
}
// Thinking 配置
if req.Thinking != nil && req.Thinking.Type == "enabled" {
config.ThinkingConfig = &GeminiThinkingConfig{
@@ -532,43 +458,37 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
return config
}
func hasWebSearchTool(tools []ClaudeTool) bool {
for _, tool := range tools {
if isWebSearchTool(tool) {
return true
}
}
return false
}
func isWebSearchTool(tool ClaudeTool) bool {
if strings.HasPrefix(tool.Type, "web_search") || tool.Type == "google_search" {
return true
}
name := strings.TrimSpace(tool.Name)
switch name {
case "web_search", "google_search", "web_search_20250305":
return true
default:
return false
}
}
// buildTools 构建 tools
func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
if len(tools) == 0 {
return nil
}
hasWebSearch := hasWebSearchTool(tools)
// 检查是否有 web_search 工具
hasWebSearch := false
for _, tool := range tools {
if tool.Name == "web_search" {
hasWebSearch = true
break
}
}
if hasWebSearch {
// Web Search 工具映射
return []GeminiToolDeclaration{{
GoogleSearch: &GeminiGoogleSearch{
EnhancedContent: &GeminiEnhancedContent{
ImageSearch: &GeminiImageSearch{
MaxResultCount: 5,
},
},
},
}}
}
// 普通工具
var funcDecls []GeminiFunctionDecl
for _, tool := range tools {
if isWebSearchTool(tool) {
continue
}
// 跳过无效工具名称
if strings.TrimSpace(tool.Name) == "" {
log.Printf("Warning: skipping tool with empty name")
@@ -611,20 +531,7 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
}
if len(funcDecls) == 0 {
if !hasWebSearch {
return nil
}
// Web Search 工具映射
return []GeminiToolDeclaration{{
GoogleSearch: &GeminiGoogleSearch{
EnhancedContent: &GeminiEnhancedContent{
ImageSearch: &GeminiImageSearch{
MaxResultCount: 5,
},
},
},
}}
return nil
}
return []GeminiToolDeclaration{{

View File

@@ -3,7 +3,6 @@ package antigravity
import (
"encoding/json"
"fmt"
"strings"
)
// TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式)
@@ -64,12 +63,6 @@ func (p *NonStreamingProcessor) Process(geminiResp *GeminiResponse, responseID,
p.processPart(&part)
}
if len(geminiResp.Candidates) > 0 {
if grounding := geminiResp.Candidates[0].GroundingMetadata; grounding != nil {
p.processGrounding(grounding)
}
}
// 刷新剩余内容
p.flushThinking()
p.flushText()
@@ -197,18 +190,6 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
}
}
func (p *NonStreamingProcessor) processGrounding(grounding *GeminiGroundingMetadata) {
groundingText := buildGroundingText(grounding)
if groundingText == "" {
return
}
p.flushThinking()
p.flushText()
p.textBuilder += groundingText
p.flushText()
}
// flushText 刷新 text builder
func (p *NonStreamingProcessor) flushText() {
if p.textBuilder == "" {
@@ -281,44 +262,6 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
}
}
func buildGroundingText(grounding *GeminiGroundingMetadata) string {
if grounding == nil {
return ""
}
var builder strings.Builder
if len(grounding.WebSearchQueries) > 0 {
_, _ = builder.WriteString("\n\n---\nWeb search queries: ")
_, _ = builder.WriteString(strings.Join(grounding.WebSearchQueries, ", "))
}
if len(grounding.GroundingChunks) > 0 {
var links []string
for i, chunk := range grounding.GroundingChunks {
if chunk.Web == nil {
continue
}
title := strings.TrimSpace(chunk.Web.Title)
if title == "" {
title = "Source"
}
uri := strings.TrimSpace(chunk.Web.URI)
if uri == "" {
uri = "#"
}
links = append(links, fmt.Sprintf("[%d] [%s](%s)", i+1, title, uri))
}
if len(links) > 0 {
_, _ = builder.WriteString("\n\nSources:\n")
_, _ = builder.WriteString(strings.Join(links, "\n"))
}
}
return builder.String()
}
// generateRandomID 生成随机 ID
func generateRandomID() string {
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"

View File

@@ -27,8 +27,6 @@ type StreamingProcessor struct {
pendingSignature string
trailingSignature string
originalModel string
webSearchQueries []string
groundingChunks []GeminiGroundingChunk
// 累计 usage
inputTokens int
@@ -95,10 +93,6 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
}
}
if len(geminiResp.Candidates) > 0 {
p.captureGrounding(geminiResp.Candidates[0].GroundingMetadata)
}
// 检查是否结束
if len(geminiResp.Candidates) > 0 {
finishReason := geminiResp.Candidates[0].FinishReason
@@ -206,20 +200,6 @@ func (p *StreamingProcessor) processPart(part *GeminiPart) []byte {
return result.Bytes()
}
func (p *StreamingProcessor) captureGrounding(grounding *GeminiGroundingMetadata) {
if grounding == nil {
return
}
if len(grounding.WebSearchQueries) > 0 && len(p.webSearchQueries) == 0 {
p.webSearchQueries = append([]string(nil), grounding.WebSearchQueries...)
}
if len(grounding.GroundingChunks) > 0 && len(p.groundingChunks) == 0 {
p.groundingChunks = append([]GeminiGroundingChunk(nil), grounding.GroundingChunks...)
}
}
// processThinking 处理 thinking
func (p *StreamingProcessor) processThinking(text, signature string) []byte {
var result bytes.Buffer
@@ -437,23 +417,6 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
p.trailingSignature = ""
}
if len(p.webSearchQueries) > 0 || len(p.groundingChunks) > 0 {
groundingText := buildGroundingText(&GeminiGroundingMetadata{
WebSearchQueries: p.webSearchQueries,
GroundingChunks: p.groundingChunks,
})
if groundingText != "" {
_, _ = result.Write(p.startBlock(BlockTypeText, map[string]any{
"type": "text",
"text": "",
}))
_, _ = result.Write(p.emitDelta("text_delta", map[string]any{
"text": groundingText,
}))
_, _ = result.Write(p.endBlock())
}
}
// 确定 stop_reason
stopReason := "end_turn"
if p.usedTool {

View File

@@ -7,14 +7,7 @@ type Key string
const (
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
ForcePlatform Key = "ctx_force_platform"
// ClientRequestID 客户端请求的唯一标识,用于追踪请求全生命周期(用于 Ops 监控与排障)。
ClientRequestID Key = "ctx_client_request_id"
// RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。
RetryCount Key = "ctx_retry_count"
// IsClaudeCodeClient 标识当前请求是否来自 Claude Code 客户端
// IsClaudeCodeClient 是否为 Claude Code 客户端,由中间件设置
IsClaudeCodeClient Key = "ctx_is_claude_code_client"
// Group 认证后的分组信息,由 API Key 认证中间件设置
Group Key = "ctx_group"

View File

@@ -16,11 +16,14 @@ type ModelsListResponse struct {
func DefaultModels() []Model {
methods := []string{"generateContent", "streamGenerateContent"}
return []Model{
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-pro", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-flash-8b", SupportedGenerationMethods: methods},
}
}

View File

@@ -12,10 +12,10 @@ type Model struct {
// DefaultModels is the curated Gemini model list used by the admin UI "test account" flow.
var DefaultModels = []Model{
{ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""},
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
}
// DefaultTestModel is the default model to preselect in test flows.

View File

@@ -162,11 +162,11 @@ func ParsePagination(c *gin.Context) (page, pageSize int) {
// 支持 page_size 和 limit 两种参数名
if ps := c.Query("page_size"); ps != "" {
if val, err := parseInt(ps); err == nil && val > 0 && val <= 1000 {
if val, err := parseInt(ps); err == nil && val > 0 && val <= 100 {
pageSize = val
}
} else if l := c.Query("limit"); l != "" {
if val, err := parseInt(l); err == nil && val > 0 && val <= 1000 {
if val, err := parseInt(l); err == nil && val > 0 && val <= 100 {
pageSize = val
}
}

View File

@@ -1,568 +0,0 @@
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
// It uses the utls library to create TLS connections that mimic Node.js/Claude Code clients.
package tlsfingerprint
import (
"bufio"
"context"
"encoding/base64"
"fmt"
"log/slog"
"net"
"net/http"
"net/url"
utls "github.com/refraction-networking/utls"
"golang.org/x/net/proxy"
)
// Profile contains TLS fingerprint configuration.
type Profile struct {
Name string // Profile name for identification
CipherSuites []uint16
Curves []uint16
PointFormats []uint8
EnableGREASE bool
}
// Dialer creates TLS connections with custom fingerprints.
type Dialer struct {
profile *Profile
baseDialer func(ctx context.Context, network, addr string) (net.Conn, error)
}
// HTTPProxyDialer creates TLS connections through HTTP/HTTPS proxies with custom fingerprints.
// It handles the CONNECT tunnel establishment before performing TLS handshake.
type HTTPProxyDialer struct {
profile *Profile
proxyURL *url.URL
}
// SOCKS5ProxyDialer creates TLS connections through SOCKS5 proxies with custom fingerprints.
// It uses golang.org/x/net/proxy to establish the SOCKS5 tunnel.
type SOCKS5ProxyDialer struct {
profile *Profile
proxyURL *url.URL
}
// Default TLS fingerprint values captured from Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)
// Captured using: tshark -i lo -f "tcp port 8443" -Y "tls.handshake.type == 1" -V
// JA3 Hash: 1a28e69016765d92e3b381168d68922c
//
// Note: JA3/JA4 may have slight variations due to:
// - Session ticket presence/absence
// - Extension negotiation state
var (
// defaultCipherSuites contains all 59 cipher suites from Claude CLI
// Order is critical for JA3 fingerprint matching
defaultCipherSuites = []uint16{
// TLS 1.3 cipher suites (MUST be first)
0x1302, // TLS_AES_256_GCM_SHA384
0x1303, // TLS_CHACHA20_POLY1305_SHA256
0x1301, // TLS_AES_128_GCM_SHA256
// ECDHE + AES-GCM
0xc02f, // TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
0xc02b, // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
0xc030, // TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
0xc02c, // TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
// DHE + AES-GCM
0x009e, // TLS_DHE_RSA_WITH_AES_128_GCM_SHA256
// ECDHE/DHE + AES-CBC-SHA256/384
0xc027, // TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256
0x0067, // TLS_DHE_RSA_WITH_AES_128_CBC_SHA256
0xc028, // TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384
0x006b, // TLS_DHE_RSA_WITH_AES_256_CBC_SHA256
// DHE-DSS/RSA + AES-GCM
0x00a3, // TLS_DHE_DSS_WITH_AES_256_GCM_SHA384
0x009f, // TLS_DHE_RSA_WITH_AES_256_GCM_SHA384
// ChaCha20-Poly1305
0xcca9, // TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
0xcca8, // TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
0xccaa, // TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256
// AES-CCM (256-bit)
0xc0af, // TLS_ECDHE_ECDSA_WITH_AES_256_CCM_8
0xc0ad, // TLS_ECDHE_ECDSA_WITH_AES_256_CCM
0xc0a3, // TLS_DHE_RSA_WITH_AES_256_CCM_8
0xc09f, // TLS_DHE_RSA_WITH_AES_256_CCM
// ARIA (256-bit)
0xc05d, // TLS_ECDHE_ECDSA_WITH_ARIA_256_GCM_SHA384
0xc061, // TLS_ECDHE_RSA_WITH_ARIA_256_GCM_SHA384
0xc057, // TLS_DHE_DSS_WITH_ARIA_256_GCM_SHA384
0xc053, // TLS_DHE_RSA_WITH_ARIA_256_GCM_SHA384
// DHE-DSS + AES-GCM (128-bit)
0x00a2, // TLS_DHE_DSS_WITH_AES_128_GCM_SHA256
// AES-CCM (128-bit)
0xc0ae, // TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8
0xc0ac, // TLS_ECDHE_ECDSA_WITH_AES_128_CCM
0xc0a2, // TLS_DHE_RSA_WITH_AES_128_CCM_8
0xc09e, // TLS_DHE_RSA_WITH_AES_128_CCM
// ARIA (128-bit)
0xc05c, // TLS_ECDHE_ECDSA_WITH_ARIA_128_GCM_SHA256
0xc060, // TLS_ECDHE_RSA_WITH_ARIA_128_GCM_SHA256
0xc056, // TLS_DHE_DSS_WITH_ARIA_128_GCM_SHA256
0xc052, // TLS_DHE_RSA_WITH_ARIA_128_GCM_SHA256
// ECDHE/DHE + AES-CBC-SHA384/256 (more)
0xc024, // TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384
0x006a, // TLS_DHE_DSS_WITH_AES_256_CBC_SHA256
0xc023, // TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256
0x0040, // TLS_DHE_DSS_WITH_AES_128_CBC_SHA256
// ECDHE/DHE + AES-CBC-SHA (legacy)
0xc00a, // TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA
0xc014, // TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA
0x0039, // TLS_DHE_RSA_WITH_AES_256_CBC_SHA
0x0038, // TLS_DHE_DSS_WITH_AES_256_CBC_SHA
0xc009, // TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA
0xc013, // TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA
0x0033, // TLS_DHE_RSA_WITH_AES_128_CBC_SHA
0x0032, // TLS_DHE_DSS_WITH_AES_128_CBC_SHA
// RSA + AES-GCM/CCM/ARIA (non-PFS, 256-bit)
0x009d, // TLS_RSA_WITH_AES_256_GCM_SHA384
0xc0a1, // TLS_RSA_WITH_AES_256_CCM_8
0xc09d, // TLS_RSA_WITH_AES_256_CCM
0xc051, // TLS_RSA_WITH_ARIA_256_GCM_SHA384
// RSA + AES-GCM/CCM/ARIA (non-PFS, 128-bit)
0x009c, // TLS_RSA_WITH_AES_128_GCM_SHA256
0xc0a0, // TLS_RSA_WITH_AES_128_CCM_8
0xc09c, // TLS_RSA_WITH_AES_128_CCM
0xc050, // TLS_RSA_WITH_ARIA_128_GCM_SHA256
// RSA + AES-CBC (non-PFS, legacy)
0x003d, // TLS_RSA_WITH_AES_256_CBC_SHA256
0x003c, // TLS_RSA_WITH_AES_128_CBC_SHA256
0x0035, // TLS_RSA_WITH_AES_256_CBC_SHA
0x002f, // TLS_RSA_WITH_AES_128_CBC_SHA
// Renegotiation indication
0x00ff, // TLS_EMPTY_RENEGOTIATION_INFO_SCSV
}
// defaultCurves contains the 10 supported groups from Claude CLI (including FFDHE)
defaultCurves = []utls.CurveID{
utls.X25519, // 0x001d
utls.CurveP256, // 0x0017 (secp256r1)
utls.CurveID(0x001e), // x448
utls.CurveP521, // 0x0019 (secp521r1)
utls.CurveP384, // 0x0018 (secp384r1)
utls.CurveID(0x0100), // ffdhe2048
utls.CurveID(0x0101), // ffdhe3072
utls.CurveID(0x0102), // ffdhe4096
utls.CurveID(0x0103), // ffdhe6144
utls.CurveID(0x0104), // ffdhe8192
}
// defaultPointFormats contains all 3 point formats from Claude CLI
defaultPointFormats = []uint8{
0, // uncompressed
1, // ansiX962_compressed_prime
2, // ansiX962_compressed_char2
}
// defaultSignatureAlgorithms contains the 20 signature algorithms from Claude CLI
defaultSignatureAlgorithms = []utls.SignatureScheme{
0x0403, // ecdsa_secp256r1_sha256
0x0503, // ecdsa_secp384r1_sha384
0x0603, // ecdsa_secp521r1_sha512
0x0807, // ed25519
0x0808, // ed448
0x0809, // rsa_pss_pss_sha256
0x080a, // rsa_pss_pss_sha384
0x080b, // rsa_pss_pss_sha512
0x0804, // rsa_pss_rsae_sha256
0x0805, // rsa_pss_rsae_sha384
0x0806, // rsa_pss_rsae_sha512
0x0401, // rsa_pkcs1_sha256
0x0501, // rsa_pkcs1_sha384
0x0601, // rsa_pkcs1_sha512
0x0303, // ecdsa_sha224
0x0301, // rsa_pkcs1_sha224
0x0302, // dsa_sha224
0x0402, // dsa_sha256
0x0502, // dsa_sha384
0x0602, // dsa_sha512
}
)
// NewDialer creates a new TLS fingerprint dialer.
// baseDialer is used for TCP connection establishment (supports proxy scenarios).
// If baseDialer is nil, direct TCP dial is used.
func NewDialer(profile *Profile, baseDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *Dialer {
if baseDialer == nil {
baseDialer = (&net.Dialer{}).DialContext
}
return &Dialer{profile: profile, baseDialer: baseDialer}
}
// NewHTTPProxyDialer creates a new TLS fingerprint dialer that works through HTTP/HTTPS proxies.
// It establishes a CONNECT tunnel before performing TLS handshake with custom fingerprint.
func NewHTTPProxyDialer(profile *Profile, proxyURL *url.URL) *HTTPProxyDialer {
return &HTTPProxyDialer{profile: profile, proxyURL: proxyURL}
}
// NewSOCKS5ProxyDialer creates a new TLS fingerprint dialer that works through SOCKS5 proxies.
// It establishes a SOCKS5 tunnel before performing TLS handshake with custom fingerprint.
func NewSOCKS5ProxyDialer(profile *Profile, proxyURL *url.URL) *SOCKS5ProxyDialer {
return &SOCKS5ProxyDialer{profile: profile, proxyURL: proxyURL}
}
// DialTLSContext establishes a TLS connection through SOCKS5 proxy with the configured fingerprint.
// Flow: SOCKS5 CONNECT to target -> TLS handshake with utls on the tunnel
func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
slog.Debug("tls_fingerprint_socks5_connecting", "proxy", d.proxyURL.Host, "target", addr)
// Step 1: Create SOCKS5 dialer
var auth *proxy.Auth
if d.proxyURL.User != nil {
username := d.proxyURL.User.Username()
password, _ := d.proxyURL.User.Password()
auth = &proxy.Auth{
User: username,
Password: password,
}
}
// Determine proxy address
proxyAddr := d.proxyURL.Host
if d.proxyURL.Port() == "" {
proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "1080") // Default SOCKS5 port
}
socksDialer, err := proxy.SOCKS5("tcp", proxyAddr, auth, proxy.Direct)
if err != nil {
slog.Debug("tls_fingerprint_socks5_dialer_failed", "error", err)
return nil, fmt.Errorf("create SOCKS5 dialer: %w", err)
}
// Step 2: Establish SOCKS5 tunnel to target
slog.Debug("tls_fingerprint_socks5_establishing_tunnel", "target", addr)
conn, err := socksDialer.Dial("tcp", addr)
if err != nil {
slog.Debug("tls_fingerprint_socks5_connect_failed", "error", err)
return nil, fmt.Errorf("SOCKS5 connect: %w", err)
}
slog.Debug("tls_fingerprint_socks5_tunnel_established")
// Step 3: Perform TLS handshake on the tunnel with utls fingerprint
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
slog.Debug("tls_fingerprint_socks5_starting_handshake", "host", host)
// Build ClientHello specification from profile (Node.js/Claude CLI fingerprint)
spec := buildClientHelloSpecFromProfile(d.profile)
slog.Debug("tls_fingerprint_socks5_clienthello_spec",
"cipher_suites", len(spec.CipherSuites),
"extensions", len(spec.Extensions),
"compression_methods", spec.CompressionMethods,
"tls_vers_max", fmt.Sprintf("0x%04x", spec.TLSVersMax),
"tls_vers_min", fmt.Sprintf("0x%04x", spec.TLSVersMin))
if d.profile != nil {
slog.Debug("tls_fingerprint_socks5_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
}
// Create uTLS connection on the tunnel
tlsConn := utls.UClient(conn, &utls.Config{
ServerName: host,
}, utls.HelloCustom)
if err := tlsConn.ApplyPreset(spec); err != nil {
slog.Debug("tls_fingerprint_socks5_apply_preset_failed", "error", err)
_ = conn.Close()
return nil, fmt.Errorf("apply TLS preset: %w", err)
}
if err := tlsConn.Handshake(); err != nil {
slog.Debug("tls_fingerprint_socks5_handshake_failed", "error", err)
_ = conn.Close()
return nil, fmt.Errorf("TLS handshake failed: %w", err)
}
state := tlsConn.ConnectionState()
slog.Debug("tls_fingerprint_socks5_handshake_success",
"version", fmt.Sprintf("0x%04x", state.Version),
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
"alpn", state.NegotiatedProtocol)
return tlsConn, nil
}
// DialTLSContext establishes a TLS connection through HTTP proxy with the configured fingerprint.
// Flow: TCP connect to proxy -> CONNECT tunnel -> TLS handshake with utls
func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
slog.Debug("tls_fingerprint_http_proxy_connecting", "proxy", d.proxyURL.Host, "target", addr)
// Step 1: TCP connect to proxy server
var proxyAddr string
if d.proxyURL.Port() != "" {
proxyAddr = d.proxyURL.Host
} else {
// Default ports
if d.proxyURL.Scheme == "https" {
proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "443")
} else {
proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "80")
}
}
dialer := &net.Dialer{}
conn, err := dialer.DialContext(ctx, "tcp", proxyAddr)
if err != nil {
slog.Debug("tls_fingerprint_http_proxy_connect_failed", "error", err)
return nil, fmt.Errorf("connect to proxy: %w", err)
}
slog.Debug("tls_fingerprint_http_proxy_connected", "proxy_addr", proxyAddr)
// Step 2: Send CONNECT request to establish tunnel
req := &http.Request{
Method: "CONNECT",
URL: &url.URL{Opaque: addr},
Host: addr,
Header: make(http.Header),
}
// Add proxy authentication if present
if d.proxyURL.User != nil {
username := d.proxyURL.User.Username()
password, _ := d.proxyURL.User.Password()
auth := base64.StdEncoding.EncodeToString([]byte(username + ":" + password))
req.Header.Set("Proxy-Authorization", "Basic "+auth)
}
slog.Debug("tls_fingerprint_http_proxy_sending_connect", "target", addr)
if err := req.Write(conn); err != nil {
_ = conn.Close()
slog.Debug("tls_fingerprint_http_proxy_write_failed", "error", err)
return nil, fmt.Errorf("write CONNECT request: %w", err)
}
// Step 3: Read CONNECT response
br := bufio.NewReader(conn)
resp, err := http.ReadResponse(br, req)
if err != nil {
_ = conn.Close()
slog.Debug("tls_fingerprint_http_proxy_read_response_failed", "error", err)
return nil, fmt.Errorf("read CONNECT response: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
_ = conn.Close()
slog.Debug("tls_fingerprint_http_proxy_connect_failed_status", "status_code", resp.StatusCode, "status", resp.Status)
return nil, fmt.Errorf("proxy CONNECT failed: %s", resp.Status)
}
slog.Debug("tls_fingerprint_http_proxy_tunnel_established")
// Step 4: Perform TLS handshake on the tunnel with utls fingerprint
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
slog.Debug("tls_fingerprint_http_proxy_starting_handshake", "host", host)
// Build ClientHello specification (reuse the shared method)
spec := buildClientHelloSpecFromProfile(d.profile)
slog.Debug("tls_fingerprint_http_proxy_clienthello_spec",
"cipher_suites", len(spec.CipherSuites),
"extensions", len(spec.Extensions))
if d.profile != nil {
slog.Debug("tls_fingerprint_http_proxy_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
}
// Create uTLS connection on the tunnel
// Note: TLS 1.3 cipher suites are handled automatically by utls when TLS 1.3 is in SupportedVersions
tlsConn := utls.UClient(conn, &utls.Config{
ServerName: host,
}, utls.HelloCustom)
if err := tlsConn.ApplyPreset(spec); err != nil {
slog.Debug("tls_fingerprint_http_proxy_apply_preset_failed", "error", err)
_ = conn.Close()
return nil, fmt.Errorf("apply TLS preset: %w", err)
}
if err := tlsConn.HandshakeContext(ctx); err != nil {
slog.Debug("tls_fingerprint_http_proxy_handshake_failed", "error", err)
_ = conn.Close()
return nil, fmt.Errorf("TLS handshake failed: %w", err)
}
state := tlsConn.ConnectionState()
slog.Debug("tls_fingerprint_http_proxy_handshake_success",
"version", fmt.Sprintf("0x%04x", state.Version),
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
"alpn", state.NegotiatedProtocol)
return tlsConn, nil
}
// DialTLSContext establishes a TLS connection with the configured fingerprint.
// This method is designed to be used as http.Transport.DialTLSContext.
func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
// Establish TCP connection using base dialer (supports proxy)
slog.Debug("tls_fingerprint_dialing_tcp", "addr", addr)
conn, err := d.baseDialer(ctx, network, addr)
if err != nil {
slog.Debug("tls_fingerprint_tcp_dial_failed", "error", err)
return nil, err
}
slog.Debug("tls_fingerprint_tcp_connected", "addr", addr)
// Extract hostname for SNI
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
slog.Debug("tls_fingerprint_sni_hostname", "host", host)
// Build ClientHello specification
spec := d.buildClientHelloSpec()
slog.Debug("tls_fingerprint_clienthello_spec",
"cipher_suites", len(spec.CipherSuites),
"extensions", len(spec.Extensions))
// Log profile info
if d.profile != nil {
slog.Debug("tls_fingerprint_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
} else {
slog.Debug("tls_fingerprint_using_default_profile")
}
// Create uTLS connection
// Note: TLS 1.3 cipher suites are handled automatically by utls when TLS 1.3 is in SupportedVersions
tlsConn := utls.UClient(conn, &utls.Config{
ServerName: host,
}, utls.HelloCustom)
// Apply fingerprint
if err := tlsConn.ApplyPreset(spec); err != nil {
slog.Debug("tls_fingerprint_apply_preset_failed", "error", err)
_ = conn.Close()
return nil, err
}
slog.Debug("tls_fingerprint_preset_applied")
// Perform TLS handshake
if err := tlsConn.HandshakeContext(ctx); err != nil {
slog.Debug("tls_fingerprint_handshake_failed",
"error", err,
"local_addr", conn.LocalAddr(),
"remote_addr", conn.RemoteAddr())
_ = conn.Close()
return nil, fmt.Errorf("TLS handshake failed: %w", err)
}
// Log successful handshake details
state := tlsConn.ConnectionState()
slog.Debug("tls_fingerprint_handshake_success",
"version", fmt.Sprintf("0x%04x", state.Version),
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
"alpn", state.NegotiatedProtocol)
return tlsConn, nil
}
// buildClientHelloSpec constructs the ClientHello specification based on the profile.
func (d *Dialer) buildClientHelloSpec() *utls.ClientHelloSpec {
return buildClientHelloSpecFromProfile(d.profile)
}
// toUTLSCurves converts uint16 slice to utls.CurveID slice.
func toUTLSCurves(curves []uint16) []utls.CurveID {
result := make([]utls.CurveID, len(curves))
for i, c := range curves {
result[i] = utls.CurveID(c)
}
return result
}
// buildClientHelloSpecFromProfile constructs ClientHelloSpec from a Profile.
// This is a standalone function that can be used by both Dialer and HTTPProxyDialer.
func buildClientHelloSpecFromProfile(profile *Profile) *utls.ClientHelloSpec {
// Get cipher suites
var cipherSuites []uint16
if profile != nil && len(profile.CipherSuites) > 0 {
cipherSuites = profile.CipherSuites
} else {
cipherSuites = defaultCipherSuites
}
// Get curves
var curves []utls.CurveID
if profile != nil && len(profile.Curves) > 0 {
curves = toUTLSCurves(profile.Curves)
} else {
curves = defaultCurves
}
// Get point formats
var pointFormats []uint8
if profile != nil && len(profile.PointFormats) > 0 {
pointFormats = profile.PointFormats
} else {
pointFormats = defaultPointFormats
}
// Check if GREASE is enabled
enableGREASE := profile != nil && profile.EnableGREASE
extensions := make([]utls.TLSExtension, 0, 16)
if enableGREASE {
extensions = append(extensions, &utls.UtlsGREASEExtension{})
}
// SNI extension - MUST be explicitly added for HelloCustom mode
// utls will populate the server name from Config.ServerName
extensions = append(extensions, &utls.SNIExtension{})
// Claude CLI extension order (captured from tshark):
// server_name(0), ec_point_formats(11), supported_groups(10), session_ticket(35),
// alpn(16), encrypt_then_mac(22), extended_master_secret(23),
// signature_algorithms(13), supported_versions(43),
// psk_key_exchange_modes(45), key_share(51)
extensions = append(extensions,
&utls.SupportedPointsExtension{SupportedPoints: pointFormats},
&utls.SupportedCurvesExtension{Curves: curves},
&utls.SessionTicketExtension{},
&utls.ALPNExtension{AlpnProtocols: []string{"http/1.1"}},
&utls.GenericExtension{Id: 22},
&utls.ExtendedMasterSecretExtension{},
&utls.SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: defaultSignatureAlgorithms},
&utls.SupportedVersionsExtension{Versions: []uint16{
utls.VersionTLS13,
utls.VersionTLS12,
}},
&utls.PSKKeyExchangeModesExtension{Modes: []uint8{utls.PskModeDHE}},
&utls.KeyShareExtension{KeyShares: []utls.KeyShare{
{Group: utls.X25519},
}},
)
if enableGREASE {
extensions = append(extensions, &utls.UtlsGREASEExtension{})
}
return &utls.ClientHelloSpec{
CipherSuites: cipherSuites,
CompressionMethods: []uint8{0}, // null compression only (standard)
Extensions: extensions,
TLSVersMax: utls.VersionTLS13,
TLSVersMin: utls.VersionTLS10,
}
}

View File

@@ -1,307 +0,0 @@
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
//
// Integration tests for verifying TLS fingerprint correctness.
// These tests make actual network requests and should be run manually.
//
// Run with: go test -v ./internal/pkg/tlsfingerprint/...
// Run integration tests: go test -v -run TestJA3 ./internal/pkg/tlsfingerprint/...
package tlsfingerprint
import (
"context"
"encoding/json"
"io"
"net/http"
"net/url"
"strings"
"testing"
"time"
)
// FingerprintResponse represents the response from tls.peet.ws/api/all.
type FingerprintResponse struct {
IP string `json:"ip"`
TLS TLSInfo `json:"tls"`
HTTP2 any `json:"http2"`
}
// TLSInfo contains TLS fingerprint details.
type TLSInfo struct {
JA3 string `json:"ja3"`
JA3Hash string `json:"ja3_hash"`
JA4 string `json:"ja4"`
PeetPrint string `json:"peetprint"`
PeetPrintHash string `json:"peetprint_hash"`
ClientRandom string `json:"client_random"`
SessionID string `json:"session_id"`
}
// TestDialerBasicConnection tests that the dialer can establish TLS connections.
func TestDialerBasicConnection(t *testing.T) {
if testing.Short() {
t.Skip("skipping network test in short mode")
}
// Create a dialer with default profile
profile := &Profile{
Name: "Test Profile",
EnableGREASE: false,
}
dialer := NewDialer(profile, nil)
// Create HTTP client with custom TLS dialer
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
// Make a request to a known HTTPS endpoint
resp, err := client.Get("https://www.google.com")
if err != nil {
t.Fatalf("failed to connect: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
// This test uses tls.peet.ws to verify the fingerprint.
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
func TestJA3Fingerprint(t *testing.T) {
// Skip if network is unavailable or if running in short mode
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
profile := &Profile{
Name: "Claude CLI Test",
EnableGREASE: false,
}
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
// Use tls.peet.ws fingerprint detection API
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("failed to get fingerprint: %v", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
}
// Log all fingerprint information
t.Logf("JA3: %s", fpResp.TLS.JA3)
t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
t.Logf("JA4: %s", fpResp.TLS.JA4)
t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
// Verify JA3 hash matches expected value
expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
if fpResp.TLS.JA3Hash == expectedJA3Hash {
t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
} else {
t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
}
// Verify JA4 fingerprint
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
// The suffix _a33745022dd6_1f22a2ca17c4 should match
expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
} else {
t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
}
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
// d = domain (SNI present), i = IP (no SNI)
// Since we connect to tls.peet.ws (domain), we expect 'd'
expectedJA4Prefix := "t13d5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
} else {
// Also accept 'i' variant for IP connections
altPrefix := "t13i5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
} else {
t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
}
}
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
} else {
t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
}
// Verify extension list (should be 11 extensions including SNI)
// Expected: 0-11-10-35-16-22-23-13-43-45-51
expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
} else {
t.Logf("Warning: JA3 extension list may differ")
}
}
// TestDialerWithProfile tests that different profiles produce different fingerprints.
func TestDialerWithProfile(t *testing.T) {
// Create two dialers with different profiles
profile1 := &Profile{
Name: "Profile 1 - No GREASE",
EnableGREASE: false,
}
profile2 := &Profile{
Name: "Profile 2 - With GREASE",
EnableGREASE: true,
}
dialer1 := NewDialer(profile1, nil)
dialer2 := NewDialer(profile2, nil)
// Build specs and compare
// Note: We can't directly compare JA3 without making network requests
// but we can verify the specs are different
spec1 := dialer1.buildClientHelloSpec()
spec2 := dialer2.buildClientHelloSpec()
// Profile with GREASE should have more extensions
if len(spec2.Extensions) <= len(spec1.Extensions) {
t.Error("expected GREASE profile to have more extensions")
}
}
// TestHTTPProxyDialerBasic tests HTTP proxy dialer creation.
// Note: This is a unit test - actual proxy testing requires a proxy server.
func TestHTTPProxyDialerBasic(t *testing.T) {
profile := &Profile{
Name: "Test Profile",
EnableGREASE: false,
}
// Test that dialer is created without panic
proxyURL := mustParseURL("http://proxy.example.com:8080")
dialer := NewHTTPProxyDialer(profile, proxyURL)
if dialer == nil {
t.Fatal("expected dialer to be created")
}
if dialer.profile != profile {
t.Error("expected profile to be set")
}
if dialer.proxyURL != proxyURL {
t.Error("expected proxyURL to be set")
}
}
// TestSOCKS5ProxyDialerBasic tests SOCKS5 proxy dialer creation.
// Note: This is a unit test - actual proxy testing requires a proxy server.
func TestSOCKS5ProxyDialerBasic(t *testing.T) {
profile := &Profile{
Name: "Test Profile",
EnableGREASE: false,
}
// Test that dialer is created without panic
proxyURL := mustParseURL("socks5://proxy.example.com:1080")
dialer := NewSOCKS5ProxyDialer(profile, proxyURL)
if dialer == nil {
t.Fatal("expected dialer to be created")
}
if dialer.profile != profile {
t.Error("expected profile to be set")
}
if dialer.proxyURL != proxyURL {
t.Error("expected proxyURL to be set")
}
}
// TestBuildClientHelloSpec tests ClientHello spec construction.
func TestBuildClientHelloSpec(t *testing.T) {
// Test with nil profile (should use defaults)
spec := buildClientHelloSpecFromProfile(nil)
if len(spec.CipherSuites) == 0 {
t.Error("expected cipher suites to be set")
}
if len(spec.Extensions) == 0 {
t.Error("expected extensions to be set")
}
// Verify default cipher suites are used
if len(spec.CipherSuites) != len(defaultCipherSuites) {
t.Errorf("expected %d cipher suites, got %d", len(defaultCipherSuites), len(spec.CipherSuites))
}
// Test with custom profile
customProfile := &Profile{
Name: "Custom",
EnableGREASE: false,
CipherSuites: []uint16{0x1301, 0x1302},
}
spec = buildClientHelloSpecFromProfile(customProfile)
if len(spec.CipherSuites) != 2 {
t.Errorf("expected 2 cipher suites, got %d", len(spec.CipherSuites))
}
}
// TestToUTLSCurves tests curve ID conversion.
func TestToUTLSCurves(t *testing.T) {
input := []uint16{0x001d, 0x0017, 0x0018}
result := toUTLSCurves(input)
if len(result) != len(input) {
t.Errorf("expected %d curves, got %d", len(input), len(result))
}
for i, curve := range result {
if uint16(curve) != input[i] {
t.Errorf("curve %d: expected 0x%04x, got 0x%04x", i, input[i], uint16(curve))
}
}
}
// Helper function to parse URL without error handling.
func mustParseURL(rawURL string) *url.URL {
u, err := url.Parse(rawURL)
if err != nil {
panic(err)
}
return u
}

View File

@@ -1,171 +0,0 @@
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
package tlsfingerprint
import (
"log/slog"
"sort"
"sync"
"github.com/Wei-Shaw/sub2api/internal/config"
)
// DefaultProfileName is the name of the built-in Claude CLI profile.
const DefaultProfileName = "claude_cli_v2"
// Registry manages TLS fingerprint profiles.
// It holds a collection of profiles that can be used for TLS fingerprint simulation.
// Profiles are selected based on account ID using modulo operation.
type Registry struct {
mu sync.RWMutex
profiles map[string]*Profile
profileNames []string // Sorted list of profile names for deterministic selection
}
// NewRegistry creates a new TLS fingerprint profile registry.
// It initializes with the built-in default profile.
func NewRegistry() *Registry {
r := &Registry{
profiles: make(map[string]*Profile),
profileNames: make([]string, 0),
}
// Register the built-in default profile
r.registerBuiltinProfile()
return r
}
// NewRegistryFromConfig creates a new registry and loads profiles from config.
// If the config has custom profiles defined, they will be merged with the built-in default.
func NewRegistryFromConfig(cfg *config.TLSFingerprintConfig) *Registry {
r := NewRegistry()
if cfg == nil || !cfg.Enabled {
slog.Debug("tls_registry_disabled", "reason", "disabled or no config")
return r
}
// Load custom profiles from config
for name, profileCfg := range cfg.Profiles {
profile := &Profile{
Name: profileCfg.Name,
EnableGREASE: profileCfg.EnableGREASE,
CipherSuites: profileCfg.CipherSuites,
Curves: profileCfg.Curves,
PointFormats: profileCfg.PointFormats,
}
// If the profile has empty values, they will use defaults in dialer
r.RegisterProfile(name, profile)
slog.Debug("tls_registry_loaded_profile", "key", name, "name", profileCfg.Name)
}
slog.Debug("tls_registry_initialized", "profile_count", len(r.profileNames), "profiles", r.profileNames)
return r
}
// registerBuiltinProfile adds the default Claude CLI profile to the registry.
func (r *Registry) registerBuiltinProfile() {
defaultProfile := &Profile{
Name: "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)",
EnableGREASE: false, // Node.js does not use GREASE
// Empty slices will cause dialer to use built-in defaults
CipherSuites: nil,
Curves: nil,
PointFormats: nil,
}
r.RegisterProfile(DefaultProfileName, defaultProfile)
}
// RegisterProfile adds or updates a profile in the registry.
func (r *Registry) RegisterProfile(name string, profile *Profile) {
r.mu.Lock()
defer r.mu.Unlock()
// Check if this is a new profile
_, exists := r.profiles[name]
r.profiles[name] = profile
if !exists {
r.profileNames = append(r.profileNames, name)
// Keep names sorted for deterministic selection
sort.Strings(r.profileNames)
}
}
// GetProfile returns a profile by name.
// Returns nil if the profile does not exist.
func (r *Registry) GetProfile(name string) *Profile {
r.mu.RLock()
defer r.mu.RUnlock()
return r.profiles[name]
}
// GetDefaultProfile returns the built-in default profile.
func (r *Registry) GetDefaultProfile() *Profile {
return r.GetProfile(DefaultProfileName)
}
// GetProfileByAccountID returns a profile for the given account ID.
// The profile is selected using: profileNames[accountID % len(profiles)]
// This ensures deterministic profile assignment for each account.
func (r *Registry) GetProfileByAccountID(accountID int64) *Profile {
r.mu.RLock()
defer r.mu.RUnlock()
if len(r.profileNames) == 0 {
return nil
}
// Use modulo to select profile index
// Use absolute value to handle negative IDs (though unlikely)
idx := accountID
if idx < 0 {
idx = -idx
}
selectedIndex := int(idx % int64(len(r.profileNames)))
selectedName := r.profileNames[selectedIndex]
return r.profiles[selectedName]
}
// ProfileCount returns the number of registered profiles.
func (r *Registry) ProfileCount() int {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.profiles)
}
// ProfileNames returns a sorted list of all registered profile names.
func (r *Registry) ProfileNames() []string {
r.mu.RLock()
defer r.mu.RUnlock()
// Return a copy to prevent modification
names := make([]string, len(r.profileNames))
copy(names, r.profileNames)
return names
}
// Global registry instance for convenience
var globalRegistry *Registry
var globalRegistryOnce sync.Once
// GlobalRegistry returns the global TLS fingerprint registry.
// The registry is lazily initialized with the default profile.
func GlobalRegistry() *Registry {
globalRegistryOnce.Do(func() {
globalRegistry = NewRegistry()
})
return globalRegistry
}
// InitGlobalRegistry initializes the global registry with configuration.
// This should be called during application startup.
// It is safe to call multiple times; subsequent calls will update the registry.
func InitGlobalRegistry(cfg *config.TLSFingerprintConfig) *Registry {
globalRegistryOnce.Do(func() {
globalRegistry = NewRegistryFromConfig(cfg)
})
return globalRegistry
}

View File

@@ -1,243 +0,0 @@
package tlsfingerprint
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
)
func TestNewRegistry(t *testing.T) {
r := NewRegistry()
// Should have exactly one profile (the default)
if r.ProfileCount() != 1 {
t.Errorf("expected 1 profile, got %d", r.ProfileCount())
}
// Should have the default profile
profile := r.GetDefaultProfile()
if profile == nil {
t.Error("expected default profile to exist")
}
// Default profile name should be in the list
names := r.ProfileNames()
if len(names) != 1 || names[0] != DefaultProfileName {
t.Errorf("expected profile names to be [%s], got %v", DefaultProfileName, names)
}
}
func TestRegisterProfile(t *testing.T) {
r := NewRegistry()
// Register a new profile
customProfile := &Profile{
Name: "Custom Profile",
EnableGREASE: true,
}
r.RegisterProfile("custom", customProfile)
// Should now have 2 profiles
if r.ProfileCount() != 2 {
t.Errorf("expected 2 profiles, got %d", r.ProfileCount())
}
// Should be able to retrieve the custom profile
retrieved := r.GetProfile("custom")
if retrieved == nil {
t.Fatal("expected custom profile to exist")
}
if retrieved.Name != "Custom Profile" {
t.Errorf("expected profile name 'Custom Profile', got '%s'", retrieved.Name)
}
if !retrieved.EnableGREASE {
t.Error("expected EnableGREASE to be true")
}
}
func TestGetProfile(t *testing.T) {
r := NewRegistry()
// Get existing profile
profile := r.GetProfile(DefaultProfileName)
if profile == nil {
t.Error("expected default profile to exist")
}
// Get non-existing profile
nonExistent := r.GetProfile("nonexistent")
if nonExistent != nil {
t.Error("expected nil for non-existent profile")
}
}
func TestGetProfileByAccountID(t *testing.T) {
r := NewRegistry()
// With only default profile, all account IDs should return the same profile
for i := int64(0); i < 10; i++ {
profile := r.GetProfileByAccountID(i)
if profile == nil {
t.Errorf("expected profile for account %d, got nil", i)
}
}
// Add more profiles
r.RegisterProfile("profile_a", &Profile{Name: "Profile A"})
r.RegisterProfile("profile_b", &Profile{Name: "Profile B"})
// Now we have 3 profiles: claude_cli_v2, profile_a, profile_b
// Names are sorted, so order is: claude_cli_v2, profile_a, profile_b
expectedOrder := []string{DefaultProfileName, "profile_a", "profile_b"}
names := r.ProfileNames()
for i, name := range expectedOrder {
if names[i] != name {
t.Errorf("expected name at index %d to be %s, got %s", i, name, names[i])
}
}
// Test modulo selection
// Account ID 0 % 3 = 0 -> claude_cli_v2
// Account ID 1 % 3 = 1 -> profile_a
// Account ID 2 % 3 = 2 -> profile_b
// Account ID 3 % 3 = 0 -> claude_cli_v2
testCases := []struct {
accountID int64
expectedName string
}{
{0, "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"},
{1, "Profile A"},
{2, "Profile B"},
{3, "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"},
{4, "Profile A"},
{5, "Profile B"},
{100, "Profile A"}, // 100 % 3 = 1
{-1, "Profile A"}, // |-1| % 3 = 1
{-3, "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"}, // |-3| % 3 = 0
}
for _, tc := range testCases {
profile := r.GetProfileByAccountID(tc.accountID)
if profile == nil {
t.Errorf("expected profile for account %d, got nil", tc.accountID)
continue
}
if profile.Name != tc.expectedName {
t.Errorf("account %d: expected profile name '%s', got '%s'", tc.accountID, tc.expectedName, profile.Name)
}
}
}
func TestNewRegistryFromConfig(t *testing.T) {
// Test with nil config
r := NewRegistryFromConfig(nil)
if r.ProfileCount() != 1 {
t.Errorf("expected 1 profile with nil config, got %d", r.ProfileCount())
}
// Test with disabled config
disabledCfg := &config.TLSFingerprintConfig{
Enabled: false,
}
r = NewRegistryFromConfig(disabledCfg)
if r.ProfileCount() != 1 {
t.Errorf("expected 1 profile with disabled config, got %d", r.ProfileCount())
}
// Test with enabled config and custom profiles
enabledCfg := &config.TLSFingerprintConfig{
Enabled: true,
Profiles: map[string]config.TLSProfileConfig{
"custom1": {
Name: "Custom Profile 1",
EnableGREASE: true,
},
"custom2": {
Name: "Custom Profile 2",
EnableGREASE: false,
},
},
}
r = NewRegistryFromConfig(enabledCfg)
// Should have 3 profiles: default + 2 custom
if r.ProfileCount() != 3 {
t.Errorf("expected 3 profiles, got %d", r.ProfileCount())
}
// Check custom profiles exist
custom1 := r.GetProfile("custom1")
if custom1 == nil || custom1.Name != "Custom Profile 1" {
t.Error("expected custom1 profile to exist with correct name")
}
custom2 := r.GetProfile("custom2")
if custom2 == nil || custom2.Name != "Custom Profile 2" {
t.Error("expected custom2 profile to exist with correct name")
}
}
func TestProfileNames(t *testing.T) {
r := NewRegistry()
// Add profiles in non-alphabetical order
r.RegisterProfile("zebra", &Profile{Name: "Zebra"})
r.RegisterProfile("alpha", &Profile{Name: "Alpha"})
r.RegisterProfile("beta", &Profile{Name: "Beta"})
names := r.ProfileNames()
// Should be sorted alphabetically
expected := []string{"alpha", "beta", DefaultProfileName, "zebra"}
if len(names) != len(expected) {
t.Errorf("expected %d names, got %d", len(expected), len(names))
}
for i, name := range expected {
if names[i] != name {
t.Errorf("expected name at index %d to be %s, got %s", i, name, names[i])
}
}
// Test that returned slice is a copy (modifying it shouldn't affect registry)
names[0] = "modified"
originalNames := r.ProfileNames()
if originalNames[0] == "modified" {
t.Error("modifying returned slice should not affect registry")
}
}
func TestConcurrentAccess(t *testing.T) {
r := NewRegistry()
// Run concurrent reads and writes
done := make(chan bool)
// Writers
for i := 0; i < 10; i++ {
go func(id int) {
for j := 0; j < 100; j++ {
r.RegisterProfile("concurrent"+string(rune('0'+id)), &Profile{Name: "Concurrent"})
}
done <- true
}(i)
}
// Readers
for i := 0; i < 10; i++ {
go func(id int) {
for j := 0; j < 100; j++ {
_ = r.ProfileCount()
_ = r.ProfileNames()
_ = r.GetProfileByAccountID(int64(id * j))
_ = r.GetProfile(DefaultProfileName)
}
done <- true
}(i)
}
// Wait for all goroutines
for i := 0; i < 20; i++ {
<-done
}
// Test should pass without data races (run with -race flag)
}

View File

@@ -1,14 +1,8 @@
package usagestats
// AccountStats 账号使用统计
//
// cost: 账号口径费用(使用 total_cost * account_rate_multiplier
// standard_cost: 标准费用(使用 total_cost不含倍率
// user_cost: 用户/API Key 口径费用(使用 actual_cost受分组倍率影响
type AccountStats struct {
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
StandardCost float64 `json:"standard_cost"`
UserCost float64 `json:"user_cost"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
}

View File

@@ -9,12 +9,6 @@ type DashboardStats struct {
TotalUsers int64 `json:"total_users"`
TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
// 小时活跃用户数UTC 当前小时)
HourlyActiveUsers int64 `json:"hourly_active_users"`
// 预聚合新鲜度
StatsUpdatedAt string `json:"stats_updated_at"`
StatsStale bool `json:"stats_stale"`
// API Key 统计
TotalAPIKeys int64 `json:"total_api_keys"`
@@ -147,15 +141,14 @@ type UsageLogFilters struct {
// UsageStats represents usage statistics
type UsageStats struct {
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
AverageDurationMs float64 `json:"average_duration_ms"`
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
AverageDurationMs float64 `json:"average_duration_ms"`
}
// BatchUserUsageStats represents usage stats for a single user
@@ -178,29 +171,25 @@ type AccountUsageHistory struct {
Label string `json:"label"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"` // 标准计费total_cost
ActualCost float64 `json:"actual_cost"` // 账号口径费用total_cost * account_rate_multiplier
UserCost float64 `json:"user_cost"` // 用户口径费用actual_cost受分组倍率影响
Cost float64 `json:"cost"`
ActualCost float64 `json:"actual_cost"`
}
// AccountUsageSummary represents summary statistics for an account
type AccountUsageSummary struct {
Days int `json:"days"`
ActualDaysUsed int `json:"actual_days_used"`
TotalCost float64 `json:"total_cost"` // 账号口径费用
TotalUserCost float64 `json:"total_user_cost"` // 用户口径费用
TotalCost float64 `json:"total_cost"`
TotalStandardCost float64 `json:"total_standard_cost"`
TotalRequests int64 `json:"total_requests"`
TotalTokens int64 `json:"total_tokens"`
AvgDailyCost float64 `json:"avg_daily_cost"` // 账号口径日均
AvgDailyUserCost float64 `json:"avg_daily_user_cost"`
AvgDailyCost float64 `json:"avg_daily_cost"`
AvgDailyRequests float64 `json:"avg_daily_requests"`
AvgDailyTokens float64 `json:"avg_daily_tokens"`
AvgDurationMs float64 `json:"avg_duration_ms"`
Today *struct {
Date string `json:"date"`
Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
} `json:"today"`
@@ -208,7 +197,6 @@ type AccountUsageSummary struct {
Date string `json:"date"`
Label string `json:"label"`
Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
Requests int64 `json:"requests"`
} `json:"highest_cost_day"`
HighestRequestDay *struct {
@@ -216,7 +204,6 @@ type AccountUsageSummary struct {
Label string `json:"label"`
Requests int64 `json:"requests"`
Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
} `json:"highest_request_day"`
}

View File

@@ -15,7 +15,6 @@ import (
"database/sql"
"encoding/json"
"errors"
"log"
"strconv"
"time"
@@ -80,10 +79,6 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
SetSchedulable(account.Schedulable).
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
if account.RateMultiplier != nil {
builder.SetRateMultiplier(*account.RateMultiplier)
}
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
}
@@ -120,9 +115,6 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
account.ID = created.ID
account.CreatedAt = created.CreatedAt
account.UpdatedAt = created.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err)
}
return nil
}
@@ -295,10 +287,6 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
SetSchedulable(account.Schedulable).
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
if account.RateMultiplier != nil {
builder.SetRateMultiplier(*account.RateMultiplier)
}
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
} else {
@@ -353,17 +341,10 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
account.UpdatedAt = updated.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
}
return nil
}
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
groupIDs, err := r.loadAccountGroupIDs(ctx, id)
if err != nil {
return err
}
// 使用事务保证账号与关联分组的删除原子性
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
@@ -387,12 +368,7 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
}
if tx != nil {
if err := tx.Commit(); err != nil {
return err
}
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, buildSchedulerGroupPayload(groupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
return tx.Commit()
}
return nil
}
@@ -479,18 +455,7 @@ func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error
Where(dbaccount.IDEQ(id)).
SetLastUsedAt(now).
Save(ctx)
if err != nil {
return err
}
payload := map[string]any{
"last_used": map[string]int64{
strconv.FormatInt(id, 10): now.Unix(),
},
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, &id, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err)
}
return nil
return err
}
func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
@@ -514,18 +479,7 @@ func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map
args = append(args, pq.Array(ids))
_, err := r.sql.ExecContext(ctx, caseSQL, args...)
if err != nil {
return err
}
lastUsedPayload := make(map[string]int64, len(updates))
for id, ts := range updates {
lastUsedPayload[strconv.FormatInt(id, 10)] = ts.Unix()
}
payload := map[string]any{"last_used": lastUsedPayload}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, nil, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue batch last used failed: err=%v", err)
}
return nil
return err
}
func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
@@ -534,21 +488,6 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
SetStatus(service.StatusError).
SetErrorMessage(errorMsg).
Save(ctx)
if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetStatus(service.StatusActive).
SetErrorMessage("").
Save(ctx)
return err
}
@@ -558,14 +497,7 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i
SetGroupID(groupID).
SetPriority(priority).
Save(ctx)
if err != nil {
return err
}
payload := buildSchedulerGroupPayload([]int64{groupID})
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err)
}
return nil
return err
}
func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
@@ -575,14 +507,7 @@ func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, grou
dbaccountgroup.GroupIDEQ(groupID),
).
Exec(ctx)
if err != nil {
return err
}
payload := buildSchedulerGroupPayload([]int64{groupID})
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err)
}
return nil
return err
}
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) {
@@ -603,10 +528,6 @@ func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]s
}
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
existingGroupIDs, err := r.loadAccountGroupIDs(ctx, accountID)
if err != nil {
return err
}
// 使用事务保证删除旧绑定与创建新绑定的原子性
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
@@ -647,13 +568,7 @@ func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, gro
}
if tx != nil {
if err := tx.Commit(); err != nil {
return err
}
}
payload := buildSchedulerGroupPayload(mergeGroupIDs(existingGroupIDs, groupIDs))
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err)
return tx.Commit()
}
return nil
}
@@ -757,13 +672,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
SetRateLimitedAt(now).
SetRateLimitResetAt(resetAt).
Save(ctx)
if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
}
return nil
return err
}
func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
@@ -797,49 +706,6 @@ func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, i
if affected == 0 {
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
if scope == "" {
return nil
}
now := time.Now().UTC()
payload := map[string]string{
"rate_limited_at": now.Format(time.RFC3339),
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
}
raw, err := json.Marshal(payload)
if err != nil {
return err
}
path := "{model_rate_limits," + scope + "}"
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
path,
raw,
id,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
}
return nil
}
@@ -848,13 +714,7 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
Where(dbaccount.IDEQ(id)).
SetOverloadUntil(until).
Save(ctx)
if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err)
}
return nil
return err
}
func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
@@ -867,13 +727,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
AND deleted_at IS NULL
AND (temp_unschedulable_until IS NULL OR temp_unschedulable_until < $1)
`, until, reason, id)
if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
}
return nil
return err
}
func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64) error {
@@ -885,13 +739,7 @@ func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64
WHERE id = $1
AND deleted_at IS NULL
`, id)
if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err)
}
return nil
return err
}
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
@@ -901,13 +749,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
ClearRateLimitResetAt().
ClearOverloadUntil().
Save(ctx)
if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
}
return nil
return err
}
func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
@@ -928,33 +770,6 @@ func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id
if affected == 0 {
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) ClearModelRateLimits(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) - 'model_rate_limits', updated_at = NOW() WHERE id = $1 AND deleted_at IS NULL",
id,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err)
}
return nil
}
@@ -969,16 +784,7 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
builder.SetSessionWindowEnd(*end)
}
_, err := builder.Save(ctx)
if err != nil {
return err
}
// 触发调度器缓存更新(仅当窗口时间有变化时)
if start != nil || end != nil {
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err)
}
}
return nil
return err
}
func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
@@ -986,13 +792,7 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
Where(dbaccount.IDEQ(id)).
SetSchedulable(schedulable).
Save(ctx)
if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
}
return nil
return err
}
func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
@@ -1013,11 +813,6 @@ func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now ti
if err != nil {
return 0, err
}
if rows > 0 {
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventFullRebuild, nil, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err)
}
}
return rows, nil
}
@@ -1049,9 +844,6 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
if affected == 0 {
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
}
return nil
}
@@ -1089,11 +881,6 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
args = append(args, *updates.Priority)
idx++
}
if updates.RateMultiplier != nil {
setClauses = append(setClauses, "rate_multiplier = $"+itoa(idx))
args = append(args, *updates.RateMultiplier)
idx++
}
if updates.Status != nil {
setClauses = append(setClauses, "status = $"+itoa(idx))
args = append(args, *updates.Status)
@@ -1141,12 +928,6 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
if err != nil {
return 0, err
}
if rows > 0 {
payload := map[string]any{"account_ids": ids}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
}
}
return rows, nil
}
@@ -1389,61 +1170,11 @@ func (r *accountRepository) loadAccountGroups(ctx context.Context, accountIDs []
return groupsByAccount, groupIDsByAccount, accountGroupsByAccount, nil
}
func (r *accountRepository) loadAccountGroupIDs(ctx context.Context, accountID int64) ([]int64, error) {
entries, err := r.client.AccountGroup.
Query().
Where(dbaccountgroup.AccountIDEQ(accountID)).
All(ctx)
if err != nil {
return nil, err
}
ids := make([]int64, 0, len(entries))
for _, entry := range entries {
ids = append(ids, entry.GroupID)
}
return ids, nil
}
func mergeGroupIDs(a []int64, b []int64) []int64 {
seen := make(map[int64]struct{}, len(a)+len(b))
out := make([]int64, 0, len(a)+len(b))
for _, id := range a {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
for _, id := range b {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
return out
}
func buildSchedulerGroupPayload(groupIDs []int64) map[string]any {
if len(groupIDs) == 0 {
return nil
}
return map[string]any{"group_ids": groupIDs}
}
func accountEntityToService(m *dbent.Account) *service.Account {
if m == nil {
return nil
}
rateMultiplier := m.RateMultiplier
return &service.Account{
ID: m.ID,
Name: m.Name,
@@ -1455,7 +1186,6 @@ func accountEntityToService(m *dbent.Account) *service.Account {
ProxyID: m.ProxyID,
Concurrency: m.Concurrency,
Priority: m.Priority,
RateMultiplier: &rateMultiplier,
Status: m.Status,
ErrorMessage: derefString(m.ErrorMessage),
LastUsedAt: m.LastUsedAt,

View File

@@ -2,10 +2,8 @@ package repository
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -13,10 +11,8 @@ import (
)
const (
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
apiKeyRateLimitDuration = 24 * time.Hour
apiKeyAuthCachePrefix = "apikey:auth:"
authCacheInvalidateChannel = "auth:cache:invalidate"
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
apiKeyRateLimitDuration = 24 * time.Hour
)
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
@@ -24,10 +20,6 @@ func apiKeyRateLimitKey(userID int64) string {
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
}
func apiKeyAuthCacheKey(key string) string {
return fmt.Sprintf("%s%s", apiKeyAuthCachePrefix, key)
}
type apiKeyCache struct {
rdb *redis.Client
}
@@ -66,72 +58,3 @@ func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) er
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
return c.rdb.Expire(ctx, apiKey, ttl).Err()
}
func (c *apiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) {
val, err := c.rdb.Get(ctx, apiKeyAuthCacheKey(key)).Bytes()
if err != nil {
return nil, err
}
var entry service.APIKeyAuthCacheEntry
if err := json.Unmarshal(val, &entry); err != nil {
return nil, err
}
return &entry, nil
}
func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error {
if entry == nil {
return nil
}
payload, err := json.Marshal(entry)
if err != nil {
return err
}
return c.rdb.Set(ctx, apiKeyAuthCacheKey(key), payload, ttl).Err()
}
func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err()
}
// PublishAuthCacheInvalidation publishes a cache invalidation message to all instances
func (c *apiKeyCache) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
return c.rdb.Publish(ctx, authCacheInvalidateChannel, cacheKey).Err()
}
// SubscribeAuthCacheInvalidation subscribes to cache invalidation messages
func (c *apiKeyCache) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
pubsub := c.rdb.Subscribe(ctx, authCacheInvalidateChannel)
// Verify subscription is working
_, err := pubsub.Receive(ctx)
if err != nil {
_ = pubsub.Close()
return fmt.Errorf("subscribe to auth cache invalidation: %w", err)
}
go func() {
defer func() {
if err := pubsub.Close(); err != nil {
log.Printf("Warning: failed to close auth cache invalidation pubsub: %v", err)
}
}()
ch := pubsub.Channel()
for {
select {
case <-ctx.Done():
return
case msg, ok := <-ch:
if !ok {
return
}
if msg != nil {
handler(msg.Payload)
}
}
}
}()
return nil
}

View File

@@ -6,9 +6,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -66,23 +64,23 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIK
return apiKeyEntityToService(m), nil
}
// GetKeyAndOwnerID 根据 API Key ID 获取其 key 与所有者用户ID。
// GetOwnerID 根据 API Key ID 获取其所有者(用户)ID。
// 相比 GetByID此方法性能更优因为
// - 使用 Select() 只查询必要字段,减少数据传输量
// - 使用 Select() 只查询 user_id 字段,减少数据传输量
// - 不加载完整的 API Key 实体及其关联数据User、Group 等)
// - 适用于删除等只需 key 与用户 ID 的场景
func (r *apiKeyRepository) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
m, err := r.activeQuery().
Where(apikey.IDEQ(id)).
Select(apikey.FieldKey, apikey.FieldUserID).
Select(apikey.FieldUserID).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return "", 0, service.ErrAPIKeyNotFound
return 0, service.ErrAPIKeyNotFound
}
return "", 0, err
return 0, err
}
return m.Key, m.UserID, nil
return m.UserID, nil
}
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
@@ -100,56 +98,6 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
return apiKeyEntityToService(m), nil
}
func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
m, err := r.activeQuery().
Where(apikey.KeyEQ(key)).
Select(
apikey.FieldID,
apikey.FieldUserID,
apikey.FieldGroupID,
apikey.FieldStatus,
apikey.FieldIPWhitelist,
apikey.FieldIPBlacklist,
).
WithUser(func(q *dbent.UserQuery) {
q.Select(
user.FieldID,
user.FieldStatus,
user.FieldRole,
user.FieldBalance,
user.FieldConcurrency,
)
}).
WithGroup(func(q *dbent.GroupQuery) {
q.Select(
group.FieldID,
group.FieldName,
group.FieldPlatform,
group.FieldStatus,
group.FieldSubscriptionType,
group.FieldRateMultiplier,
group.FieldDailyLimitUsd,
group.FieldWeeklyLimitUsd,
group.FieldMonthlyLimitUsd,
group.FieldImagePrice1k,
group.FieldImagePrice2k,
group.FieldImagePrice4k,
group.FieldClaudeCodeOnly,
group.FieldFallbackGroupID,
group.FieldModelRoutingEnabled,
group.FieldModelRouting,
)
}).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrAPIKeyNotFound
}
return nil, err
}
return apiKeyEntityToService(m), nil
}
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
// 之前的实现先检查 Exist 再 UpdateOneID若在两步之间发生软删除
@@ -335,28 +283,6 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i
return int64(count), err
}
func (r *apiKeyRepository) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
keys, err := r.activeQuery().
Where(apikey.UserIDEQ(userID)).
Select(apikey.FieldKey).
Strings(ctx)
if err != nil {
return nil, err
}
return keys, nil
}
func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
keys, err := r.activeQuery().
Where(apikey.GroupIDEQ(groupID)).
Select(apikey.FieldKey).
Strings(ctx)
if err != nil {
return nil, err
}
return keys, nil
}
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
if m == nil {
return nil
@@ -424,8 +350,6 @@ func groupEntityToService(g *dbent.Group) *service.Group {
DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}

View File

@@ -93,7 +93,7 @@ var (
return redis.call('ZCARD', key)
`)
// incrementWaitScript - refreshes TTL on each increment to keep queue depth accurate
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
// KEYS[1] = wait queue key
// ARGV[1] = maxWait
// ARGV[2] = TTL in seconds
@@ -111,13 +111,15 @@ var (
local newVal = redis.call('INCR', KEYS[1])
-- Refresh TTL so long-running traffic doesn't expire active queue counters.
redis.call('EXPIRE', KEYS[1], ARGV[2])
-- Only set TTL on first creation to avoid refreshing zombie data
if newVal == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1
`)
// incrementAccountWaitScript - account-level wait queue count (refresh TTL on each increment)
// incrementAccountWaitScript - account-level wait queue count
incrementAccountWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
@@ -132,8 +134,10 @@ var (
local newVal = redis.call('INCR', KEYS[1])
-- Refresh TTL so long-running traffic doesn't expire active queue counters.
redis.call('EXPIRE', KEYS[1], ARGV[2])
-- Only set TTL on first creation to avoid refreshing zombie data
if newVal == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1
`)

View File

@@ -1,461 +0,0 @@
package repository
import (
"context"
"database/sql"
"fmt"
"log"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
type dashboardAggregationRepository struct {
sql sqlExecutor
}
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
if sqlDB == nil {
return nil
}
if !isPostgresDriver(sqlDB) {
log.Printf("[DashboardAggregation] 检测到非 PostgreSQL 驱动,已自动禁用预聚合")
return nil
}
return newDashboardAggregationRepositoryWithSQL(sqlDB)
}
func newDashboardAggregationRepositoryWithSQL(sqlq sqlExecutor) *dashboardAggregationRepository {
return &dashboardAggregationRepository{sql: sqlq}
}
func isPostgresDriver(db *sql.DB) bool {
if db == nil {
return false
}
_, ok := db.Driver().(*pq.Driver)
return ok
}
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
loc := timezone.Location()
startLocal := start.In(loc)
endLocal := end.In(loc)
if !endLocal.After(startLocal) {
return nil
}
hourStart := startLocal.Truncate(time.Hour)
hourEnd := endLocal.Truncate(time.Hour)
if endLocal.After(hourEnd) {
hourEnd = hourEnd.Add(time.Hour)
}
dayStart := truncateToDay(startLocal)
dayEnd := truncateToDay(endLocal)
if endLocal.After(dayEnd) {
dayEnd = dayEnd.Add(24 * time.Hour)
}
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
return err
}
if err := r.insertDailyActiveUsers(ctx, hourStart, hourEnd); err != nil {
return err
}
if err := r.upsertHourlyAggregates(ctx, hourStart, hourEnd); err != nil {
return err
}
if err := r.upsertDailyAggregates(ctx, dayStart, dayEnd); err != nil {
return err
}
return nil
}
func (r *dashboardAggregationRepository) RecomputeRange(ctx context.Context, start, end time.Time) error {
if r == nil || r.sql == nil {
return nil
}
loc := timezone.Location()
startLocal := start.In(loc)
endLocal := end.In(loc)
if !endLocal.After(startLocal) {
return nil
}
hourStart := startLocal.Truncate(time.Hour)
hourEnd := endLocal.Truncate(time.Hour)
if endLocal.After(hourEnd) {
hourEnd = hourEnd.Add(time.Hour)
}
dayStart := truncateToDay(startLocal)
dayEnd := truncateToDay(endLocal)
if endLocal.After(dayEnd) {
dayEnd = dayEnd.Add(24 * time.Hour)
}
// 尽量使用事务保证范围内的一致性(允许在非 *sql.DB 的情况下退化为非事务执行)。
if db, ok := r.sql.(*sql.DB); ok {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
txRepo := newDashboardAggregationRepositoryWithSQL(tx)
if err := txRepo.recomputeRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil {
_ = tx.Rollback()
return err
}
return tx.Commit()
}
return r.recomputeRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd)
}
func (r *dashboardAggregationRepository) recomputeRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error {
// 先清空范围内桶,再重建(避免仅增量插入导致活跃用户等指标无法回退)。
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly WHERE bucket_start >= $1 AND bucket_start < $2", hourStart, hourEnd); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly_users WHERE bucket_start >= $1 AND bucket_start < $2", hourStart, hourEnd); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily WHERE bucket_date >= $1::date AND bucket_date < $2::date", dayStart, dayEnd); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily_users WHERE bucket_date >= $1::date AND bucket_date < $2::date", dayStart, dayEnd); err != nil {
return err
}
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
return err
}
if err := r.insertDailyActiveUsers(ctx, hourStart, hourEnd); err != nil {
return err
}
if err := r.upsertHourlyAggregates(ctx, hourStart, hourEnd); err != nil {
return err
}
if err := r.upsertDailyAggregates(ctx, dayStart, dayEnd); err != nil {
return err
}
return nil
}
func (r *dashboardAggregationRepository) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
var ts time.Time
query := "SELECT last_aggregated_at FROM usage_dashboard_aggregation_watermark WHERE id = 1"
if err := scanSingleRow(ctx, r.sql, query, nil, &ts); err != nil {
if err == sql.ErrNoRows {
return time.Unix(0, 0).UTC(), nil
}
return time.Time{}, err
}
return ts.UTC(), nil
}
func (r *dashboardAggregationRepository) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
query := `
INSERT INTO usage_dashboard_aggregation_watermark (id, last_aggregated_at, updated_at)
VALUES (1, $1, NOW())
ON CONFLICT (id)
DO UPDATE SET last_aggregated_at = EXCLUDED.last_aggregated_at, updated_at = EXCLUDED.updated_at
`
_, err := r.sql.ExecContext(ctx, query, aggregatedAt.UTC())
return err
}
func (r *dashboardAggregationRepository) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
hourlyCutoffUTC := hourlyCutoff.UTC()
dailyCutoffUTC := dailyCutoff.UTC()
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly WHERE bucket_start < $1", hourlyCutoffUTC); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly_users WHERE bucket_start < $1", hourlyCutoffUTC); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily WHERE bucket_date < $1::date", dailyCutoffUTC); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily_users WHERE bucket_date < $1::date", dailyCutoffUTC); err != nil {
return err
}
return nil
}
func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
isPartitioned, err := r.isUsageLogsPartitioned(ctx)
if err != nil {
return err
}
if isPartitioned {
return r.dropUsageLogsPartitions(ctx, cutoff)
}
_, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC())
return err
}
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
isPartitioned, err := r.isUsageLogsPartitioned(ctx)
if err != nil || !isPartitioned {
return err
}
monthStart := truncateToMonthUTC(now)
prevMonth := monthStart.AddDate(0, -1, 0)
nextMonth := monthStart.AddDate(0, 1, 0)
for _, m := range []time.Time{prevMonth, monthStart, nextMonth} {
if err := r.createUsageLogsPartition(ctx, m); err != nil {
return err
}
}
return nil
}
func (r *dashboardAggregationRepository) insertHourlyActiveUsers(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := `
INSERT INTO usage_dashboard_hourly_users (bucket_start, user_id)
SELECT DISTINCT
date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
user_id
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
ON CONFLICT DO NOTHING
`
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) insertDailyActiveUsers(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := `
INSERT INTO usage_dashboard_daily_users (bucket_date, user_id)
SELECT DISTINCT
(bucket_start AT TIME ZONE $3)::date AS bucket_date,
user_id
FROM usage_dashboard_hourly_users
WHERE bucket_start >= $1 AND bucket_start < $2
ON CONFLICT DO NOTHING
`
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := `
WITH hourly AS (
SELECT
date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
COUNT(*) AS total_requests,
COALESCE(SUM(input_tokens), 0) AS input_tokens,
COALESCE(SUM(output_tokens), 0) AS output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) AS cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
COALESCE(SUM(total_cost), 0) AS total_cost,
COALESCE(SUM(actual_cost), 0) AS actual_cost,
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) AS total_duration_ms
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
GROUP BY 1
),
user_counts AS (
SELECT bucket_start, COUNT(*) AS active_users
FROM usage_dashboard_hourly_users
WHERE bucket_start >= $1 AND bucket_start < $2
GROUP BY bucket_start
)
INSERT INTO usage_dashboard_hourly (
bucket_start,
total_requests,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
total_cost,
actual_cost,
total_duration_ms,
active_users,
computed_at
)
SELECT
hourly.bucket_start,
hourly.total_requests,
hourly.input_tokens,
hourly.output_tokens,
hourly.cache_creation_tokens,
hourly.cache_read_tokens,
hourly.total_cost,
hourly.actual_cost,
hourly.total_duration_ms,
COALESCE(user_counts.active_users, 0) AS active_users,
NOW()
FROM hourly
LEFT JOIN user_counts ON user_counts.bucket_start = hourly.bucket_start
ON CONFLICT (bucket_start)
DO UPDATE SET
total_requests = EXCLUDED.total_requests,
input_tokens = EXCLUDED.input_tokens,
output_tokens = EXCLUDED.output_tokens,
cache_creation_tokens = EXCLUDED.cache_creation_tokens,
cache_read_tokens = EXCLUDED.cache_read_tokens,
total_cost = EXCLUDED.total_cost,
actual_cost = EXCLUDED.actual_cost,
total_duration_ms = EXCLUDED.total_duration_ms,
active_users = EXCLUDED.active_users,
computed_at = EXCLUDED.computed_at
`
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := `
WITH daily AS (
SELECT
(bucket_start AT TIME ZONE $5)::date AS bucket_date,
COALESCE(SUM(total_requests), 0) AS total_requests,
COALESCE(SUM(input_tokens), 0) AS input_tokens,
COALESCE(SUM(output_tokens), 0) AS output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) AS cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
COALESCE(SUM(total_cost), 0) AS total_cost,
COALESCE(SUM(actual_cost), 0) AS actual_cost,
COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms
FROM usage_dashboard_hourly
WHERE bucket_start >= $1 AND bucket_start < $2
GROUP BY (bucket_start AT TIME ZONE $5)::date
),
user_counts AS (
SELECT bucket_date, COUNT(*) AS active_users
FROM usage_dashboard_daily_users
WHERE bucket_date >= $3::date AND bucket_date < $4::date
GROUP BY bucket_date
)
INSERT INTO usage_dashboard_daily (
bucket_date,
total_requests,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
total_cost,
actual_cost,
total_duration_ms,
active_users,
computed_at
)
SELECT
daily.bucket_date,
daily.total_requests,
daily.input_tokens,
daily.output_tokens,
daily.cache_creation_tokens,
daily.cache_read_tokens,
daily.total_cost,
daily.actual_cost,
daily.total_duration_ms,
COALESCE(user_counts.active_users, 0) AS active_users,
NOW()
FROM daily
LEFT JOIN user_counts ON user_counts.bucket_date = daily.bucket_date
ON CONFLICT (bucket_date)
DO UPDATE SET
total_requests = EXCLUDED.total_requests,
input_tokens = EXCLUDED.input_tokens,
output_tokens = EXCLUDED.output_tokens,
cache_creation_tokens = EXCLUDED.cache_creation_tokens,
cache_read_tokens = EXCLUDED.cache_read_tokens,
total_cost = EXCLUDED.total_cost,
actual_cost = EXCLUDED.actual_cost,
total_duration_ms = EXCLUDED.total_duration_ms,
active_users = EXCLUDED.active_users,
computed_at = EXCLUDED.computed_at
`
_, err := r.sql.ExecContext(ctx, query, start, end, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) isUsageLogsPartitioned(ctx context.Context) (bool, error) {
query := `
SELECT EXISTS(
SELECT 1
FROM pg_partitioned_table pt
JOIN pg_class c ON c.oid = pt.partrelid
WHERE c.relname = 'usage_logs'
)
`
var partitioned bool
if err := scanSingleRow(ctx, r.sql, query, nil, &partitioned); err != nil {
return false, err
}
return partitioned, nil
}
func (r *dashboardAggregationRepository) dropUsageLogsPartitions(ctx context.Context, cutoff time.Time) error {
rows, err := r.sql.QueryContext(ctx, `
SELECT c.relname
FROM pg_inherits
JOIN pg_class c ON c.oid = pg_inherits.inhrelid
JOIN pg_class p ON p.oid = pg_inherits.inhparent
WHERE p.relname = 'usage_logs'
`)
if err != nil {
return err
}
defer func() {
_ = rows.Close()
}()
cutoffMonth := truncateToMonthUTC(cutoff)
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return err
}
if !strings.HasPrefix(name, "usage_logs_") {
continue
}
suffix := strings.TrimPrefix(name, "usage_logs_")
month, err := time.Parse("200601", suffix)
if err != nil {
continue
}
month = month.UTC()
if month.Before(cutoffMonth) {
if _, err := r.sql.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", pq.QuoteIdentifier(name))); err != nil {
return err
}
}
}
return rows.Err()
}
func (r *dashboardAggregationRepository) createUsageLogsPartition(ctx context.Context, month time.Time) error {
monthStart := truncateToMonthUTC(month)
nextMonth := monthStart.AddDate(0, 1, 0)
name := fmt.Sprintf("usage_logs_%s", monthStart.Format("200601"))
query := fmt.Sprintf(
"CREATE TABLE IF NOT EXISTS %s PARTITION OF usage_logs FOR VALUES FROM (%s) TO (%s)",
pq.QuoteIdentifier(name),
pq.QuoteLiteral(monthStart.Format("2006-01-02")),
pq.QuoteLiteral(nextMonth.Format("2006-01-02")),
)
_, err := r.sql.ExecContext(ctx, query)
return err
}
func truncateToDay(t time.Time) time.Time {
return timezone.StartOfDay(t)
}
func truncateToMonthUTC(t time.Time) time.Time {
t = t.UTC()
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, time.UTC)
}

View File

@@ -1,58 +0,0 @@
package repository
import (
"context"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const dashboardStatsCacheKey = "dashboard:stats:v1"
type dashboardCache struct {
rdb *redis.Client
keyPrefix string
}
func NewDashboardCache(rdb *redis.Client, cfg *config.Config) service.DashboardStatsCache {
prefix := "sub2api:"
if cfg != nil {
prefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
}
if prefix != "" && !strings.HasSuffix(prefix, ":") {
prefix += ":"
}
return &dashboardCache{
rdb: rdb,
keyPrefix: prefix,
}
}
func (c *dashboardCache) GetDashboardStats(ctx context.Context) (string, error) {
val, err := c.rdb.Get(ctx, c.buildKey()).Result()
if err != nil {
if err == redis.Nil {
return "", service.ErrDashboardStatsCacheMiss
}
return "", err
}
return val, nil
}
func (c *dashboardCache) SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error {
return c.rdb.Set(ctx, c.buildKey(), data, ttl).Err()
}
func (c *dashboardCache) buildKey() string {
if c.keyPrefix == "" {
return dashboardStatsCacheKey
}
return c.keyPrefix + dashboardStatsCacheKey
}
func (c *dashboardCache) DeleteDashboardStats(ctx context.Context) error {
return c.rdb.Del(ctx, c.buildKey()).Err()
}

View File

@@ -1,28 +0,0 @@
package repository
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestNewDashboardCacheKeyPrefix(t *testing.T) {
cache := NewDashboardCache(nil, &config.Config{
Dashboard: config.DashboardCacheConfig{
KeyPrefix: "prod",
},
})
impl, ok := cache.(*dashboardCache)
require.True(t, ok)
require.Equal(t, "prod:", impl.keyPrefix)
cache = NewDashboardCache(nil, &config.Config{
Dashboard: config.DashboardCacheConfig{
KeyPrefix: "staging:",
},
})
impl, ok = cache.(*dashboardCache)
require.True(t, ok)
require.Equal(t, "staging:", impl.keyPrefix)
}

Some files were not shown because too many files have changed in this diff Show More