mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-02 12:20:45 +08:00
merge: sync upstream changes
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -83,6 +83,8 @@ temp/
|
|||||||
*.log
|
*.log
|
||||||
*.bak
|
*.bak
|
||||||
.cache/
|
.cache/
|
||||||
|
.dev/
|
||||||
|
.serena/
|
||||||
|
|
||||||
# ===================
|
# ===================
|
||||||
# 构建产物
|
# 构建产物
|
||||||
@@ -127,3 +129,4 @@ deploy/docker-compose.override.yml
|
|||||||
.gocache/
|
.gocache/
|
||||||
vite.config.js
|
vite.config.js
|
||||||
docs/*
|
docs/*
|
||||||
|
.serena/
|
||||||
164
PR_DESCRIPTION.md
Normal file
164
PR_DESCRIPTION.md
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
## 概述
|
||||||
|
|
||||||
|
全面增强运维监控系统(Ops)的错误日志管理和告警静默功能,优化前端 UI 组件代码质量和用户体验。本次更新重构了核心服务层和数据访问层,提升系统可维护性和运维效率。
|
||||||
|
|
||||||
|
## 主要改动
|
||||||
|
|
||||||
|
### 1. 错误日志查询优化
|
||||||
|
|
||||||
|
**功能特性:**
|
||||||
|
- 新增 GetErrorLogByID 接口,支持按 ID 精确查询错误详情
|
||||||
|
- 优化错误日志过滤逻辑,支持多维度筛选(平台、阶段、来源、所有者等)
|
||||||
|
- 改进查询参数处理,简化代码结构
|
||||||
|
- 增强错误分类和标准化处理
|
||||||
|
- 支持错误解决状态追踪(resolved 字段)
|
||||||
|
|
||||||
|
**技术实现:**
|
||||||
|
- `ops_handler.go` - 新增单条错误日志查询接口
|
||||||
|
- `ops_repo.go` - 优化数据查询和过滤条件构建
|
||||||
|
- `ops_models.go` - 扩展错误日志数据模型
|
||||||
|
- 前端 API 接口同步更新
|
||||||
|
|
||||||
|
### 2. 告警静默功能
|
||||||
|
|
||||||
|
**功能特性:**
|
||||||
|
- 支持按规则、平台、分组、区域等维度静默告警
|
||||||
|
- 可设置静默时长和原因说明
|
||||||
|
- 静默记录可追溯,记录创建人和创建时间
|
||||||
|
- 自动过期机制,避免永久静默
|
||||||
|
|
||||||
|
**技术实现:**
|
||||||
|
- `037_ops_alert_silences.sql` - 新增告警静默表
|
||||||
|
- `ops_alerts.go` - 告警静默逻辑实现
|
||||||
|
- `ops_alerts_handler.go` - 告警静默 API 接口
|
||||||
|
- `OpsAlertEventsCard.vue` - 前端告警静默操作界面
|
||||||
|
|
||||||
|
**数据库结构:**
|
||||||
|
|
||||||
|
| 字段 | 类型 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| rule_id | BIGINT | 告警规则 ID |
|
||||||
|
| platform | VARCHAR(64) | 平台标识 |
|
||||||
|
| group_id | BIGINT | 分组 ID(可选) |
|
||||||
|
| region | VARCHAR(64) | 区域(可选) |
|
||||||
|
| until | TIMESTAMPTZ | 静默截止时间 |
|
||||||
|
| reason | TEXT | 静默原因 |
|
||||||
|
| created_by | BIGINT | 创建人 ID |
|
||||||
|
|
||||||
|
### 3. 错误分类标准化
|
||||||
|
|
||||||
|
**功能特性:**
|
||||||
|
- 统一错误阶段分类(request|auth|routing|upstream|network|internal)
|
||||||
|
- 规范错误归属分类(client|provider|platform)
|
||||||
|
- 标准化错误来源分类(client_request|upstream_http|gateway)
|
||||||
|
- 自动迁移历史数据到新分类体系
|
||||||
|
|
||||||
|
**技术实现:**
|
||||||
|
- `038_ops_errors_resolution_retry_results_and_standardize_classification.sql` - 分类标准化迁移
|
||||||
|
- 自动映射历史遗留分类到新标准
|
||||||
|
- 自动解决已恢复的上游错误(客户端状态码 < 400)
|
||||||
|
|
||||||
|
### 4. Gateway 服务集成
|
||||||
|
|
||||||
|
**功能特性:**
|
||||||
|
- 完善各 Gateway 服务的 Ops 集成
|
||||||
|
- 统一错误日志记录接口
|
||||||
|
- 增强上游错误追踪能力
|
||||||
|
|
||||||
|
**涉及服务:**
|
||||||
|
- `antigravity_gateway_service.go` - Antigravity 网关集成
|
||||||
|
- `gateway_service.go` - 通用网关集成
|
||||||
|
- `gemini_messages_compat_service.go` - Gemini 兼容层集成
|
||||||
|
- `openai_gateway_service.go` - OpenAI 网关集成
|
||||||
|
|
||||||
|
### 5. 前端 UI 优化
|
||||||
|
|
||||||
|
**代码重构:**
|
||||||
|
- 大幅简化错误详情模态框代码(从 828 行优化到 450 行)
|
||||||
|
- 优化错误日志表格组件,提升可读性
|
||||||
|
- 清理未使用的 i18n 翻译,减少冗余
|
||||||
|
- 统一组件代码风格和格式
|
||||||
|
- 优化骨架屏组件,更好匹配实际看板布局
|
||||||
|
|
||||||
|
**布局改进:**
|
||||||
|
- 修复模态框内容溢出和滚动问题
|
||||||
|
- 优化表格布局,使用 flex 布局确保正确显示
|
||||||
|
- 改进看板头部布局和交互
|
||||||
|
- 提升响应式体验
|
||||||
|
- 骨架屏支持全屏模式适配
|
||||||
|
|
||||||
|
**交互优化:**
|
||||||
|
- 优化告警事件卡片功能和展示
|
||||||
|
- 改进错误详情展示逻辑
|
||||||
|
- 增强请求详情模态框
|
||||||
|
- 完善运行时设置卡片
|
||||||
|
- 改进加载动画效果
|
||||||
|
|
||||||
|
### 6. 国际化完善
|
||||||
|
|
||||||
|
**文案补充:**
|
||||||
|
- 补充错误日志相关的英文翻译
|
||||||
|
- 添加告警静默功能的中英文文案
|
||||||
|
- 完善提示文本和错误信息
|
||||||
|
- 统一术语翻译标准
|
||||||
|
|
||||||
|
## 文件变更
|
||||||
|
|
||||||
|
**后端(26 个文件):**
|
||||||
|
- `backend/internal/handler/admin/ops_alerts_handler.go` - 告警接口增强
|
||||||
|
- `backend/internal/handler/admin/ops_handler.go` - 错误日志接口优化
|
||||||
|
- `backend/internal/handler/ops_error_logger.go` - 错误记录器增强
|
||||||
|
- `backend/internal/repository/ops_repo.go` - 数据访问层重构
|
||||||
|
- `backend/internal/repository/ops_repo_alerts.go` - 告警数据访问增强
|
||||||
|
- `backend/internal/service/ops_*.go` - 核心服务层重构(10 个文件)
|
||||||
|
- `backend/internal/service/*_gateway_service.go` - Gateway 集成(4 个文件)
|
||||||
|
- `backend/internal/server/routes/admin.go` - 路由配置更新
|
||||||
|
- `backend/migrations/*.sql` - 数据库迁移(2 个文件)
|
||||||
|
- 测试文件更新(5 个文件)
|
||||||
|
|
||||||
|
**前端(13 个文件):**
|
||||||
|
- `frontend/src/views/admin/ops/OpsDashboard.vue` - 看板主页优化
|
||||||
|
- `frontend/src/views/admin/ops/components/*.vue` - 组件重构(10 个文件)
|
||||||
|
- `frontend/src/api/admin/ops.ts` - API 接口扩展
|
||||||
|
- `frontend/src/i18n/locales/*.ts` - 国际化文本(2 个文件)
|
||||||
|
|
||||||
|
## 代码统计
|
||||||
|
|
||||||
|
- 44 个文件修改
|
||||||
|
- 3733 行新增
|
||||||
|
- 995 行删除
|
||||||
|
- 净增加 2738 行
|
||||||
|
|
||||||
|
## 核心改进
|
||||||
|
|
||||||
|
**可维护性提升:**
|
||||||
|
- 重构核心服务层,职责更清晰
|
||||||
|
- 简化前端组件代码,降低复杂度
|
||||||
|
- 统一代码风格和命名规范
|
||||||
|
- 清理冗余代码和未使用的翻译
|
||||||
|
- 标准化错误分类体系
|
||||||
|
|
||||||
|
**功能完善:**
|
||||||
|
- 告警静默功能,减少告警噪音
|
||||||
|
- 错误日志查询优化,提升运维效率
|
||||||
|
- Gateway 服务集成完善,统一监控能力
|
||||||
|
- 错误解决状态追踪,便于问题管理
|
||||||
|
|
||||||
|
**用户体验优化:**
|
||||||
|
- 修复多个 UI 布局问题
|
||||||
|
- 优化交互流程
|
||||||
|
- 完善国际化支持
|
||||||
|
- 提升响应式体验
|
||||||
|
- 改进加载状态展示
|
||||||
|
|
||||||
|
## 测试验证
|
||||||
|
|
||||||
|
- ✅ 错误日志查询和过滤功能
|
||||||
|
- ✅ 告警静默创建和自动过期
|
||||||
|
- ✅ 错误分类标准化迁移
|
||||||
|
- ✅ Gateway 服务错误日志记录
|
||||||
|
- ✅ 前端组件布局和交互
|
||||||
|
- ✅ 骨架屏全屏模式适配
|
||||||
|
- ✅ 国际化文本完整性
|
||||||
|
- ✅ API 接口功能正确性
|
||||||
|
- ✅ 数据库迁移执行成功
|
||||||
@@ -67,7 +67,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
userHandler := handler.NewUserHandler(userService)
|
userHandler := handler.NewUserHandler(userService)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||||
dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
|
|
||||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
||||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||||
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
|
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
|
||||||
@@ -76,15 +75,20 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||||
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||||
|
dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
|
||||||
dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig)
|
dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig)
|
||||||
timingWheelService := service.ProvideTimingWheelService()
|
|
||||||
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
|
|
||||||
dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig)
|
dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig)
|
||||||
|
timingWheelService, err := service.ProvideTimingWheelService()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
|
||||||
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
|
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
|
||||||
accountRepository := repository.NewAccountRepository(client, db)
|
accountRepository := repository.NewAccountRepository(client, db)
|
||||||
proxyRepository := repository.NewProxyRepository(client, db)
|
proxyRepository := repository.NewProxyRepository(client, db)
|
||||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, apiKeyAuthCacheInvalidator)
|
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||||
|
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
|
||||||
adminUserHandler := admin.NewUserHandler(adminService)
|
adminUserHandler := admin.NewUserHandler(adminService)
|
||||||
groupHandler := admin.NewGroupHandler(adminService)
|
groupHandler := admin.NewGroupHandler(adminService)
|
||||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||||
@@ -98,12 +102,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
|
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
|
||||||
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
||||||
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
||||||
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService)
|
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
||||||
|
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
||||||
|
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
|
||||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
||||||
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
|
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
|
||||||
usageCache := service.NewUsageCache()
|
usageCache := service.NewUsageCache()
|
||||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache)
|
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache)
|
||||||
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
|
||||||
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
||||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||||
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
||||||
@@ -112,11 +117,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
||||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||||
schedulerCache := repository.NewSchedulerCache(redisClient)
|
|
||||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
|
||||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
|
||||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
|
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
||||||
|
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache)
|
||||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||||
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
||||||
@@ -125,6 +128,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
||||||
promoHandler := admin.NewPromoHandler(promoService)
|
promoHandler := admin.NewPromoHandler(promoService)
|
||||||
opsRepository := repository.NewOpsRepository(db)
|
opsRepository := repository.NewOpsRepository(db)
|
||||||
|
schedulerCache := repository.NewSchedulerCache(redisClient)
|
||||||
|
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||||
|
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||||
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
|
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
|
||||||
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -134,8 +140,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
identityCache := repository.NewIdentityCache(redisClient)
|
identityCache := repository.NewIdentityCache(redisClient)
|
||||||
identityService := service.NewIdentityService(identityCache)
|
identityService := service.NewIdentityService(identityCache)
|
||||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
|
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
|
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache)
|
||||||
|
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||||
|
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService)
|
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService)
|
||||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
|
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
|
||||||
@@ -166,7 +174,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
|
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
|
||||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
||||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
|
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
|
||||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
||||||
application := &Application{
|
application := &Application{
|
||||||
|
|||||||
@@ -43,6 +43,8 @@ type Account struct {
|
|||||||
Concurrency int `json:"concurrency,omitempty"`
|
Concurrency int `json:"concurrency,omitempty"`
|
||||||
// Priority holds the value of the "priority" field.
|
// Priority holds the value of the "priority" field.
|
||||||
Priority int `json:"priority,omitempty"`
|
Priority int `json:"priority,omitempty"`
|
||||||
|
// RateMultiplier holds the value of the "rate_multiplier" field.
|
||||||
|
RateMultiplier float64 `json:"rate_multiplier,omitempty"`
|
||||||
// Status holds the value of the "status" field.
|
// Status holds the value of the "status" field.
|
||||||
Status string `json:"status,omitempty"`
|
Status string `json:"status,omitempty"`
|
||||||
// ErrorMessage holds the value of the "error_message" field.
|
// ErrorMessage holds the value of the "error_message" field.
|
||||||
@@ -135,6 +137,8 @@ func (*Account) scanValues(columns []string) ([]any, error) {
|
|||||||
values[i] = new([]byte)
|
values[i] = new([]byte)
|
||||||
case account.FieldAutoPauseOnExpired, account.FieldSchedulable:
|
case account.FieldAutoPauseOnExpired, account.FieldSchedulable:
|
||||||
values[i] = new(sql.NullBool)
|
values[i] = new(sql.NullBool)
|
||||||
|
case account.FieldRateMultiplier:
|
||||||
|
values[i] = new(sql.NullFloat64)
|
||||||
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority:
|
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldSessionWindowStatus:
|
case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldSessionWindowStatus:
|
||||||
@@ -241,6 +245,12 @@ func (_m *Account) assignValues(columns []string, values []any) error {
|
|||||||
} else if value.Valid {
|
} else if value.Valid {
|
||||||
_m.Priority = int(value.Int64)
|
_m.Priority = int(value.Int64)
|
||||||
}
|
}
|
||||||
|
case account.FieldRateMultiplier:
|
||||||
|
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field rate_multiplier", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.RateMultiplier = value.Float64
|
||||||
|
}
|
||||||
case account.FieldStatus:
|
case account.FieldStatus:
|
||||||
if value, ok := values[i].(*sql.NullString); !ok {
|
if value, ok := values[i].(*sql.NullString); !ok {
|
||||||
return fmt.Errorf("unexpected type %T for field status", values[i])
|
return fmt.Errorf("unexpected type %T for field status", values[i])
|
||||||
@@ -420,6 +430,9 @@ func (_m *Account) String() string {
|
|||||||
builder.WriteString("priority=")
|
builder.WriteString("priority=")
|
||||||
builder.WriteString(fmt.Sprintf("%v", _m.Priority))
|
builder.WriteString(fmt.Sprintf("%v", _m.Priority))
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("rate_multiplier=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.RateMultiplier))
|
||||||
|
builder.WriteString(", ")
|
||||||
builder.WriteString("status=")
|
builder.WriteString("status=")
|
||||||
builder.WriteString(_m.Status)
|
builder.WriteString(_m.Status)
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
|||||||
@@ -39,6 +39,8 @@ const (
|
|||||||
FieldConcurrency = "concurrency"
|
FieldConcurrency = "concurrency"
|
||||||
// FieldPriority holds the string denoting the priority field in the database.
|
// FieldPriority holds the string denoting the priority field in the database.
|
||||||
FieldPriority = "priority"
|
FieldPriority = "priority"
|
||||||
|
// FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
|
||||||
|
FieldRateMultiplier = "rate_multiplier"
|
||||||
// FieldStatus holds the string denoting the status field in the database.
|
// FieldStatus holds the string denoting the status field in the database.
|
||||||
FieldStatus = "status"
|
FieldStatus = "status"
|
||||||
// FieldErrorMessage holds the string denoting the error_message field in the database.
|
// FieldErrorMessage holds the string denoting the error_message field in the database.
|
||||||
@@ -116,6 +118,7 @@ var Columns = []string{
|
|||||||
FieldProxyID,
|
FieldProxyID,
|
||||||
FieldConcurrency,
|
FieldConcurrency,
|
||||||
FieldPriority,
|
FieldPriority,
|
||||||
|
FieldRateMultiplier,
|
||||||
FieldStatus,
|
FieldStatus,
|
||||||
FieldErrorMessage,
|
FieldErrorMessage,
|
||||||
FieldLastUsedAt,
|
FieldLastUsedAt,
|
||||||
@@ -174,6 +177,8 @@ var (
|
|||||||
DefaultConcurrency int
|
DefaultConcurrency int
|
||||||
// DefaultPriority holds the default value on creation for the "priority" field.
|
// DefaultPriority holds the default value on creation for the "priority" field.
|
||||||
DefaultPriority int
|
DefaultPriority int
|
||||||
|
// DefaultRateMultiplier holds the default value on creation for the "rate_multiplier" field.
|
||||||
|
DefaultRateMultiplier float64
|
||||||
// DefaultStatus holds the default value on creation for the "status" field.
|
// DefaultStatus holds the default value on creation for the "status" field.
|
||||||
DefaultStatus string
|
DefaultStatus string
|
||||||
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||||
@@ -244,6 +249,11 @@ func ByPriority(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldPriority, opts...).ToFunc()
|
return sql.OrderByField(FieldPriority, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ByRateMultiplier orders the results by the rate_multiplier field.
|
||||||
|
func ByRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldRateMultiplier, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// ByStatus orders the results by the status field.
|
// ByStatus orders the results by the status field.
|
||||||
func ByStatus(opts ...sql.OrderTermOption) OrderOption {
|
func ByStatus(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return sql.OrderByField(FieldStatus, opts...).ToFunc()
|
return sql.OrderByField(FieldStatus, opts...).ToFunc()
|
||||||
|
|||||||
@@ -105,6 +105,11 @@ func Priority(v int) predicate.Account {
|
|||||||
return predicate.Account(sql.FieldEQ(FieldPriority, v))
|
return predicate.Account(sql.FieldEQ(FieldPriority, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RateMultiplier applies equality check predicate on the "rate_multiplier" field. It's identical to RateMultiplierEQ.
|
||||||
|
func RateMultiplier(v float64) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldEQ(FieldRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
// Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
|
// Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
|
||||||
func Status(v string) predicate.Account {
|
func Status(v string) predicate.Account {
|
||||||
return predicate.Account(sql.FieldEQ(FieldStatus, v))
|
return predicate.Account(sql.FieldEQ(FieldStatus, v))
|
||||||
@@ -675,6 +680,46 @@ func PriorityLTE(v int) predicate.Account {
|
|||||||
return predicate.Account(sql.FieldLTE(FieldPriority, v))
|
return predicate.Account(sql.FieldLTE(FieldPriority, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RateMultiplierEQ applies the EQ predicate on the "rate_multiplier" field.
|
||||||
|
func RateMultiplierEQ(v float64) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldEQ(FieldRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateMultiplierNEQ applies the NEQ predicate on the "rate_multiplier" field.
|
||||||
|
func RateMultiplierNEQ(v float64) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldNEQ(FieldRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateMultiplierIn applies the In predicate on the "rate_multiplier" field.
|
||||||
|
func RateMultiplierIn(vs ...float64) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldIn(FieldRateMultiplier, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateMultiplierNotIn applies the NotIn predicate on the "rate_multiplier" field.
|
||||||
|
func RateMultiplierNotIn(vs ...float64) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldNotIn(FieldRateMultiplier, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateMultiplierGT applies the GT predicate on the "rate_multiplier" field.
|
||||||
|
func RateMultiplierGT(v float64) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldGT(FieldRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateMultiplierGTE applies the GTE predicate on the "rate_multiplier" field.
|
||||||
|
func RateMultiplierGTE(v float64) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldGTE(FieldRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateMultiplierLT applies the LT predicate on the "rate_multiplier" field.
|
||||||
|
func RateMultiplierLT(v float64) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldLT(FieldRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateMultiplierLTE applies the LTE predicate on the "rate_multiplier" field.
|
||||||
|
func RateMultiplierLTE(v float64) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldLTE(FieldRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
// StatusEQ applies the EQ predicate on the "status" field.
|
// StatusEQ applies the EQ predicate on the "status" field.
|
||||||
func StatusEQ(v string) predicate.Account {
|
func StatusEQ(v string) predicate.Account {
|
||||||
return predicate.Account(sql.FieldEQ(FieldStatus, v))
|
return predicate.Account(sql.FieldEQ(FieldStatus, v))
|
||||||
|
|||||||
@@ -153,6 +153,20 @@ func (_c *AccountCreate) SetNillablePriority(v *int) *AccountCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRateMultiplier sets the "rate_multiplier" field.
|
||||||
|
func (_c *AccountCreate) SetRateMultiplier(v float64) *AccountCreate {
|
||||||
|
_c.mutation.SetRateMultiplier(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil.
|
||||||
|
func (_c *AccountCreate) SetNillableRateMultiplier(v *float64) *AccountCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetRateMultiplier(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// SetStatus sets the "status" field.
|
// SetStatus sets the "status" field.
|
||||||
func (_c *AccountCreate) SetStatus(v string) *AccountCreate {
|
func (_c *AccountCreate) SetStatus(v string) *AccountCreate {
|
||||||
_c.mutation.SetStatus(v)
|
_c.mutation.SetStatus(v)
|
||||||
@@ -429,6 +443,10 @@ func (_c *AccountCreate) defaults() error {
|
|||||||
v := account.DefaultPriority
|
v := account.DefaultPriority
|
||||||
_c.mutation.SetPriority(v)
|
_c.mutation.SetPriority(v)
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.RateMultiplier(); !ok {
|
||||||
|
v := account.DefaultRateMultiplier
|
||||||
|
_c.mutation.SetRateMultiplier(v)
|
||||||
|
}
|
||||||
if _, ok := _c.mutation.Status(); !ok {
|
if _, ok := _c.mutation.Status(); !ok {
|
||||||
v := account.DefaultStatus
|
v := account.DefaultStatus
|
||||||
_c.mutation.SetStatus(v)
|
_c.mutation.SetStatus(v)
|
||||||
@@ -488,6 +506,9 @@ func (_c *AccountCreate) check() error {
|
|||||||
if _, ok := _c.mutation.Priority(); !ok {
|
if _, ok := _c.mutation.Priority(); !ok {
|
||||||
return &ValidationError{Name: "priority", err: errors.New(`ent: missing required field "Account.priority"`)}
|
return &ValidationError{Name: "priority", err: errors.New(`ent: missing required field "Account.priority"`)}
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.RateMultiplier(); !ok {
|
||||||
|
return &ValidationError{Name: "rate_multiplier", err: errors.New(`ent: missing required field "Account.rate_multiplier"`)}
|
||||||
|
}
|
||||||
if _, ok := _c.mutation.Status(); !ok {
|
if _, ok := _c.mutation.Status(); !ok {
|
||||||
return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "Account.status"`)}
|
return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "Account.status"`)}
|
||||||
}
|
}
|
||||||
@@ -578,6 +599,10 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(account.FieldPriority, field.TypeInt, value)
|
_spec.SetField(account.FieldPriority, field.TypeInt, value)
|
||||||
_node.Priority = value
|
_node.Priority = value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.RateMultiplier(); ok {
|
||||||
|
_spec.SetField(account.FieldRateMultiplier, field.TypeFloat64, value)
|
||||||
|
_node.RateMultiplier = value
|
||||||
|
}
|
||||||
if value, ok := _c.mutation.Status(); ok {
|
if value, ok := _c.mutation.Status(); ok {
|
||||||
_spec.SetField(account.FieldStatus, field.TypeString, value)
|
_spec.SetField(account.FieldStatus, field.TypeString, value)
|
||||||
_node.Status = value
|
_node.Status = value
|
||||||
@@ -893,6 +918,24 @@ func (u *AccountUpsert) AddPriority(v int) *AccountUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRateMultiplier sets the "rate_multiplier" field.
|
||||||
|
func (u *AccountUpsert) SetRateMultiplier(v float64) *AccountUpsert {
|
||||||
|
u.Set(account.FieldRateMultiplier, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create.
|
||||||
|
func (u *AccountUpsert) UpdateRateMultiplier() *AccountUpsert {
|
||||||
|
u.SetExcluded(account.FieldRateMultiplier)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRateMultiplier adds v to the "rate_multiplier" field.
|
||||||
|
func (u *AccountUpsert) AddRateMultiplier(v float64) *AccountUpsert {
|
||||||
|
u.Add(account.FieldRateMultiplier, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// SetStatus sets the "status" field.
|
// SetStatus sets the "status" field.
|
||||||
func (u *AccountUpsert) SetStatus(v string) *AccountUpsert {
|
func (u *AccountUpsert) SetStatus(v string) *AccountUpsert {
|
||||||
u.Set(account.FieldStatus, v)
|
u.Set(account.FieldStatus, v)
|
||||||
@@ -1325,6 +1368,27 @@ func (u *AccountUpsertOne) UpdatePriority() *AccountUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRateMultiplier sets the "rate_multiplier" field.
|
||||||
|
func (u *AccountUpsertOne) SetRateMultiplier(v float64) *AccountUpsertOne {
|
||||||
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
s.SetRateMultiplier(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRateMultiplier adds v to the "rate_multiplier" field.
|
||||||
|
func (u *AccountUpsertOne) AddRateMultiplier(v float64) *AccountUpsertOne {
|
||||||
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
s.AddRateMultiplier(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create.
|
||||||
|
func (u *AccountUpsertOne) UpdateRateMultiplier() *AccountUpsertOne {
|
||||||
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
s.UpdateRateMultiplier()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetStatus sets the "status" field.
|
// SetStatus sets the "status" field.
|
||||||
func (u *AccountUpsertOne) SetStatus(v string) *AccountUpsertOne {
|
func (u *AccountUpsertOne) SetStatus(v string) *AccountUpsertOne {
|
||||||
return u.Update(func(s *AccountUpsert) {
|
return u.Update(func(s *AccountUpsert) {
|
||||||
@@ -1956,6 +2020,27 @@ func (u *AccountUpsertBulk) UpdatePriority() *AccountUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRateMultiplier sets the "rate_multiplier" field.
|
||||||
|
func (u *AccountUpsertBulk) SetRateMultiplier(v float64) *AccountUpsertBulk {
|
||||||
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
s.SetRateMultiplier(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRateMultiplier adds v to the "rate_multiplier" field.
|
||||||
|
func (u *AccountUpsertBulk) AddRateMultiplier(v float64) *AccountUpsertBulk {
|
||||||
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
s.AddRateMultiplier(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create.
|
||||||
|
func (u *AccountUpsertBulk) UpdateRateMultiplier() *AccountUpsertBulk {
|
||||||
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
s.UpdateRateMultiplier()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetStatus sets the "status" field.
|
// SetStatus sets the "status" field.
|
||||||
func (u *AccountUpsertBulk) SetStatus(v string) *AccountUpsertBulk {
|
func (u *AccountUpsertBulk) SetStatus(v string) *AccountUpsertBulk {
|
||||||
return u.Update(func(s *AccountUpsert) {
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
|||||||
@@ -193,6 +193,27 @@ func (_u *AccountUpdate) AddPriority(v int) *AccountUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRateMultiplier sets the "rate_multiplier" field.
|
||||||
|
func (_u *AccountUpdate) SetRateMultiplier(v float64) *AccountUpdate {
|
||||||
|
_u.mutation.ResetRateMultiplier()
|
||||||
|
_u.mutation.SetRateMultiplier(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil.
|
||||||
|
func (_u *AccountUpdate) SetNillableRateMultiplier(v *float64) *AccountUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetRateMultiplier(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRateMultiplier adds value to the "rate_multiplier" field.
|
||||||
|
func (_u *AccountUpdate) AddRateMultiplier(v float64) *AccountUpdate {
|
||||||
|
_u.mutation.AddRateMultiplier(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetStatus sets the "status" field.
|
// SetStatus sets the "status" field.
|
||||||
func (_u *AccountUpdate) SetStatus(v string) *AccountUpdate {
|
func (_u *AccountUpdate) SetStatus(v string) *AccountUpdate {
|
||||||
_u.mutation.SetStatus(v)
|
_u.mutation.SetStatus(v)
|
||||||
@@ -629,6 +650,12 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if value, ok := _u.mutation.AddedPriority(); ok {
|
if value, ok := _u.mutation.AddedPriority(); ok {
|
||||||
_spec.AddField(account.FieldPriority, field.TypeInt, value)
|
_spec.AddField(account.FieldPriority, field.TypeInt, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.RateMultiplier(); ok {
|
||||||
|
_spec.SetField(account.FieldRateMultiplier, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
|
||||||
|
_spec.AddField(account.FieldRateMultiplier, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.Status(); ok {
|
if value, ok := _u.mutation.Status(); ok {
|
||||||
_spec.SetField(account.FieldStatus, field.TypeString, value)
|
_spec.SetField(account.FieldStatus, field.TypeString, value)
|
||||||
}
|
}
|
||||||
@@ -1005,6 +1032,27 @@ func (_u *AccountUpdateOne) AddPriority(v int) *AccountUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRateMultiplier sets the "rate_multiplier" field.
|
||||||
|
func (_u *AccountUpdateOne) SetRateMultiplier(v float64) *AccountUpdateOne {
|
||||||
|
_u.mutation.ResetRateMultiplier()
|
||||||
|
_u.mutation.SetRateMultiplier(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil.
|
||||||
|
func (_u *AccountUpdateOne) SetNillableRateMultiplier(v *float64) *AccountUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetRateMultiplier(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRateMultiplier adds value to the "rate_multiplier" field.
|
||||||
|
func (_u *AccountUpdateOne) AddRateMultiplier(v float64) *AccountUpdateOne {
|
||||||
|
_u.mutation.AddRateMultiplier(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetStatus sets the "status" field.
|
// SetStatus sets the "status" field.
|
||||||
func (_u *AccountUpdateOne) SetStatus(v string) *AccountUpdateOne {
|
func (_u *AccountUpdateOne) SetStatus(v string) *AccountUpdateOne {
|
||||||
_u.mutation.SetStatus(v)
|
_u.mutation.SetStatus(v)
|
||||||
@@ -1471,6 +1519,12 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er
|
|||||||
if value, ok := _u.mutation.AddedPriority(); ok {
|
if value, ok := _u.mutation.AddedPriority(); ok {
|
||||||
_spec.AddField(account.FieldPriority, field.TypeInt, value)
|
_spec.AddField(account.FieldPriority, field.TypeInt, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.RateMultiplier(); ok {
|
||||||
|
_spec.SetField(account.FieldRateMultiplier, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
|
||||||
|
_spec.AddField(account.FieldRateMultiplier, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.Status(); ok {
|
if value, ok := _u.mutation.Status(); ok {
|
||||||
_spec.SetField(account.FieldStatus, field.TypeString, value)
|
_spec.SetField(account.FieldStatus, field.TypeString, value)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
package ent
|
package ent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -55,6 +56,10 @@ type Group struct {
|
|||||||
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
|
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
|
||||||
// 非 Claude Code 请求降级使用的分组 ID
|
// 非 Claude Code 请求降级使用的分组 ID
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||||
|
// 模型路由配置:模型模式 -> 优先账号ID列表
|
||||||
|
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
||||||
|
// 是否启用模型路由配置
|
||||||
|
ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"`
|
||||||
// Edges holds the relations/edges for other nodes in the graph.
|
// Edges holds the relations/edges for other nodes in the graph.
|
||||||
// The values are being populated by the GroupQuery when eager-loading is set.
|
// The values are being populated by the GroupQuery when eager-loading is set.
|
||||||
Edges GroupEdges `json:"edges"`
|
Edges GroupEdges `json:"edges"`
|
||||||
@@ -161,7 +166,9 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
|||||||
values := make([]any, len(columns))
|
values := make([]any, len(columns))
|
||||||
for i := range columns {
|
for i := range columns {
|
||||||
switch columns[i] {
|
switch columns[i] {
|
||||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly:
|
case group.FieldModelRouting:
|
||||||
|
values[i] = new([]byte)
|
||||||
|
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled:
|
||||||
values[i] = new(sql.NullBool)
|
values[i] = new(sql.NullBool)
|
||||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
|
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
|
||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
@@ -315,6 +322,20 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
|||||||
_m.FallbackGroupID = new(int64)
|
_m.FallbackGroupID = new(int64)
|
||||||
*_m.FallbackGroupID = value.Int64
|
*_m.FallbackGroupID = value.Int64
|
||||||
}
|
}
|
||||||
|
case group.FieldModelRouting:
|
||||||
|
if value, ok := values[i].(*[]byte); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field model_routing", values[i])
|
||||||
|
} else if value != nil && len(*value) > 0 {
|
||||||
|
if err := json.Unmarshal(*value, &_m.ModelRouting); err != nil {
|
||||||
|
return fmt.Errorf("unmarshal field model_routing: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case group.FieldModelRoutingEnabled:
|
||||||
|
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field model_routing_enabled", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.ModelRoutingEnabled = value.Bool
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
_m.selectValues.Set(columns[i], values[i])
|
_m.selectValues.Set(columns[i], values[i])
|
||||||
}
|
}
|
||||||
@@ -465,6 +486,12 @@ func (_m *Group) String() string {
|
|||||||
builder.WriteString("fallback_group_id=")
|
builder.WriteString("fallback_group_id=")
|
||||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||||
}
|
}
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("model_routing=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.ModelRouting))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("model_routing_enabled=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.ModelRoutingEnabled))
|
||||||
builder.WriteByte(')')
|
builder.WriteByte(')')
|
||||||
return builder.String()
|
return builder.String()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,6 +53,10 @@ const (
|
|||||||
FieldClaudeCodeOnly = "claude_code_only"
|
FieldClaudeCodeOnly = "claude_code_only"
|
||||||
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
|
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
|
||||||
FieldFallbackGroupID = "fallback_group_id"
|
FieldFallbackGroupID = "fallback_group_id"
|
||||||
|
// FieldModelRouting holds the string denoting the model_routing field in the database.
|
||||||
|
FieldModelRouting = "model_routing"
|
||||||
|
// FieldModelRoutingEnabled holds the string denoting the model_routing_enabled field in the database.
|
||||||
|
FieldModelRoutingEnabled = "model_routing_enabled"
|
||||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||||
EdgeAPIKeys = "api_keys"
|
EdgeAPIKeys = "api_keys"
|
||||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||||
@@ -147,6 +151,8 @@ var Columns = []string{
|
|||||||
FieldImagePrice4k,
|
FieldImagePrice4k,
|
||||||
FieldClaudeCodeOnly,
|
FieldClaudeCodeOnly,
|
||||||
FieldFallbackGroupID,
|
FieldFallbackGroupID,
|
||||||
|
FieldModelRouting,
|
||||||
|
FieldModelRoutingEnabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -204,6 +210,8 @@ var (
|
|||||||
DefaultDefaultValidityDays int
|
DefaultDefaultValidityDays int
|
||||||
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
|
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
|
||||||
DefaultClaudeCodeOnly bool
|
DefaultClaudeCodeOnly bool
|
||||||
|
// DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field.
|
||||||
|
DefaultModelRoutingEnabled bool
|
||||||
)
|
)
|
||||||
|
|
||||||
// OrderOption defines the ordering options for the Group queries.
|
// OrderOption defines the ordering options for the Group queries.
|
||||||
@@ -309,6 +317,11 @@ func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc()
|
return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ByModelRoutingEnabled orders the results by the model_routing_enabled field.
|
||||||
|
func ByModelRoutingEnabled(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldModelRoutingEnabled, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// ByAPIKeysCount orders the results by api_keys count.
|
// ByAPIKeysCount orders the results by api_keys count.
|
||||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return func(s *sql.Selector) {
|
return func(s *sql.Selector) {
|
||||||
|
|||||||
@@ -150,6 +150,11 @@ func FallbackGroupID(v int64) predicate.Group {
|
|||||||
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
|
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ModelRoutingEnabled applies equality check predicate on the "model_routing_enabled" field. It's identical to ModelRoutingEnabledEQ.
|
||||||
|
func ModelRoutingEnabled(v bool) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v))
|
||||||
|
}
|
||||||
|
|
||||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||||
func CreatedAtEQ(v time.Time) predicate.Group {
|
func CreatedAtEQ(v time.Time) predicate.Group {
|
||||||
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
||||||
@@ -1065,6 +1070,26 @@ func FallbackGroupIDNotNil() predicate.Group {
|
|||||||
return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID))
|
return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ModelRoutingIsNil applies the IsNil predicate on the "model_routing" field.
|
||||||
|
func ModelRoutingIsNil() predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldIsNull(FieldModelRouting))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelRoutingNotNil applies the NotNil predicate on the "model_routing" field.
|
||||||
|
func ModelRoutingNotNil() predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldNotNull(FieldModelRouting))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelRoutingEnabledEQ applies the EQ predicate on the "model_routing_enabled" field.
|
||||||
|
func ModelRoutingEnabledEQ(v bool) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelRoutingEnabledNEQ applies the NEQ predicate on the "model_routing_enabled" field.
|
||||||
|
func ModelRoutingEnabledNEQ(v bool) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldNEQ(FieldModelRoutingEnabled, v))
|
||||||
|
}
|
||||||
|
|
||||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||||
func HasAPIKeys() predicate.Group {
|
func HasAPIKeys() predicate.Group {
|
||||||
return predicate.Group(func(s *sql.Selector) {
|
return predicate.Group(func(s *sql.Selector) {
|
||||||
|
|||||||
@@ -286,6 +286,26 @@ func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetModelRouting sets the "model_routing" field.
|
||||||
|
func (_c *GroupCreate) SetModelRouting(v map[string][]int64) *GroupCreate {
|
||||||
|
_c.mutation.SetModelRouting(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
|
||||||
|
func (_c *GroupCreate) SetModelRoutingEnabled(v bool) *GroupCreate {
|
||||||
|
_c.mutation.SetModelRoutingEnabled(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableModelRoutingEnabled sets the "model_routing_enabled" field if the given value is not nil.
|
||||||
|
func (_c *GroupCreate) SetNillableModelRoutingEnabled(v *bool) *GroupCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetModelRoutingEnabled(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
||||||
_c.mutation.AddAPIKeyIDs(ids...)
|
_c.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -455,6 +475,10 @@ func (_c *GroupCreate) defaults() error {
|
|||||||
v := group.DefaultClaudeCodeOnly
|
v := group.DefaultClaudeCodeOnly
|
||||||
_c.mutation.SetClaudeCodeOnly(v)
|
_c.mutation.SetClaudeCodeOnly(v)
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.ModelRoutingEnabled(); !ok {
|
||||||
|
v := group.DefaultModelRoutingEnabled
|
||||||
|
_c.mutation.SetModelRoutingEnabled(v)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -510,6 +534,9 @@ func (_c *GroupCreate) check() error {
|
|||||||
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
|
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
|
||||||
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
|
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.ModelRoutingEnabled(); !ok {
|
||||||
|
return &ValidationError{Name: "model_routing_enabled", err: errors.New(`ent: missing required field "Group.model_routing_enabled"`)}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -613,6 +640,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
|
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
|
||||||
_node.FallbackGroupID = &value
|
_node.FallbackGroupID = &value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.ModelRouting(); ok {
|
||||||
|
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
|
||||||
|
_node.ModelRouting = value
|
||||||
|
}
|
||||||
|
if value, ok := _c.mutation.ModelRoutingEnabled(); ok {
|
||||||
|
_spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
|
||||||
|
_node.ModelRoutingEnabled = value
|
||||||
|
}
|
||||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
@@ -1093,6 +1128,36 @@ func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetModelRouting sets the "model_routing" field.
|
||||||
|
func (u *GroupUpsert) SetModelRouting(v map[string][]int64) *GroupUpsert {
|
||||||
|
u.Set(group.FieldModelRouting, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateModelRouting sets the "model_routing" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsert) UpdateModelRouting() *GroupUpsert {
|
||||||
|
u.SetExcluded(group.FieldModelRouting)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearModelRouting clears the value of the "model_routing" field.
|
||||||
|
func (u *GroupUpsert) ClearModelRouting() *GroupUpsert {
|
||||||
|
u.SetNull(group.FieldModelRouting)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
|
||||||
|
func (u *GroupUpsert) SetModelRoutingEnabled(v bool) *GroupUpsert {
|
||||||
|
u.Set(group.FieldModelRoutingEnabled, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateModelRoutingEnabled sets the "model_routing_enabled" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsert) UpdateModelRoutingEnabled() *GroupUpsert {
|
||||||
|
u.SetExcluded(group.FieldModelRoutingEnabled)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||||
// Using this option is equivalent to using:
|
// Using this option is equivalent to using:
|
||||||
//
|
//
|
||||||
@@ -1516,6 +1581,41 @@ func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetModelRouting sets the "model_routing" field.
|
||||||
|
func (u *GroupUpsertOne) SetModelRouting(v map[string][]int64) *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.SetModelRouting(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateModelRouting sets the "model_routing" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsertOne) UpdateModelRouting() *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.UpdateModelRouting()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearModelRouting clears the value of the "model_routing" field.
|
||||||
|
func (u *GroupUpsertOne) ClearModelRouting() *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.ClearModelRouting()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
|
||||||
|
func (u *GroupUpsertOne) SetModelRoutingEnabled(v bool) *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.SetModelRoutingEnabled(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateModelRoutingEnabled sets the "model_routing_enabled" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsertOne) UpdateModelRoutingEnabled() *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.UpdateModelRoutingEnabled()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
||||||
if len(u.create.conflict) == 0 {
|
if len(u.create.conflict) == 0 {
|
||||||
@@ -2105,6 +2205,41 @@ func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetModelRouting sets the "model_routing" field.
|
||||||
|
func (u *GroupUpsertBulk) SetModelRouting(v map[string][]int64) *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.SetModelRouting(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateModelRouting sets the "model_routing" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsertBulk) UpdateModelRouting() *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.UpdateModelRouting()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearModelRouting clears the value of the "model_routing" field.
|
||||||
|
func (u *GroupUpsertBulk) ClearModelRouting() *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.ClearModelRouting()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
|
||||||
|
func (u *GroupUpsertBulk) SetModelRoutingEnabled(v bool) *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.SetModelRoutingEnabled(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateModelRoutingEnabled sets the "model_routing_enabled" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsertBulk) UpdateModelRoutingEnabled() *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.UpdateModelRoutingEnabled()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
||||||
if u.create.err != nil {
|
if u.create.err != nil {
|
||||||
|
|||||||
@@ -395,6 +395,32 @@ func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetModelRouting sets the "model_routing" field.
|
||||||
|
func (_u *GroupUpdate) SetModelRouting(v map[string][]int64) *GroupUpdate {
|
||||||
|
_u.mutation.SetModelRouting(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearModelRouting clears the value of the "model_routing" field.
|
||||||
|
func (_u *GroupUpdate) ClearModelRouting() *GroupUpdate {
|
||||||
|
_u.mutation.ClearModelRouting()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
|
||||||
|
func (_u *GroupUpdate) SetModelRoutingEnabled(v bool) *GroupUpdate {
|
||||||
|
_u.mutation.SetModelRoutingEnabled(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableModelRoutingEnabled sets the "model_routing_enabled" field if the given value is not nil.
|
||||||
|
func (_u *GroupUpdate) SetNillableModelRoutingEnabled(v *bool) *GroupUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetModelRoutingEnabled(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
||||||
_u.mutation.AddAPIKeyIDs(ids...)
|
_u.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -803,6 +829,15 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if _u.mutation.FallbackGroupIDCleared() {
|
if _u.mutation.FallbackGroupIDCleared() {
|
||||||
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
|
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.ModelRouting(); ok {
|
||||||
|
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.ModelRoutingCleared() {
|
||||||
|
_spec.ClearField(group.FieldModelRouting, field.TypeJSON)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.ModelRoutingEnabled(); ok {
|
||||||
|
_spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
|
||||||
|
}
|
||||||
if _u.mutation.APIKeysCleared() {
|
if _u.mutation.APIKeysCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
@@ -1478,6 +1513,32 @@ func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetModelRouting sets the "model_routing" field.
|
||||||
|
func (_u *GroupUpdateOne) SetModelRouting(v map[string][]int64) *GroupUpdateOne {
|
||||||
|
_u.mutation.SetModelRouting(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearModelRouting clears the value of the "model_routing" field.
|
||||||
|
func (_u *GroupUpdateOne) ClearModelRouting() *GroupUpdateOne {
|
||||||
|
_u.mutation.ClearModelRouting()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
|
||||||
|
func (_u *GroupUpdateOne) SetModelRoutingEnabled(v bool) *GroupUpdateOne {
|
||||||
|
_u.mutation.SetModelRoutingEnabled(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableModelRoutingEnabled sets the "model_routing_enabled" field if the given value is not nil.
|
||||||
|
func (_u *GroupUpdateOne) SetNillableModelRoutingEnabled(v *bool) *GroupUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetModelRoutingEnabled(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
||||||
_u.mutation.AddAPIKeyIDs(ids...)
|
_u.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -1916,6 +1977,15 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
|||||||
if _u.mutation.FallbackGroupIDCleared() {
|
if _u.mutation.FallbackGroupIDCleared() {
|
||||||
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
|
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.ModelRouting(); ok {
|
||||||
|
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.ModelRoutingCleared() {
|
||||||
|
_spec.ClearField(group.FieldModelRouting, field.TypeJSON)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.ModelRoutingEnabled(); ok {
|
||||||
|
_spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
|
||||||
|
}
|
||||||
if _u.mutation.APIKeysCleared() {
|
if _u.mutation.APIKeysCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ var (
|
|||||||
{Name: "extra", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
{Name: "extra", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||||
{Name: "concurrency", Type: field.TypeInt, Default: 3},
|
{Name: "concurrency", Type: field.TypeInt, Default: 3},
|
||||||
{Name: "priority", Type: field.TypeInt, Default: 50},
|
{Name: "priority", Type: field.TypeInt, Default: 50},
|
||||||
|
{Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
|
||||||
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
|
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
|
||||||
{Name: "error_message", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
|
{Name: "error_message", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
|
||||||
{Name: "last_used_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
{Name: "last_used_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||||
@@ -101,7 +102,7 @@ var (
|
|||||||
ForeignKeys: []*schema.ForeignKey{
|
ForeignKeys: []*schema.ForeignKey{
|
||||||
{
|
{
|
||||||
Symbol: "accounts_proxies_proxy",
|
Symbol: "accounts_proxies_proxy",
|
||||||
Columns: []*schema.Column{AccountsColumns[24]},
|
Columns: []*schema.Column{AccountsColumns[25]},
|
||||||
RefColumns: []*schema.Column{ProxiesColumns[0]},
|
RefColumns: []*schema.Column{ProxiesColumns[0]},
|
||||||
OnDelete: schema.SetNull,
|
OnDelete: schema.SetNull,
|
||||||
},
|
},
|
||||||
@@ -120,12 +121,12 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "account_status",
|
Name: "account_status",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{AccountsColumns[12]},
|
Columns: []*schema.Column{AccountsColumns[13]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "account_proxy_id",
|
Name: "account_proxy_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{AccountsColumns[24]},
|
Columns: []*schema.Column{AccountsColumns[25]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "account_priority",
|
Name: "account_priority",
|
||||||
@@ -135,27 +136,27 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "account_last_used_at",
|
Name: "account_last_used_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{AccountsColumns[14]},
|
Columns: []*schema.Column{AccountsColumns[15]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "account_schedulable",
|
Name: "account_schedulable",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{AccountsColumns[17]},
|
Columns: []*schema.Column{AccountsColumns[18]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "account_rate_limited_at",
|
Name: "account_rate_limited_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{AccountsColumns[18]},
|
Columns: []*schema.Column{AccountsColumns[19]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "account_rate_limit_reset_at",
|
Name: "account_rate_limit_reset_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{AccountsColumns[19]},
|
Columns: []*schema.Column{AccountsColumns[20]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "account_overload_until",
|
Name: "account_overload_until",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{AccountsColumns[20]},
|
Columns: []*schema.Column{AccountsColumns[21]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "account_deleted_at",
|
Name: "account_deleted_at",
|
||||||
@@ -225,6 +226,8 @@ var (
|
|||||||
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||||
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
|
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
|
||||||
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
|
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
|
||||||
|
{Name: "model_routing", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||||
|
{Name: "model_routing_enabled", Type: field.TypeBool, Default: false},
|
||||||
}
|
}
|
||||||
// GroupsTable holds the schema information for the "groups" table.
|
// GroupsTable holds the schema information for the "groups" table.
|
||||||
GroupsTable = &schema.Table{
|
GroupsTable = &schema.Table{
|
||||||
@@ -449,6 +452,7 @@ var (
|
|||||||
{Name: "total_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
|
{Name: "total_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
|
||||||
{Name: "actual_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
|
{Name: "actual_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
|
||||||
{Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
|
{Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
|
||||||
|
{Name: "account_rate_multiplier", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
|
||||||
{Name: "billing_type", Type: field.TypeInt8, Default: 0},
|
{Name: "billing_type", Type: field.TypeInt8, Default: 0},
|
||||||
{Name: "stream", Type: field.TypeBool, Default: false},
|
{Name: "stream", Type: field.TypeBool, Default: false},
|
||||||
{Name: "duration_ms", Type: field.TypeInt, Nullable: true},
|
{Name: "duration_ms", Type: field.TypeInt, Nullable: true},
|
||||||
@@ -472,31 +476,31 @@ var (
|
|||||||
ForeignKeys: []*schema.ForeignKey{
|
ForeignKeys: []*schema.ForeignKey{
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_api_keys_usage_logs",
|
Symbol: "usage_logs_api_keys_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[25]},
|
Columns: []*schema.Column{UsageLogsColumns[26]},
|
||||||
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_accounts_usage_logs",
|
Symbol: "usage_logs_accounts_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[26]},
|
Columns: []*schema.Column{UsageLogsColumns[27]},
|
||||||
RefColumns: []*schema.Column{AccountsColumns[0]},
|
RefColumns: []*schema.Column{AccountsColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_groups_usage_logs",
|
Symbol: "usage_logs_groups_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||||
RefColumns: []*schema.Column{GroupsColumns[0]},
|
RefColumns: []*schema.Column{GroupsColumns[0]},
|
||||||
OnDelete: schema.SetNull,
|
OnDelete: schema.SetNull,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_users_usage_logs",
|
Symbol: "usage_logs_users_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||||
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
||||||
OnDelete: schema.SetNull,
|
OnDelete: schema.SetNull,
|
||||||
},
|
},
|
||||||
@@ -505,32 +509,32 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "usagelog_user_id",
|
Name: "usagelog_user_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_api_key_id",
|
Name: "usagelog_api_key_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[25]},
|
Columns: []*schema.Column{UsageLogsColumns[26]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_account_id",
|
Name: "usagelog_account_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[26]},
|
Columns: []*schema.Column{UsageLogsColumns[27]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_group_id",
|
Name: "usagelog_group_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_subscription_id",
|
Name: "usagelog_subscription_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_created_at",
|
Name: "usagelog_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[24]},
|
Columns: []*schema.Column{UsageLogsColumns[25]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_model",
|
Name: "usagelog_model",
|
||||||
@@ -545,12 +549,12 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "usagelog_user_id_created_at",
|
Name: "usagelog_user_id_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[24]},
|
Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[25]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_api_key_id_created_at",
|
Name: "usagelog_api_key_id_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[25], UsageLogsColumns[24]},
|
Columns: []*schema.Column{UsageLogsColumns[26], UsageLogsColumns[25]},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1187,6 +1187,8 @@ type AccountMutation struct {
|
|||||||
addconcurrency *int
|
addconcurrency *int
|
||||||
priority *int
|
priority *int
|
||||||
addpriority *int
|
addpriority *int
|
||||||
|
rate_multiplier *float64
|
||||||
|
addrate_multiplier *float64
|
||||||
status *string
|
status *string
|
||||||
error_message *string
|
error_message *string
|
||||||
last_used_at *time.Time
|
last_used_at *time.Time
|
||||||
@@ -1822,6 +1824,62 @@ func (m *AccountMutation) ResetPriority() {
|
|||||||
m.addpriority = nil
|
m.addpriority = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRateMultiplier sets the "rate_multiplier" field.
|
||||||
|
func (m *AccountMutation) SetRateMultiplier(f float64) {
|
||||||
|
m.rate_multiplier = &f
|
||||||
|
m.addrate_multiplier = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateMultiplier returns the value of the "rate_multiplier" field in the mutation.
|
||||||
|
func (m *AccountMutation) RateMultiplier() (r float64, exists bool) {
|
||||||
|
v := m.rate_multiplier
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldRateMultiplier returns the old "rate_multiplier" field's value of the Account entity.
|
||||||
|
// If the Account object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *AccountMutation) OldRateMultiplier(ctx context.Context) (v float64, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldRateMultiplier is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldRateMultiplier requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldRateMultiplier: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.RateMultiplier, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRateMultiplier adds f to the "rate_multiplier" field.
|
||||||
|
func (m *AccountMutation) AddRateMultiplier(f float64) {
|
||||||
|
if m.addrate_multiplier != nil {
|
||||||
|
*m.addrate_multiplier += f
|
||||||
|
} else {
|
||||||
|
m.addrate_multiplier = &f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddedRateMultiplier returns the value that was added to the "rate_multiplier" field in this mutation.
|
||||||
|
func (m *AccountMutation) AddedRateMultiplier() (r float64, exists bool) {
|
||||||
|
v := m.addrate_multiplier
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetRateMultiplier resets all changes to the "rate_multiplier" field.
|
||||||
|
func (m *AccountMutation) ResetRateMultiplier() {
|
||||||
|
m.rate_multiplier = nil
|
||||||
|
m.addrate_multiplier = nil
|
||||||
|
}
|
||||||
|
|
||||||
// SetStatus sets the "status" field.
|
// SetStatus sets the "status" field.
|
||||||
func (m *AccountMutation) SetStatus(s string) {
|
func (m *AccountMutation) SetStatus(s string) {
|
||||||
m.status = &s
|
m.status = &s
|
||||||
@@ -2540,7 +2598,7 @@ func (m *AccountMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *AccountMutation) Fields() []string {
|
func (m *AccountMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 24)
|
fields := make([]string, 0, 25)
|
||||||
if m.created_at != nil {
|
if m.created_at != nil {
|
||||||
fields = append(fields, account.FieldCreatedAt)
|
fields = append(fields, account.FieldCreatedAt)
|
||||||
}
|
}
|
||||||
@@ -2577,6 +2635,9 @@ func (m *AccountMutation) Fields() []string {
|
|||||||
if m.priority != nil {
|
if m.priority != nil {
|
||||||
fields = append(fields, account.FieldPriority)
|
fields = append(fields, account.FieldPriority)
|
||||||
}
|
}
|
||||||
|
if m.rate_multiplier != nil {
|
||||||
|
fields = append(fields, account.FieldRateMultiplier)
|
||||||
|
}
|
||||||
if m.status != nil {
|
if m.status != nil {
|
||||||
fields = append(fields, account.FieldStatus)
|
fields = append(fields, account.FieldStatus)
|
||||||
}
|
}
|
||||||
@@ -2645,6 +2706,8 @@ func (m *AccountMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.Concurrency()
|
return m.Concurrency()
|
||||||
case account.FieldPriority:
|
case account.FieldPriority:
|
||||||
return m.Priority()
|
return m.Priority()
|
||||||
|
case account.FieldRateMultiplier:
|
||||||
|
return m.RateMultiplier()
|
||||||
case account.FieldStatus:
|
case account.FieldStatus:
|
||||||
return m.Status()
|
return m.Status()
|
||||||
case account.FieldErrorMessage:
|
case account.FieldErrorMessage:
|
||||||
@@ -2702,6 +2765,8 @@ func (m *AccountMutation) OldField(ctx context.Context, name string) (ent.Value,
|
|||||||
return m.OldConcurrency(ctx)
|
return m.OldConcurrency(ctx)
|
||||||
case account.FieldPriority:
|
case account.FieldPriority:
|
||||||
return m.OldPriority(ctx)
|
return m.OldPriority(ctx)
|
||||||
|
case account.FieldRateMultiplier:
|
||||||
|
return m.OldRateMultiplier(ctx)
|
||||||
case account.FieldStatus:
|
case account.FieldStatus:
|
||||||
return m.OldStatus(ctx)
|
return m.OldStatus(ctx)
|
||||||
case account.FieldErrorMessage:
|
case account.FieldErrorMessage:
|
||||||
@@ -2819,6 +2884,13 @@ func (m *AccountMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetPriority(v)
|
m.SetPriority(v)
|
||||||
return nil
|
return nil
|
||||||
|
case account.FieldRateMultiplier:
|
||||||
|
v, ok := value.(float64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetRateMultiplier(v)
|
||||||
|
return nil
|
||||||
case account.FieldStatus:
|
case account.FieldStatus:
|
||||||
v, ok := value.(string)
|
v, ok := value.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -2917,6 +2989,9 @@ func (m *AccountMutation) AddedFields() []string {
|
|||||||
if m.addpriority != nil {
|
if m.addpriority != nil {
|
||||||
fields = append(fields, account.FieldPriority)
|
fields = append(fields, account.FieldPriority)
|
||||||
}
|
}
|
||||||
|
if m.addrate_multiplier != nil {
|
||||||
|
fields = append(fields, account.FieldRateMultiplier)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2929,6 +3004,8 @@ func (m *AccountMutation) AddedField(name string) (ent.Value, bool) {
|
|||||||
return m.AddedConcurrency()
|
return m.AddedConcurrency()
|
||||||
case account.FieldPriority:
|
case account.FieldPriority:
|
||||||
return m.AddedPriority()
|
return m.AddedPriority()
|
||||||
|
case account.FieldRateMultiplier:
|
||||||
|
return m.AddedRateMultiplier()
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@@ -2952,6 +3029,13 @@ func (m *AccountMutation) AddField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.AddPriority(v)
|
m.AddPriority(v)
|
||||||
return nil
|
return nil
|
||||||
|
case account.FieldRateMultiplier:
|
||||||
|
v, ok := value.(float64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.AddRateMultiplier(v)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown Account numeric field %s", name)
|
return fmt.Errorf("unknown Account numeric field %s", name)
|
||||||
}
|
}
|
||||||
@@ -3090,6 +3174,9 @@ func (m *AccountMutation) ResetField(name string) error {
|
|||||||
case account.FieldPriority:
|
case account.FieldPriority:
|
||||||
m.ResetPriority()
|
m.ResetPriority()
|
||||||
return nil
|
return nil
|
||||||
|
case account.FieldRateMultiplier:
|
||||||
|
m.ResetRateMultiplier()
|
||||||
|
return nil
|
||||||
case account.FieldStatus:
|
case account.FieldStatus:
|
||||||
m.ResetStatus()
|
m.ResetStatus()
|
||||||
return nil
|
return nil
|
||||||
@@ -3777,6 +3864,8 @@ type GroupMutation struct {
|
|||||||
claude_code_only *bool
|
claude_code_only *bool
|
||||||
fallback_group_id *int64
|
fallback_group_id *int64
|
||||||
addfallback_group_id *int64
|
addfallback_group_id *int64
|
||||||
|
model_routing *map[string][]int64
|
||||||
|
model_routing_enabled *bool
|
||||||
clearedFields map[string]struct{}
|
clearedFields map[string]struct{}
|
||||||
api_keys map[int64]struct{}
|
api_keys map[int64]struct{}
|
||||||
removedapi_keys map[int64]struct{}
|
removedapi_keys map[int64]struct{}
|
||||||
@@ -4887,6 +4976,91 @@ func (m *GroupMutation) ResetFallbackGroupID() {
|
|||||||
delete(m.clearedFields, group.FieldFallbackGroupID)
|
delete(m.clearedFields, group.FieldFallbackGroupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetModelRouting sets the "model_routing" field.
|
||||||
|
func (m *GroupMutation) SetModelRouting(value map[string][]int64) {
|
||||||
|
m.model_routing = &value
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelRouting returns the value of the "model_routing" field in the mutation.
|
||||||
|
func (m *GroupMutation) ModelRouting() (r map[string][]int64, exists bool) {
|
||||||
|
v := m.model_routing
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldModelRouting returns the old "model_routing" field's value of the Group entity.
|
||||||
|
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *GroupMutation) OldModelRouting(ctx context.Context) (v map[string][]int64, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldModelRouting is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldModelRouting requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldModelRouting: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.ModelRouting, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearModelRouting clears the value of the "model_routing" field.
|
||||||
|
func (m *GroupMutation) ClearModelRouting() {
|
||||||
|
m.model_routing = nil
|
||||||
|
m.clearedFields[group.FieldModelRouting] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelRoutingCleared returns if the "model_routing" field was cleared in this mutation.
|
||||||
|
func (m *GroupMutation) ModelRoutingCleared() bool {
|
||||||
|
_, ok := m.clearedFields[group.FieldModelRouting]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetModelRouting resets all changes to the "model_routing" field.
|
||||||
|
func (m *GroupMutation) ResetModelRouting() {
|
||||||
|
m.model_routing = nil
|
||||||
|
delete(m.clearedFields, group.FieldModelRouting)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
|
||||||
|
func (m *GroupMutation) SetModelRoutingEnabled(b bool) {
|
||||||
|
m.model_routing_enabled = &b
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelRoutingEnabled returns the value of the "model_routing_enabled" field in the mutation.
|
||||||
|
func (m *GroupMutation) ModelRoutingEnabled() (r bool, exists bool) {
|
||||||
|
v := m.model_routing_enabled
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldModelRoutingEnabled returns the old "model_routing_enabled" field's value of the Group entity.
|
||||||
|
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *GroupMutation) OldModelRoutingEnabled(ctx context.Context) (v bool, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldModelRoutingEnabled is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldModelRoutingEnabled requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldModelRoutingEnabled: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.ModelRoutingEnabled, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetModelRoutingEnabled resets all changes to the "model_routing_enabled" field.
|
||||||
|
func (m *GroupMutation) ResetModelRoutingEnabled() {
|
||||||
|
m.model_routing_enabled = nil
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
||||||
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
|
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
|
||||||
if m.api_keys == nil {
|
if m.api_keys == nil {
|
||||||
@@ -5245,7 +5419,7 @@ func (m *GroupMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *GroupMutation) Fields() []string {
|
func (m *GroupMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 19)
|
fields := make([]string, 0, 21)
|
||||||
if m.created_at != nil {
|
if m.created_at != nil {
|
||||||
fields = append(fields, group.FieldCreatedAt)
|
fields = append(fields, group.FieldCreatedAt)
|
||||||
}
|
}
|
||||||
@@ -5303,6 +5477,12 @@ func (m *GroupMutation) Fields() []string {
|
|||||||
if m.fallback_group_id != nil {
|
if m.fallback_group_id != nil {
|
||||||
fields = append(fields, group.FieldFallbackGroupID)
|
fields = append(fields, group.FieldFallbackGroupID)
|
||||||
}
|
}
|
||||||
|
if m.model_routing != nil {
|
||||||
|
fields = append(fields, group.FieldModelRouting)
|
||||||
|
}
|
||||||
|
if m.model_routing_enabled != nil {
|
||||||
|
fields = append(fields, group.FieldModelRoutingEnabled)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -5349,6 +5529,10 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.ClaudeCodeOnly()
|
return m.ClaudeCodeOnly()
|
||||||
case group.FieldFallbackGroupID:
|
case group.FieldFallbackGroupID:
|
||||||
return m.FallbackGroupID()
|
return m.FallbackGroupID()
|
||||||
|
case group.FieldModelRouting:
|
||||||
|
return m.ModelRouting()
|
||||||
|
case group.FieldModelRoutingEnabled:
|
||||||
|
return m.ModelRoutingEnabled()
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@@ -5396,6 +5580,10 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
|
|||||||
return m.OldClaudeCodeOnly(ctx)
|
return m.OldClaudeCodeOnly(ctx)
|
||||||
case group.FieldFallbackGroupID:
|
case group.FieldFallbackGroupID:
|
||||||
return m.OldFallbackGroupID(ctx)
|
return m.OldFallbackGroupID(ctx)
|
||||||
|
case group.FieldModelRouting:
|
||||||
|
return m.OldModelRouting(ctx)
|
||||||
|
case group.FieldModelRoutingEnabled:
|
||||||
|
return m.OldModelRoutingEnabled(ctx)
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unknown Group field %s", name)
|
return nil, fmt.Errorf("unknown Group field %s", name)
|
||||||
}
|
}
|
||||||
@@ -5538,6 +5726,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetFallbackGroupID(v)
|
m.SetFallbackGroupID(v)
|
||||||
return nil
|
return nil
|
||||||
|
case group.FieldModelRouting:
|
||||||
|
v, ok := value.(map[string][]int64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetModelRouting(v)
|
||||||
|
return nil
|
||||||
|
case group.FieldModelRoutingEnabled:
|
||||||
|
v, ok := value.(bool)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetModelRoutingEnabled(v)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown Group field %s", name)
|
return fmt.Errorf("unknown Group field %s", name)
|
||||||
}
|
}
|
||||||
@@ -5706,6 +5908,9 @@ func (m *GroupMutation) ClearedFields() []string {
|
|||||||
if m.FieldCleared(group.FieldFallbackGroupID) {
|
if m.FieldCleared(group.FieldFallbackGroupID) {
|
||||||
fields = append(fields, group.FieldFallbackGroupID)
|
fields = append(fields, group.FieldFallbackGroupID)
|
||||||
}
|
}
|
||||||
|
if m.FieldCleared(group.FieldModelRouting) {
|
||||||
|
fields = append(fields, group.FieldModelRouting)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -5747,6 +5952,9 @@ func (m *GroupMutation) ClearField(name string) error {
|
|||||||
case group.FieldFallbackGroupID:
|
case group.FieldFallbackGroupID:
|
||||||
m.ClearFallbackGroupID()
|
m.ClearFallbackGroupID()
|
||||||
return nil
|
return nil
|
||||||
|
case group.FieldModelRouting:
|
||||||
|
m.ClearModelRouting()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown Group nullable field %s", name)
|
return fmt.Errorf("unknown Group nullable field %s", name)
|
||||||
}
|
}
|
||||||
@@ -5812,6 +6020,12 @@ func (m *GroupMutation) ResetField(name string) error {
|
|||||||
case group.FieldFallbackGroupID:
|
case group.FieldFallbackGroupID:
|
||||||
m.ResetFallbackGroupID()
|
m.ResetFallbackGroupID()
|
||||||
return nil
|
return nil
|
||||||
|
case group.FieldModelRouting:
|
||||||
|
m.ResetModelRouting()
|
||||||
|
return nil
|
||||||
|
case group.FieldModelRoutingEnabled:
|
||||||
|
m.ResetModelRoutingEnabled()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown Group field %s", name)
|
return fmt.Errorf("unknown Group field %s", name)
|
||||||
}
|
}
|
||||||
@@ -10190,6 +10404,8 @@ type UsageLogMutation struct {
|
|||||||
addactual_cost *float64
|
addactual_cost *float64
|
||||||
rate_multiplier *float64
|
rate_multiplier *float64
|
||||||
addrate_multiplier *float64
|
addrate_multiplier *float64
|
||||||
|
account_rate_multiplier *float64
|
||||||
|
addaccount_rate_multiplier *float64
|
||||||
billing_type *int8
|
billing_type *int8
|
||||||
addbilling_type *int8
|
addbilling_type *int8
|
||||||
stream *bool
|
stream *bool
|
||||||
@@ -11323,6 +11539,76 @@ func (m *UsageLogMutation) ResetRateMultiplier() {
|
|||||||
m.addrate_multiplier = nil
|
m.addrate_multiplier = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
|
||||||
|
func (m *UsageLogMutation) SetAccountRateMultiplier(f float64) {
|
||||||
|
m.account_rate_multiplier = &f
|
||||||
|
m.addaccount_rate_multiplier = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplier returns the value of the "account_rate_multiplier" field in the mutation.
|
||||||
|
func (m *UsageLogMutation) AccountRateMultiplier() (r float64, exists bool) {
|
||||||
|
v := m.account_rate_multiplier
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldAccountRateMultiplier returns the old "account_rate_multiplier" field's value of the UsageLog entity.
|
||||||
|
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *UsageLogMutation) OldAccountRateMultiplier(ctx context.Context) (v *float64, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldAccountRateMultiplier is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldAccountRateMultiplier requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldAccountRateMultiplier: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.AccountRateMultiplier, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddAccountRateMultiplier adds f to the "account_rate_multiplier" field.
|
||||||
|
func (m *UsageLogMutation) AddAccountRateMultiplier(f float64) {
|
||||||
|
if m.addaccount_rate_multiplier != nil {
|
||||||
|
*m.addaccount_rate_multiplier += f
|
||||||
|
} else {
|
||||||
|
m.addaccount_rate_multiplier = &f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddedAccountRateMultiplier returns the value that was added to the "account_rate_multiplier" field in this mutation.
|
||||||
|
func (m *UsageLogMutation) AddedAccountRateMultiplier() (r float64, exists bool) {
|
||||||
|
v := m.addaccount_rate_multiplier
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
|
||||||
|
func (m *UsageLogMutation) ClearAccountRateMultiplier() {
|
||||||
|
m.account_rate_multiplier = nil
|
||||||
|
m.addaccount_rate_multiplier = nil
|
||||||
|
m.clearedFields[usagelog.FieldAccountRateMultiplier] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplierCleared returns if the "account_rate_multiplier" field was cleared in this mutation.
|
||||||
|
func (m *UsageLogMutation) AccountRateMultiplierCleared() bool {
|
||||||
|
_, ok := m.clearedFields[usagelog.FieldAccountRateMultiplier]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetAccountRateMultiplier resets all changes to the "account_rate_multiplier" field.
|
||||||
|
func (m *UsageLogMutation) ResetAccountRateMultiplier() {
|
||||||
|
m.account_rate_multiplier = nil
|
||||||
|
m.addaccount_rate_multiplier = nil
|
||||||
|
delete(m.clearedFields, usagelog.FieldAccountRateMultiplier)
|
||||||
|
}
|
||||||
|
|
||||||
// SetBillingType sets the "billing_type" field.
|
// SetBillingType sets the "billing_type" field.
|
||||||
func (m *UsageLogMutation) SetBillingType(i int8) {
|
func (m *UsageLogMutation) SetBillingType(i int8) {
|
||||||
m.billing_type = &i
|
m.billing_type = &i
|
||||||
@@ -11963,7 +12249,7 @@ func (m *UsageLogMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *UsageLogMutation) Fields() []string {
|
func (m *UsageLogMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 29)
|
fields := make([]string, 0, 30)
|
||||||
if m.user != nil {
|
if m.user != nil {
|
||||||
fields = append(fields, usagelog.FieldUserID)
|
fields = append(fields, usagelog.FieldUserID)
|
||||||
}
|
}
|
||||||
@@ -12024,6 +12310,9 @@ func (m *UsageLogMutation) Fields() []string {
|
|||||||
if m.rate_multiplier != nil {
|
if m.rate_multiplier != nil {
|
||||||
fields = append(fields, usagelog.FieldRateMultiplier)
|
fields = append(fields, usagelog.FieldRateMultiplier)
|
||||||
}
|
}
|
||||||
|
if m.account_rate_multiplier != nil {
|
||||||
|
fields = append(fields, usagelog.FieldAccountRateMultiplier)
|
||||||
|
}
|
||||||
if m.billing_type != nil {
|
if m.billing_type != nil {
|
||||||
fields = append(fields, usagelog.FieldBillingType)
|
fields = append(fields, usagelog.FieldBillingType)
|
||||||
}
|
}
|
||||||
@@ -12099,6 +12388,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.ActualCost()
|
return m.ActualCost()
|
||||||
case usagelog.FieldRateMultiplier:
|
case usagelog.FieldRateMultiplier:
|
||||||
return m.RateMultiplier()
|
return m.RateMultiplier()
|
||||||
|
case usagelog.FieldAccountRateMultiplier:
|
||||||
|
return m.AccountRateMultiplier()
|
||||||
case usagelog.FieldBillingType:
|
case usagelog.FieldBillingType:
|
||||||
return m.BillingType()
|
return m.BillingType()
|
||||||
case usagelog.FieldStream:
|
case usagelog.FieldStream:
|
||||||
@@ -12166,6 +12457,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
|
|||||||
return m.OldActualCost(ctx)
|
return m.OldActualCost(ctx)
|
||||||
case usagelog.FieldRateMultiplier:
|
case usagelog.FieldRateMultiplier:
|
||||||
return m.OldRateMultiplier(ctx)
|
return m.OldRateMultiplier(ctx)
|
||||||
|
case usagelog.FieldAccountRateMultiplier:
|
||||||
|
return m.OldAccountRateMultiplier(ctx)
|
||||||
case usagelog.FieldBillingType:
|
case usagelog.FieldBillingType:
|
||||||
return m.OldBillingType(ctx)
|
return m.OldBillingType(ctx)
|
||||||
case usagelog.FieldStream:
|
case usagelog.FieldStream:
|
||||||
@@ -12333,6 +12626,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetRateMultiplier(v)
|
m.SetRateMultiplier(v)
|
||||||
return nil
|
return nil
|
||||||
|
case usagelog.FieldAccountRateMultiplier:
|
||||||
|
v, ok := value.(float64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetAccountRateMultiplier(v)
|
||||||
|
return nil
|
||||||
case usagelog.FieldBillingType:
|
case usagelog.FieldBillingType:
|
||||||
v, ok := value.(int8)
|
v, ok := value.(int8)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -12443,6 +12743,9 @@ func (m *UsageLogMutation) AddedFields() []string {
|
|||||||
if m.addrate_multiplier != nil {
|
if m.addrate_multiplier != nil {
|
||||||
fields = append(fields, usagelog.FieldRateMultiplier)
|
fields = append(fields, usagelog.FieldRateMultiplier)
|
||||||
}
|
}
|
||||||
|
if m.addaccount_rate_multiplier != nil {
|
||||||
|
fields = append(fields, usagelog.FieldAccountRateMultiplier)
|
||||||
|
}
|
||||||
if m.addbilling_type != nil {
|
if m.addbilling_type != nil {
|
||||||
fields = append(fields, usagelog.FieldBillingType)
|
fields = append(fields, usagelog.FieldBillingType)
|
||||||
}
|
}
|
||||||
@@ -12489,6 +12792,8 @@ func (m *UsageLogMutation) AddedField(name string) (ent.Value, bool) {
|
|||||||
return m.AddedActualCost()
|
return m.AddedActualCost()
|
||||||
case usagelog.FieldRateMultiplier:
|
case usagelog.FieldRateMultiplier:
|
||||||
return m.AddedRateMultiplier()
|
return m.AddedRateMultiplier()
|
||||||
|
case usagelog.FieldAccountRateMultiplier:
|
||||||
|
return m.AddedAccountRateMultiplier()
|
||||||
case usagelog.FieldBillingType:
|
case usagelog.FieldBillingType:
|
||||||
return m.AddedBillingType()
|
return m.AddedBillingType()
|
||||||
case usagelog.FieldDurationMs:
|
case usagelog.FieldDurationMs:
|
||||||
@@ -12597,6 +12902,13 @@ func (m *UsageLogMutation) AddField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.AddRateMultiplier(v)
|
m.AddRateMultiplier(v)
|
||||||
return nil
|
return nil
|
||||||
|
case usagelog.FieldAccountRateMultiplier:
|
||||||
|
v, ok := value.(float64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.AddAccountRateMultiplier(v)
|
||||||
|
return nil
|
||||||
case usagelog.FieldBillingType:
|
case usagelog.FieldBillingType:
|
||||||
v, ok := value.(int8)
|
v, ok := value.(int8)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -12639,6 +12951,9 @@ func (m *UsageLogMutation) ClearedFields() []string {
|
|||||||
if m.FieldCleared(usagelog.FieldSubscriptionID) {
|
if m.FieldCleared(usagelog.FieldSubscriptionID) {
|
||||||
fields = append(fields, usagelog.FieldSubscriptionID)
|
fields = append(fields, usagelog.FieldSubscriptionID)
|
||||||
}
|
}
|
||||||
|
if m.FieldCleared(usagelog.FieldAccountRateMultiplier) {
|
||||||
|
fields = append(fields, usagelog.FieldAccountRateMultiplier)
|
||||||
|
}
|
||||||
if m.FieldCleared(usagelog.FieldDurationMs) {
|
if m.FieldCleared(usagelog.FieldDurationMs) {
|
||||||
fields = append(fields, usagelog.FieldDurationMs)
|
fields = append(fields, usagelog.FieldDurationMs)
|
||||||
}
|
}
|
||||||
@@ -12674,6 +12989,9 @@ func (m *UsageLogMutation) ClearField(name string) error {
|
|||||||
case usagelog.FieldSubscriptionID:
|
case usagelog.FieldSubscriptionID:
|
||||||
m.ClearSubscriptionID()
|
m.ClearSubscriptionID()
|
||||||
return nil
|
return nil
|
||||||
|
case usagelog.FieldAccountRateMultiplier:
|
||||||
|
m.ClearAccountRateMultiplier()
|
||||||
|
return nil
|
||||||
case usagelog.FieldDurationMs:
|
case usagelog.FieldDurationMs:
|
||||||
m.ClearDurationMs()
|
m.ClearDurationMs()
|
||||||
return nil
|
return nil
|
||||||
@@ -12757,6 +13075,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
|
|||||||
case usagelog.FieldRateMultiplier:
|
case usagelog.FieldRateMultiplier:
|
||||||
m.ResetRateMultiplier()
|
m.ResetRateMultiplier()
|
||||||
return nil
|
return nil
|
||||||
|
case usagelog.FieldAccountRateMultiplier:
|
||||||
|
m.ResetAccountRateMultiplier()
|
||||||
|
return nil
|
||||||
case usagelog.FieldBillingType:
|
case usagelog.FieldBillingType:
|
||||||
m.ResetBillingType()
|
m.ResetBillingType()
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -177,22 +177,26 @@ func init() {
|
|||||||
accountDescPriority := accountFields[8].Descriptor()
|
accountDescPriority := accountFields[8].Descriptor()
|
||||||
// account.DefaultPriority holds the default value on creation for the priority field.
|
// account.DefaultPriority holds the default value on creation for the priority field.
|
||||||
account.DefaultPriority = accountDescPriority.Default.(int)
|
account.DefaultPriority = accountDescPriority.Default.(int)
|
||||||
|
// accountDescRateMultiplier is the schema descriptor for rate_multiplier field.
|
||||||
|
accountDescRateMultiplier := accountFields[9].Descriptor()
|
||||||
|
// account.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
|
||||||
|
account.DefaultRateMultiplier = accountDescRateMultiplier.Default.(float64)
|
||||||
// accountDescStatus is the schema descriptor for status field.
|
// accountDescStatus is the schema descriptor for status field.
|
||||||
accountDescStatus := accountFields[9].Descriptor()
|
accountDescStatus := accountFields[10].Descriptor()
|
||||||
// account.DefaultStatus holds the default value on creation for the status field.
|
// account.DefaultStatus holds the default value on creation for the status field.
|
||||||
account.DefaultStatus = accountDescStatus.Default.(string)
|
account.DefaultStatus = accountDescStatus.Default.(string)
|
||||||
// account.StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
// account.StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||||
account.StatusValidator = accountDescStatus.Validators[0].(func(string) error)
|
account.StatusValidator = accountDescStatus.Validators[0].(func(string) error)
|
||||||
// accountDescAutoPauseOnExpired is the schema descriptor for auto_pause_on_expired field.
|
// accountDescAutoPauseOnExpired is the schema descriptor for auto_pause_on_expired field.
|
||||||
accountDescAutoPauseOnExpired := accountFields[13].Descriptor()
|
accountDescAutoPauseOnExpired := accountFields[14].Descriptor()
|
||||||
// account.DefaultAutoPauseOnExpired holds the default value on creation for the auto_pause_on_expired field.
|
// account.DefaultAutoPauseOnExpired holds the default value on creation for the auto_pause_on_expired field.
|
||||||
account.DefaultAutoPauseOnExpired = accountDescAutoPauseOnExpired.Default.(bool)
|
account.DefaultAutoPauseOnExpired = accountDescAutoPauseOnExpired.Default.(bool)
|
||||||
// accountDescSchedulable is the schema descriptor for schedulable field.
|
// accountDescSchedulable is the schema descriptor for schedulable field.
|
||||||
accountDescSchedulable := accountFields[14].Descriptor()
|
accountDescSchedulable := accountFields[15].Descriptor()
|
||||||
// account.DefaultSchedulable holds the default value on creation for the schedulable field.
|
// account.DefaultSchedulable holds the default value on creation for the schedulable field.
|
||||||
account.DefaultSchedulable = accountDescSchedulable.Default.(bool)
|
account.DefaultSchedulable = accountDescSchedulable.Default.(bool)
|
||||||
// accountDescSessionWindowStatus is the schema descriptor for session_window_status field.
|
// accountDescSessionWindowStatus is the schema descriptor for session_window_status field.
|
||||||
accountDescSessionWindowStatus := accountFields[20].Descriptor()
|
accountDescSessionWindowStatus := accountFields[21].Descriptor()
|
||||||
// account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save.
|
// account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save.
|
||||||
account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error)
|
account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error)
|
||||||
accountgroupFields := schema.AccountGroup{}.Fields()
|
accountgroupFields := schema.AccountGroup{}.Fields()
|
||||||
@@ -276,6 +280,10 @@ func init() {
|
|||||||
groupDescClaudeCodeOnly := groupFields[14].Descriptor()
|
groupDescClaudeCodeOnly := groupFields[14].Descriptor()
|
||||||
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
|
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
|
||||||
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
|
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
|
||||||
|
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
|
||||||
|
groupDescModelRoutingEnabled := groupFields[17].Descriptor()
|
||||||
|
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
|
||||||
|
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
|
||||||
promocodeFields := schema.PromoCode{}.Fields()
|
promocodeFields := schema.PromoCode{}.Fields()
|
||||||
_ = promocodeFields
|
_ = promocodeFields
|
||||||
// promocodeDescCode is the schema descriptor for code field.
|
// promocodeDescCode is the schema descriptor for code field.
|
||||||
@@ -578,31 +586,31 @@ func init() {
|
|||||||
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
|
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
|
||||||
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
|
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
|
||||||
// usagelogDescBillingType is the schema descriptor for billing_type field.
|
// usagelogDescBillingType is the schema descriptor for billing_type field.
|
||||||
usagelogDescBillingType := usagelogFields[20].Descriptor()
|
usagelogDescBillingType := usagelogFields[21].Descriptor()
|
||||||
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
|
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
|
||||||
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
|
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
|
||||||
// usagelogDescStream is the schema descriptor for stream field.
|
// usagelogDescStream is the schema descriptor for stream field.
|
||||||
usagelogDescStream := usagelogFields[21].Descriptor()
|
usagelogDescStream := usagelogFields[22].Descriptor()
|
||||||
// usagelog.DefaultStream holds the default value on creation for the stream field.
|
// usagelog.DefaultStream holds the default value on creation for the stream field.
|
||||||
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
|
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
|
||||||
// usagelogDescUserAgent is the schema descriptor for user_agent field.
|
// usagelogDescUserAgent is the schema descriptor for user_agent field.
|
||||||
usagelogDescUserAgent := usagelogFields[24].Descriptor()
|
usagelogDescUserAgent := usagelogFields[25].Descriptor()
|
||||||
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
|
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
|
||||||
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
|
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
|
||||||
// usagelogDescIPAddress is the schema descriptor for ip_address field.
|
// usagelogDescIPAddress is the schema descriptor for ip_address field.
|
||||||
usagelogDescIPAddress := usagelogFields[25].Descriptor()
|
usagelogDescIPAddress := usagelogFields[26].Descriptor()
|
||||||
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
|
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
|
||||||
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
|
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
|
||||||
// usagelogDescImageCount is the schema descriptor for image_count field.
|
// usagelogDescImageCount is the schema descriptor for image_count field.
|
||||||
usagelogDescImageCount := usagelogFields[26].Descriptor()
|
usagelogDescImageCount := usagelogFields[27].Descriptor()
|
||||||
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
|
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
|
||||||
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
|
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
|
||||||
// usagelogDescImageSize is the schema descriptor for image_size field.
|
// usagelogDescImageSize is the schema descriptor for image_size field.
|
||||||
usagelogDescImageSize := usagelogFields[27].Descriptor()
|
usagelogDescImageSize := usagelogFields[28].Descriptor()
|
||||||
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
||||||
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
|
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
|
||||||
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
||||||
usagelogDescCreatedAt := usagelogFields[28].Descriptor()
|
usagelogDescCreatedAt := usagelogFields[29].Descriptor()
|
||||||
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||||
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
||||||
userMixin := schema.User{}.Mixin()
|
userMixin := schema.User{}.Mixin()
|
||||||
|
|||||||
@@ -102,6 +102,12 @@ func (Account) Fields() []ent.Field {
|
|||||||
field.Int("priority").
|
field.Int("priority").
|
||||||
Default(50),
|
Default(50),
|
||||||
|
|
||||||
|
// rate_multiplier: 账号计费倍率(>=0,允许 0 表示该账号计费为 0)
|
||||||
|
// 仅影响账号维度计费口径,不影响用户/API Key 扣费(分组倍率)
|
||||||
|
field.Float("rate_multiplier").
|
||||||
|
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}).
|
||||||
|
Default(1.0),
|
||||||
|
|
||||||
// status: 账户状态,如 "active", "error", "disabled"
|
// status: 账户状态,如 "active", "error", "disabled"
|
||||||
field.String("status").
|
field.String("status").
|
||||||
MaxLen(20).
|
MaxLen(20).
|
||||||
|
|||||||
@@ -95,6 +95,17 @@ func (Group) Fields() []ent.Field {
|
|||||||
Optional().
|
Optional().
|
||||||
Nillable().
|
Nillable().
|
||||||
Comment("非 Claude Code 请求降级使用的分组 ID"),
|
Comment("非 Claude Code 请求降级使用的分组 ID"),
|
||||||
|
|
||||||
|
// 模型路由配置 (added by migration 040)
|
||||||
|
field.JSON("model_routing", map[string][]int64{}).
|
||||||
|
Optional().
|
||||||
|
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
||||||
|
Comment("模型路由配置:模型模式 -> 优先账号ID列表"),
|
||||||
|
|
||||||
|
// 模型路由开关 (added by migration 041)
|
||||||
|
field.Bool("model_routing_enabled").
|
||||||
|
Default(false).
|
||||||
|
Comment("是否启用模型路由配置"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -85,6 +85,12 @@ func (UsageLog) Fields() []ent.Field {
|
|||||||
Default(1).
|
Default(1).
|
||||||
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}),
|
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}),
|
||||||
|
|
||||||
|
// account_rate_multiplier: 账号计费倍率快照(NULL 表示按 1.0 处理)
|
||||||
|
field.Float("account_rate_multiplier").
|
||||||
|
Optional().
|
||||||
|
Nillable().
|
||||||
|
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}),
|
||||||
|
|
||||||
// 其他字段
|
// 其他字段
|
||||||
field.Int8("billing_type").
|
field.Int8("billing_type").
|
||||||
Default(0),
|
Default(0),
|
||||||
|
|||||||
@@ -62,6 +62,8 @@ type UsageLog struct {
|
|||||||
ActualCost float64 `json:"actual_cost,omitempty"`
|
ActualCost float64 `json:"actual_cost,omitempty"`
|
||||||
// RateMultiplier holds the value of the "rate_multiplier" field.
|
// RateMultiplier holds the value of the "rate_multiplier" field.
|
||||||
RateMultiplier float64 `json:"rate_multiplier,omitempty"`
|
RateMultiplier float64 `json:"rate_multiplier,omitempty"`
|
||||||
|
// AccountRateMultiplier holds the value of the "account_rate_multiplier" field.
|
||||||
|
AccountRateMultiplier *float64 `json:"account_rate_multiplier,omitempty"`
|
||||||
// BillingType holds the value of the "billing_type" field.
|
// BillingType holds the value of the "billing_type" field.
|
||||||
BillingType int8 `json:"billing_type,omitempty"`
|
BillingType int8 `json:"billing_type,omitempty"`
|
||||||
// Stream holds the value of the "stream" field.
|
// Stream holds the value of the "stream" field.
|
||||||
@@ -165,7 +167,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
|
|||||||
switch columns[i] {
|
switch columns[i] {
|
||||||
case usagelog.FieldStream:
|
case usagelog.FieldStream:
|
||||||
values[i] = new(sql.NullBool)
|
values[i] = new(sql.NullBool)
|
||||||
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier:
|
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
|
||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
@@ -316,6 +318,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
|
|||||||
} else if value.Valid {
|
} else if value.Valid {
|
||||||
_m.RateMultiplier = value.Float64
|
_m.RateMultiplier = value.Float64
|
||||||
}
|
}
|
||||||
|
case usagelog.FieldAccountRateMultiplier:
|
||||||
|
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field account_rate_multiplier", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.AccountRateMultiplier = new(float64)
|
||||||
|
*_m.AccountRateMultiplier = value.Float64
|
||||||
|
}
|
||||||
case usagelog.FieldBillingType:
|
case usagelog.FieldBillingType:
|
||||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||||
return fmt.Errorf("unexpected type %T for field billing_type", values[i])
|
return fmt.Errorf("unexpected type %T for field billing_type", values[i])
|
||||||
@@ -500,6 +509,11 @@ func (_m *UsageLog) String() string {
|
|||||||
builder.WriteString("rate_multiplier=")
|
builder.WriteString("rate_multiplier=")
|
||||||
builder.WriteString(fmt.Sprintf("%v", _m.RateMultiplier))
|
builder.WriteString(fmt.Sprintf("%v", _m.RateMultiplier))
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
if v := _m.AccountRateMultiplier; v != nil {
|
||||||
|
builder.WriteString("account_rate_multiplier=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||||
|
}
|
||||||
|
builder.WriteString(", ")
|
||||||
builder.WriteString("billing_type=")
|
builder.WriteString("billing_type=")
|
||||||
builder.WriteString(fmt.Sprintf("%v", _m.BillingType))
|
builder.WriteString(fmt.Sprintf("%v", _m.BillingType))
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
|||||||
@@ -54,6 +54,8 @@ const (
|
|||||||
FieldActualCost = "actual_cost"
|
FieldActualCost = "actual_cost"
|
||||||
// FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
|
// FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
|
||||||
FieldRateMultiplier = "rate_multiplier"
|
FieldRateMultiplier = "rate_multiplier"
|
||||||
|
// FieldAccountRateMultiplier holds the string denoting the account_rate_multiplier field in the database.
|
||||||
|
FieldAccountRateMultiplier = "account_rate_multiplier"
|
||||||
// FieldBillingType holds the string denoting the billing_type field in the database.
|
// FieldBillingType holds the string denoting the billing_type field in the database.
|
||||||
FieldBillingType = "billing_type"
|
FieldBillingType = "billing_type"
|
||||||
// FieldStream holds the string denoting the stream field in the database.
|
// FieldStream holds the string denoting the stream field in the database.
|
||||||
@@ -144,6 +146,7 @@ var Columns = []string{
|
|||||||
FieldTotalCost,
|
FieldTotalCost,
|
||||||
FieldActualCost,
|
FieldActualCost,
|
||||||
FieldRateMultiplier,
|
FieldRateMultiplier,
|
||||||
|
FieldAccountRateMultiplier,
|
||||||
FieldBillingType,
|
FieldBillingType,
|
||||||
FieldStream,
|
FieldStream,
|
||||||
FieldDurationMs,
|
FieldDurationMs,
|
||||||
@@ -320,6 +323,11 @@ func ByRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldRateMultiplier, opts...).ToFunc()
|
return sql.OrderByField(FieldRateMultiplier, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ByAccountRateMultiplier orders the results by the account_rate_multiplier field.
|
||||||
|
func ByAccountRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldAccountRateMultiplier, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// ByBillingType orders the results by the billing_type field.
|
// ByBillingType orders the results by the billing_type field.
|
||||||
func ByBillingType(opts ...sql.OrderTermOption) OrderOption {
|
func ByBillingType(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return sql.OrderByField(FieldBillingType, opts...).ToFunc()
|
return sql.OrderByField(FieldBillingType, opts...).ToFunc()
|
||||||
|
|||||||
@@ -155,6 +155,11 @@ func RateMultiplier(v float64) predicate.UsageLog {
|
|||||||
return predicate.UsageLog(sql.FieldEQ(FieldRateMultiplier, v))
|
return predicate.UsageLog(sql.FieldEQ(FieldRateMultiplier, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplier applies equality check predicate on the "account_rate_multiplier" field. It's identical to AccountRateMultiplierEQ.
|
||||||
|
func AccountRateMultiplier(v float64) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldEQ(FieldAccountRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
// BillingType applies equality check predicate on the "billing_type" field. It's identical to BillingTypeEQ.
|
// BillingType applies equality check predicate on the "billing_type" field. It's identical to BillingTypeEQ.
|
||||||
func BillingType(v int8) predicate.UsageLog {
|
func BillingType(v int8) predicate.UsageLog {
|
||||||
return predicate.UsageLog(sql.FieldEQ(FieldBillingType, v))
|
return predicate.UsageLog(sql.FieldEQ(FieldBillingType, v))
|
||||||
@@ -970,6 +975,56 @@ func RateMultiplierLTE(v float64) predicate.UsageLog {
|
|||||||
return predicate.UsageLog(sql.FieldLTE(FieldRateMultiplier, v))
|
return predicate.UsageLog(sql.FieldLTE(FieldRateMultiplier, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplierEQ applies the EQ predicate on the "account_rate_multiplier" field.
|
||||||
|
func AccountRateMultiplierEQ(v float64) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldEQ(FieldAccountRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplierNEQ applies the NEQ predicate on the "account_rate_multiplier" field.
|
||||||
|
func AccountRateMultiplierNEQ(v float64) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldNEQ(FieldAccountRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplierIn applies the In predicate on the "account_rate_multiplier" field.
|
||||||
|
func AccountRateMultiplierIn(vs ...float64) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldIn(FieldAccountRateMultiplier, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplierNotIn applies the NotIn predicate on the "account_rate_multiplier" field.
|
||||||
|
func AccountRateMultiplierNotIn(vs ...float64) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldNotIn(FieldAccountRateMultiplier, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplierGT applies the GT predicate on the "account_rate_multiplier" field.
|
||||||
|
func AccountRateMultiplierGT(v float64) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldGT(FieldAccountRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplierGTE applies the GTE predicate on the "account_rate_multiplier" field.
|
||||||
|
func AccountRateMultiplierGTE(v float64) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldGTE(FieldAccountRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplierLT applies the LT predicate on the "account_rate_multiplier" field.
|
||||||
|
func AccountRateMultiplierLT(v float64) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldLT(FieldAccountRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplierLTE applies the LTE predicate on the "account_rate_multiplier" field.
|
||||||
|
func AccountRateMultiplierLTE(v float64) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldLTE(FieldAccountRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplierIsNil applies the IsNil predicate on the "account_rate_multiplier" field.
|
||||||
|
func AccountRateMultiplierIsNil() predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldIsNull(FieldAccountRateMultiplier))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplierNotNil applies the NotNil predicate on the "account_rate_multiplier" field.
|
||||||
|
func AccountRateMultiplierNotNil() predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldNotNull(FieldAccountRateMultiplier))
|
||||||
|
}
|
||||||
|
|
||||||
// BillingTypeEQ applies the EQ predicate on the "billing_type" field.
|
// BillingTypeEQ applies the EQ predicate on the "billing_type" field.
|
||||||
func BillingTypeEQ(v int8) predicate.UsageLog {
|
func BillingTypeEQ(v int8) predicate.UsageLog {
|
||||||
return predicate.UsageLog(sql.FieldEQ(FieldBillingType, v))
|
return predicate.UsageLog(sql.FieldEQ(FieldBillingType, v))
|
||||||
|
|||||||
@@ -267,6 +267,20 @@ func (_c *UsageLogCreate) SetNillableRateMultiplier(v *float64) *UsageLogCreate
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
|
||||||
|
func (_c *UsageLogCreate) SetAccountRateMultiplier(v float64) *UsageLogCreate {
|
||||||
|
_c.mutation.SetAccountRateMultiplier(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableAccountRateMultiplier sets the "account_rate_multiplier" field if the given value is not nil.
|
||||||
|
func (_c *UsageLogCreate) SetNillableAccountRateMultiplier(v *float64) *UsageLogCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetAccountRateMultiplier(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// SetBillingType sets the "billing_type" field.
|
// SetBillingType sets the "billing_type" field.
|
||||||
func (_c *UsageLogCreate) SetBillingType(v int8) *UsageLogCreate {
|
func (_c *UsageLogCreate) SetBillingType(v int8) *UsageLogCreate {
|
||||||
_c.mutation.SetBillingType(v)
|
_c.mutation.SetBillingType(v)
|
||||||
@@ -712,6 +726,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
|
_spec.SetField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
|
||||||
_node.RateMultiplier = value
|
_node.RateMultiplier = value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.AccountRateMultiplier(); ok {
|
||||||
|
_spec.SetField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
|
||||||
|
_node.AccountRateMultiplier = &value
|
||||||
|
}
|
||||||
if value, ok := _c.mutation.BillingType(); ok {
|
if value, ok := _c.mutation.BillingType(); ok {
|
||||||
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
|
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
|
||||||
_node.BillingType = value
|
_node.BillingType = value
|
||||||
@@ -1215,6 +1233,30 @@ func (u *UsageLogUpsert) AddRateMultiplier(v float64) *UsageLogUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
|
||||||
|
func (u *UsageLogUpsert) SetAccountRateMultiplier(v float64) *UsageLogUpsert {
|
||||||
|
u.Set(usagelog.FieldAccountRateMultiplier, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAccountRateMultiplier sets the "account_rate_multiplier" field to the value that was provided on create.
|
||||||
|
func (u *UsageLogUpsert) UpdateAccountRateMultiplier() *UsageLogUpsert {
|
||||||
|
u.SetExcluded(usagelog.FieldAccountRateMultiplier)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddAccountRateMultiplier adds v to the "account_rate_multiplier" field.
|
||||||
|
func (u *UsageLogUpsert) AddAccountRateMultiplier(v float64) *UsageLogUpsert {
|
||||||
|
u.Add(usagelog.FieldAccountRateMultiplier, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
|
||||||
|
func (u *UsageLogUpsert) ClearAccountRateMultiplier() *UsageLogUpsert {
|
||||||
|
u.SetNull(usagelog.FieldAccountRateMultiplier)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// SetBillingType sets the "billing_type" field.
|
// SetBillingType sets the "billing_type" field.
|
||||||
func (u *UsageLogUpsert) SetBillingType(v int8) *UsageLogUpsert {
|
func (u *UsageLogUpsert) SetBillingType(v int8) *UsageLogUpsert {
|
||||||
u.Set(usagelog.FieldBillingType, v)
|
u.Set(usagelog.FieldBillingType, v)
|
||||||
@@ -1795,6 +1837,34 @@ func (u *UsageLogUpsertOne) UpdateRateMultiplier() *UsageLogUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
|
||||||
|
func (u *UsageLogUpsertOne) SetAccountRateMultiplier(v float64) *UsageLogUpsertOne {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.SetAccountRateMultiplier(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddAccountRateMultiplier adds v to the "account_rate_multiplier" field.
|
||||||
|
func (u *UsageLogUpsertOne) AddAccountRateMultiplier(v float64) *UsageLogUpsertOne {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.AddAccountRateMultiplier(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAccountRateMultiplier sets the "account_rate_multiplier" field to the value that was provided on create.
|
||||||
|
func (u *UsageLogUpsertOne) UpdateAccountRateMultiplier() *UsageLogUpsertOne {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.UpdateAccountRateMultiplier()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
|
||||||
|
func (u *UsageLogUpsertOne) ClearAccountRateMultiplier() *UsageLogUpsertOne {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.ClearAccountRateMultiplier()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetBillingType sets the "billing_type" field.
|
// SetBillingType sets the "billing_type" field.
|
||||||
func (u *UsageLogUpsertOne) SetBillingType(v int8) *UsageLogUpsertOne {
|
func (u *UsageLogUpsertOne) SetBillingType(v int8) *UsageLogUpsertOne {
|
||||||
return u.Update(func(s *UsageLogUpsert) {
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
@@ -2566,6 +2636,34 @@ func (u *UsageLogUpsertBulk) UpdateRateMultiplier() *UsageLogUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
|
||||||
|
func (u *UsageLogUpsertBulk) SetAccountRateMultiplier(v float64) *UsageLogUpsertBulk {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.SetAccountRateMultiplier(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddAccountRateMultiplier adds v to the "account_rate_multiplier" field.
|
||||||
|
func (u *UsageLogUpsertBulk) AddAccountRateMultiplier(v float64) *UsageLogUpsertBulk {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.AddAccountRateMultiplier(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAccountRateMultiplier sets the "account_rate_multiplier" field to the value that was provided on create.
|
||||||
|
func (u *UsageLogUpsertBulk) UpdateAccountRateMultiplier() *UsageLogUpsertBulk {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.UpdateAccountRateMultiplier()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
|
||||||
|
func (u *UsageLogUpsertBulk) ClearAccountRateMultiplier() *UsageLogUpsertBulk {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.ClearAccountRateMultiplier()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetBillingType sets the "billing_type" field.
|
// SetBillingType sets the "billing_type" field.
|
||||||
func (u *UsageLogUpsertBulk) SetBillingType(v int8) *UsageLogUpsertBulk {
|
func (u *UsageLogUpsertBulk) SetBillingType(v int8) *UsageLogUpsertBulk {
|
||||||
return u.Update(func(s *UsageLogUpsert) {
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
|||||||
@@ -415,6 +415,33 @@ func (_u *UsageLogUpdate) AddRateMultiplier(v float64) *UsageLogUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
|
||||||
|
func (_u *UsageLogUpdate) SetAccountRateMultiplier(v float64) *UsageLogUpdate {
|
||||||
|
_u.mutation.ResetAccountRateMultiplier()
|
||||||
|
_u.mutation.SetAccountRateMultiplier(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableAccountRateMultiplier sets the "account_rate_multiplier" field if the given value is not nil.
|
||||||
|
func (_u *UsageLogUpdate) SetNillableAccountRateMultiplier(v *float64) *UsageLogUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetAccountRateMultiplier(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddAccountRateMultiplier adds value to the "account_rate_multiplier" field.
|
||||||
|
func (_u *UsageLogUpdate) AddAccountRateMultiplier(v float64) *UsageLogUpdate {
|
||||||
|
_u.mutation.AddAccountRateMultiplier(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
|
||||||
|
func (_u *UsageLogUpdate) ClearAccountRateMultiplier() *UsageLogUpdate {
|
||||||
|
_u.mutation.ClearAccountRateMultiplier()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetBillingType sets the "billing_type" field.
|
// SetBillingType sets the "billing_type" field.
|
||||||
func (_u *UsageLogUpdate) SetBillingType(v int8) *UsageLogUpdate {
|
func (_u *UsageLogUpdate) SetBillingType(v int8) *UsageLogUpdate {
|
||||||
_u.mutation.ResetBillingType()
|
_u.mutation.ResetBillingType()
|
||||||
@@ -807,6 +834,15 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
|
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
|
||||||
_spec.AddField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
|
_spec.AddField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.AccountRateMultiplier(); ok {
|
||||||
|
_spec.SetField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedAccountRateMultiplier(); ok {
|
||||||
|
_spec.AddField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.AccountRateMultiplierCleared() {
|
||||||
|
_spec.ClearField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.BillingType(); ok {
|
if value, ok := _u.mutation.BillingType(); ok {
|
||||||
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
|
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
|
||||||
}
|
}
|
||||||
@@ -1406,6 +1442,33 @@ func (_u *UsageLogUpdateOne) AddRateMultiplier(v float64) *UsageLogUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
|
||||||
|
func (_u *UsageLogUpdateOne) SetAccountRateMultiplier(v float64) *UsageLogUpdateOne {
|
||||||
|
_u.mutation.ResetAccountRateMultiplier()
|
||||||
|
_u.mutation.SetAccountRateMultiplier(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableAccountRateMultiplier sets the "account_rate_multiplier" field if the given value is not nil.
|
||||||
|
func (_u *UsageLogUpdateOne) SetNillableAccountRateMultiplier(v *float64) *UsageLogUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetAccountRateMultiplier(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddAccountRateMultiplier adds value to the "account_rate_multiplier" field.
|
||||||
|
func (_u *UsageLogUpdateOne) AddAccountRateMultiplier(v float64) *UsageLogUpdateOne {
|
||||||
|
_u.mutation.AddAccountRateMultiplier(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
|
||||||
|
func (_u *UsageLogUpdateOne) ClearAccountRateMultiplier() *UsageLogUpdateOne {
|
||||||
|
_u.mutation.ClearAccountRateMultiplier()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetBillingType sets the "billing_type" field.
|
// SetBillingType sets the "billing_type" field.
|
||||||
func (_u *UsageLogUpdateOne) SetBillingType(v int8) *UsageLogUpdateOne {
|
func (_u *UsageLogUpdateOne) SetBillingType(v int8) *UsageLogUpdateOne {
|
||||||
_u.mutation.ResetBillingType()
|
_u.mutation.ResetBillingType()
|
||||||
@@ -1828,6 +1891,15 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
|
|||||||
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
|
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
|
||||||
_spec.AddField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
|
_spec.AddField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.AccountRateMultiplier(); ok {
|
||||||
|
_spec.SetField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedAccountRateMultiplier(); ok {
|
||||||
|
_spec.AddField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.AccountRateMultiplierCleared() {
|
||||||
|
_spec.ClearField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.BillingType(); ok {
|
if value, ok := _u.mutation.BillingType(); ok {
|
||||||
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
|
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,7 +19,9 @@ const (
|
|||||||
RunModeSimple = "simple"
|
RunModeSimple = "simple"
|
||||||
)
|
)
|
||||||
|
|
||||||
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' https://challenges.cloudflare.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
|
// DefaultCSPPolicy is the default Content-Security-Policy with nonce support
|
||||||
|
// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
|
||||||
|
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
|
||||||
|
|
||||||
// 连接池隔离策略常量
|
// 连接池隔离策略常量
|
||||||
// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗
|
// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗
|
||||||
@@ -232,6 +234,10 @@ type GatewayConfig struct {
|
|||||||
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
|
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
|
||||||
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
|
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
|
||||||
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
|
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
|
||||||
|
// SessionIdleTimeoutMinutes: 会话空闲超时时间(分钟),默认 5 分钟
|
||||||
|
// 用于 Anthropic OAuth/SetupToken 账号的会话数量限制功能
|
||||||
|
// 空闲超过此时间的会话将被自动释放
|
||||||
|
SessionIdleTimeoutMinutes int `mapstructure:"session_idle_timeout_minutes"`
|
||||||
|
|
||||||
// StreamDataIntervalTimeout: 流数据间隔超时(秒),0表示禁用
|
// StreamDataIntervalTimeout: 流数据间隔超时(秒),0表示禁用
|
||||||
StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"`
|
StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"`
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ type AccountHandler struct {
|
|||||||
accountTestService *service.AccountTestService
|
accountTestService *service.AccountTestService
|
||||||
concurrencyService *service.ConcurrencyService
|
concurrencyService *service.ConcurrencyService
|
||||||
crsSyncService *service.CRSSyncService
|
crsSyncService *service.CRSSyncService
|
||||||
|
sessionLimitCache service.SessionLimitCache
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccountHandler creates a new admin account handler
|
// NewAccountHandler creates a new admin account handler
|
||||||
@@ -58,6 +59,7 @@ func NewAccountHandler(
|
|||||||
accountTestService *service.AccountTestService,
|
accountTestService *service.AccountTestService,
|
||||||
concurrencyService *service.ConcurrencyService,
|
concurrencyService *service.ConcurrencyService,
|
||||||
crsSyncService *service.CRSSyncService,
|
crsSyncService *service.CRSSyncService,
|
||||||
|
sessionLimitCache service.SessionLimitCache,
|
||||||
) *AccountHandler {
|
) *AccountHandler {
|
||||||
return &AccountHandler{
|
return &AccountHandler{
|
||||||
adminService: adminService,
|
adminService: adminService,
|
||||||
@@ -70,6 +72,7 @@ func NewAccountHandler(
|
|||||||
accountTestService: accountTestService,
|
accountTestService: accountTestService,
|
||||||
concurrencyService: concurrencyService,
|
concurrencyService: concurrencyService,
|
||||||
crsSyncService: crsSyncService,
|
crsSyncService: crsSyncService,
|
||||||
|
sessionLimitCache: sessionLimitCache,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -84,6 +87,7 @@ type CreateAccountRequest struct {
|
|||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
Concurrency int `json:"concurrency"`
|
Concurrency int `json:"concurrency"`
|
||||||
Priority int `json:"priority"`
|
Priority int `json:"priority"`
|
||||||
|
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||||
GroupIDs []int64 `json:"group_ids"`
|
GroupIDs []int64 `json:"group_ids"`
|
||||||
ExpiresAt *int64 `json:"expires_at"`
|
ExpiresAt *int64 `json:"expires_at"`
|
||||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
||||||
@@ -101,6 +105,7 @@ type UpdateAccountRequest struct {
|
|||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
Concurrency *int `json:"concurrency"`
|
Concurrency *int `json:"concurrency"`
|
||||||
Priority *int `json:"priority"`
|
Priority *int `json:"priority"`
|
||||||
|
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||||
GroupIDs *[]int64 `json:"group_ids"`
|
GroupIDs *[]int64 `json:"group_ids"`
|
||||||
ExpiresAt *int64 `json:"expires_at"`
|
ExpiresAt *int64 `json:"expires_at"`
|
||||||
@@ -115,6 +120,7 @@ type BulkUpdateAccountsRequest struct {
|
|||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
Concurrency *int `json:"concurrency"`
|
Concurrency *int `json:"concurrency"`
|
||||||
Priority *int `json:"priority"`
|
Priority *int `json:"priority"`
|
||||||
|
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||||
Schedulable *bool `json:"schedulable"`
|
Schedulable *bool `json:"schedulable"`
|
||||||
GroupIDs *[]int64 `json:"group_ids"`
|
GroupIDs *[]int64 `json:"group_ids"`
|
||||||
@@ -127,6 +133,9 @@ type BulkUpdateAccountsRequest struct {
|
|||||||
type AccountWithConcurrency struct {
|
type AccountWithConcurrency struct {
|
||||||
*dto.Account
|
*dto.Account
|
||||||
CurrentConcurrency int `json:"current_concurrency"`
|
CurrentConcurrency int `json:"current_concurrency"`
|
||||||
|
// 以下字段仅对 Anthropic OAuth/SetupToken 账号有效,且仅在启用相应功能时返回
|
||||||
|
CurrentWindowCost *float64 `json:"current_window_cost,omitempty"` // 当前窗口费用
|
||||||
|
ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数
|
||||||
}
|
}
|
||||||
|
|
||||||
// List handles listing all accounts with pagination
|
// List handles listing all accounts with pagination
|
||||||
@@ -161,13 +170,89 @@ func (h *AccountHandler) List(c *gin.Context) {
|
|||||||
concurrencyCounts = make(map[int64]int)
|
concurrencyCounts = make(map[int64]int)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
|
||||||
|
windowCostAccountIDs := make([]int64, 0)
|
||||||
|
sessionLimitAccountIDs := make([]int64, 0)
|
||||||
|
for i := range accounts {
|
||||||
|
acc := &accounts[i]
|
||||||
|
if acc.IsAnthropicOAuthOrSetupToken() {
|
||||||
|
if acc.GetWindowCostLimit() > 0 {
|
||||||
|
windowCostAccountIDs = append(windowCostAccountIDs, acc.ID)
|
||||||
|
}
|
||||||
|
if acc.GetMaxSessions() > 0 {
|
||||||
|
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 并行获取窗口费用和活跃会话数
|
||||||
|
var windowCosts map[int64]float64
|
||||||
|
var activeSessions map[int64]int
|
||||||
|
|
||||||
|
// 获取活跃会话数(批量查询)
|
||||||
|
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
|
||||||
|
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs)
|
||||||
|
if activeSessions == nil {
|
||||||
|
activeSessions = make(map[int64]int)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取窗口费用(并行查询)
|
||||||
|
if len(windowCostAccountIDs) > 0 {
|
||||||
|
windowCosts = make(map[int64]float64)
|
||||||
|
var mu sync.Mutex
|
||||||
|
g, gctx := errgroup.WithContext(c.Request.Context())
|
||||||
|
g.SetLimit(10) // 限制并发数
|
||||||
|
|
||||||
|
for i := range accounts {
|
||||||
|
acc := &accounts[i]
|
||||||
|
if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
accCopy := acc // 闭包捕获
|
||||||
|
g.Go(func() error {
|
||||||
|
var startTime time.Time
|
||||||
|
if accCopy.SessionWindowStart != nil {
|
||||||
|
startTime = *accCopy.SessionWindowStart
|
||||||
|
} else {
|
||||||
|
startTime = time.Now().Add(-5 * time.Hour)
|
||||||
|
}
|
||||||
|
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
|
||||||
|
if err == nil && stats != nil {
|
||||||
|
mu.Lock()
|
||||||
|
windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
|
return nil // 不返回错误,允许部分失败
|
||||||
|
})
|
||||||
|
}
|
||||||
|
_ = g.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
// Build response with concurrency info
|
// Build response with concurrency info
|
||||||
result := make([]AccountWithConcurrency, len(accounts))
|
result := make([]AccountWithConcurrency, len(accounts))
|
||||||
for i := range accounts {
|
for i := range accounts {
|
||||||
result[i] = AccountWithConcurrency{
|
acc := &accounts[i]
|
||||||
Account: dto.AccountFromService(&accounts[i]),
|
item := AccountWithConcurrency{
|
||||||
CurrentConcurrency: concurrencyCounts[accounts[i].ID],
|
Account: dto.AccountFromService(acc),
|
||||||
|
CurrentConcurrency: concurrencyCounts[acc.ID],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 添加窗口费用(仅当启用时)
|
||||||
|
if windowCosts != nil {
|
||||||
|
if cost, ok := windowCosts[acc.ID]; ok {
|
||||||
|
item.CurrentWindowCost = &cost
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加活跃会话数(仅当启用时)
|
||||||
|
if activeSessions != nil {
|
||||||
|
if count, ok := activeSessions[acc.ID]; ok {
|
||||||
|
item.ActiveSessions = &count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result[i] = item
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Paginated(c, result, total, page, pageSize)
|
response.Paginated(c, result, total, page, pageSize)
|
||||||
@@ -199,6 +284,10 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
|||||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if req.RateMultiplier != nil && *req.RateMultiplier < 0 {
|
||||||
|
response.BadRequest(c, "rate_multiplier must be >= 0")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 确定是否跳过混合渠道检查
|
// 确定是否跳过混合渠道检查
|
||||||
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
||||||
@@ -213,6 +302,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
|||||||
ProxyID: req.ProxyID,
|
ProxyID: req.ProxyID,
|
||||||
Concurrency: req.Concurrency,
|
Concurrency: req.Concurrency,
|
||||||
Priority: req.Priority,
|
Priority: req.Priority,
|
||||||
|
RateMultiplier: req.RateMultiplier,
|
||||||
GroupIDs: req.GroupIDs,
|
GroupIDs: req.GroupIDs,
|
||||||
ExpiresAt: req.ExpiresAt,
|
ExpiresAt: req.ExpiresAt,
|
||||||
AutoPauseOnExpired: req.AutoPauseOnExpired,
|
AutoPauseOnExpired: req.AutoPauseOnExpired,
|
||||||
@@ -258,6 +348,10 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
|||||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if req.RateMultiplier != nil && *req.RateMultiplier < 0 {
|
||||||
|
response.BadRequest(c, "rate_multiplier must be >= 0")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 确定是否跳过混合渠道检查
|
// 确定是否跳过混合渠道检查
|
||||||
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
||||||
@@ -271,6 +365,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
|||||||
ProxyID: req.ProxyID,
|
ProxyID: req.ProxyID,
|
||||||
Concurrency: req.Concurrency, // 指针类型,nil 表示未提供
|
Concurrency: req.Concurrency, // 指针类型,nil 表示未提供
|
||||||
Priority: req.Priority, // 指针类型,nil 表示未提供
|
Priority: req.Priority, // 指针类型,nil 表示未提供
|
||||||
|
RateMultiplier: req.RateMultiplier,
|
||||||
Status: req.Status,
|
Status: req.Status,
|
||||||
GroupIDs: req.GroupIDs,
|
GroupIDs: req.GroupIDs,
|
||||||
ExpiresAt: req.ExpiresAt,
|
ExpiresAt: req.ExpiresAt,
|
||||||
@@ -652,6 +747,10 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
|||||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if req.RateMultiplier != nil && *req.RateMultiplier < 0 {
|
||||||
|
response.BadRequest(c, "rate_multiplier must be >= 0")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 确定是否跳过混合渠道检查
|
// 确定是否跳过混合渠道检查
|
||||||
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
||||||
@@ -660,6 +759,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
|||||||
req.ProxyID != nil ||
|
req.ProxyID != nil ||
|
||||||
req.Concurrency != nil ||
|
req.Concurrency != nil ||
|
||||||
req.Priority != nil ||
|
req.Priority != nil ||
|
||||||
|
req.RateMultiplier != nil ||
|
||||||
req.Status != "" ||
|
req.Status != "" ||
|
||||||
req.Schedulable != nil ||
|
req.Schedulable != nil ||
|
||||||
req.GroupIDs != nil ||
|
req.GroupIDs != nil ||
|
||||||
@@ -677,6 +777,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
|||||||
ProxyID: req.ProxyID,
|
ProxyID: req.ProxyID,
|
||||||
Concurrency: req.Concurrency,
|
Concurrency: req.Concurrency,
|
||||||
Priority: req.Priority,
|
Priority: req.Priority,
|
||||||
|
RateMultiplier: req.RateMultiplier,
|
||||||
Status: req.Status,
|
Status: req.Status,
|
||||||
Schedulable: req.Schedulable,
|
Schedulable: req.Schedulable,
|
||||||
GroupIDs: req.GroupIDs,
|
GroupIDs: req.GroupIDs,
|
||||||
|
|||||||
@@ -186,13 +186,16 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
|
|||||||
|
|
||||||
// GetUsageTrend handles getting usage trend data
|
// GetUsageTrend handles getting usage trend data
|
||||||
// GET /api/v1/admin/dashboard/trend
|
// GET /api/v1/admin/dashboard/trend
|
||||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id
|
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream
|
||||||
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||||
startTime, endTime := parseTimeRange(c)
|
startTime, endTime := parseTimeRange(c)
|
||||||
granularity := c.DefaultQuery("granularity", "day")
|
granularity := c.DefaultQuery("granularity", "day")
|
||||||
|
|
||||||
// Parse optional filter params
|
// Parse optional filter params
|
||||||
var userID, apiKeyID int64
|
var userID, apiKeyID, accountID, groupID int64
|
||||||
|
var model string
|
||||||
|
var stream *bool
|
||||||
|
|
||||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||||
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
||||||
userID = id
|
userID = id
|
||||||
@@ -203,8 +206,26 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
|||||||
apiKeyID = id
|
apiKeyID = id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
|
||||||
|
if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil {
|
||||||
|
accountID = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
|
||||||
|
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
|
||||||
|
groupID = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if modelStr := c.Query("model"); modelStr != "" {
|
||||||
|
model = modelStr
|
||||||
|
}
|
||||||
|
if streamStr := c.Query("stream"); streamStr != "" {
|
||||||
|
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
|
||||||
|
stream = &streamVal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID)
|
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get usage trend")
|
response.Error(c, 500, "Failed to get usage trend")
|
||||||
return
|
return
|
||||||
@@ -220,12 +241,14 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
|||||||
|
|
||||||
// GetModelStats handles getting model usage statistics
|
// GetModelStats handles getting model usage statistics
|
||||||
// GET /api/v1/admin/dashboard/models
|
// GET /api/v1/admin/dashboard/models
|
||||||
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id
|
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream
|
||||||
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||||
startTime, endTime := parseTimeRange(c)
|
startTime, endTime := parseTimeRange(c)
|
||||||
|
|
||||||
// Parse optional filter params
|
// Parse optional filter params
|
||||||
var userID, apiKeyID int64
|
var userID, apiKeyID, accountID, groupID int64
|
||||||
|
var stream *bool
|
||||||
|
|
||||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||||
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
||||||
userID = id
|
userID = id
|
||||||
@@ -236,8 +259,23 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
|||||||
apiKeyID = id
|
apiKeyID = id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
|
||||||
|
if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil {
|
||||||
|
accountID = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
|
||||||
|
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
|
||||||
|
groupID = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if streamStr := c.Query("stream"); streamStr != "" {
|
||||||
|
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
|
||||||
|
stream = &streamVal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID)
|
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get model statistics")
|
response.Error(c, 500, "Failed to get model statistics")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -40,6 +40,9 @@ type CreateGroupRequest struct {
|
|||||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||||
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
|
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||||
|
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateGroupRequest represents update group request
|
// UpdateGroupRequest represents update group request
|
||||||
@@ -60,6 +63,9 @@ type UpdateGroupRequest struct {
|
|||||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||||
ClaudeCodeOnly *bool `json:"claude_code_only"`
|
ClaudeCodeOnly *bool `json:"claude_code_only"`
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||||
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
|
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||||
|
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// List handles listing all groups with pagination
|
// List handles listing all groups with pagination
|
||||||
@@ -163,6 +169,8 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
|||||||
ImagePrice4K: req.ImagePrice4K,
|
ImagePrice4K: req.ImagePrice4K,
|
||||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||||
FallbackGroupID: req.FallbackGroupID,
|
FallbackGroupID: req.FallbackGroupID,
|
||||||
|
ModelRouting: req.ModelRouting,
|
||||||
|
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
@@ -203,6 +211,8 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
|||||||
ImagePrice4K: req.ImagePrice4K,
|
ImagePrice4K: req.ImagePrice4K,
|
||||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||||
FallbackGroupID: req.FallbackGroupID,
|
FallbackGroupID: req.FallbackGroupID,
|
||||||
|
ModelRouting: req.ModelRouting,
|
||||||
|
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
|
|||||||
@@ -7,8 +7,10 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gin-gonic/gin/binding"
|
"github.com/gin-gonic/gin/binding"
|
||||||
@@ -18,8 +20,6 @@ var validOpsAlertMetricTypes = []string{
|
|||||||
"success_rate",
|
"success_rate",
|
||||||
"error_rate",
|
"error_rate",
|
||||||
"upstream_error_rate",
|
"upstream_error_rate",
|
||||||
"p95_latency_ms",
|
|
||||||
"p99_latency_ms",
|
|
||||||
"cpu_usage_percent",
|
"cpu_usage_percent",
|
||||||
"memory_usage_percent",
|
"memory_usage_percent",
|
||||||
"concurrency_queue_depth",
|
"concurrency_queue_depth",
|
||||||
@@ -372,8 +372,135 @@ func (h *OpsHandler) DeleteAlertRule(c *gin.Context) {
|
|||||||
response.Success(c, gin.H{"deleted": true})
|
response.Success(c, gin.H{"deleted": true})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAlertEvent returns a single ops alert event.
|
||||||
|
// GET /api/v1/admin/ops/alert-events/:id
|
||||||
|
func (h *OpsHandler) GetAlertEvent(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid event ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ev, err := h.opsService.GetAlertEventByID(c.Request.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, ev)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAlertEventStatus updates an ops alert event status.
|
||||||
|
// PUT /api/v1/admin/ops/alert-events/:id/status
|
||||||
|
func (h *OpsHandler) UpdateAlertEventStatus(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid event ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
payload.Status = strings.TrimSpace(payload.Status)
|
||||||
|
if payload.Status == "" {
|
||||||
|
response.BadRequest(c, "Invalid status")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if payload.Status != service.OpsAlertStatusResolved && payload.Status != service.OpsAlertStatusManualResolved {
|
||||||
|
response.BadRequest(c, "Invalid status")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var resolvedAt *time.Time
|
||||||
|
if payload.Status == service.OpsAlertStatusResolved || payload.Status == service.OpsAlertStatusManualResolved {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
resolvedAt = &now
|
||||||
|
}
|
||||||
|
if err := h.opsService.UpdateAlertEventStatus(c.Request.Context(), id, payload.Status, resolvedAt); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"updated": true})
|
||||||
|
}
|
||||||
|
|
||||||
// ListAlertEvents lists recent ops alert events.
|
// ListAlertEvents lists recent ops alert events.
|
||||||
// GET /api/v1/admin/ops/alert-events
|
// GET /api/v1/admin/ops/alert-events
|
||||||
|
// CreateAlertSilence creates a scoped silence for ops alerts.
|
||||||
|
// POST /api/v1/admin/ops/alert-silences
|
||||||
|
func (h *OpsHandler) CreateAlertSilence(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload struct {
|
||||||
|
RuleID int64 `json:"rule_id"`
|
||||||
|
Platform string `json:"platform"`
|
||||||
|
GroupID *int64 `json:"group_id"`
|
||||||
|
Region *string `json:"region"`
|
||||||
|
Until string `json:"until"`
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
until, err := time.Parse(time.RFC3339, strings.TrimSpace(payload.Until))
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid until")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
createdBy := (*int64)(nil)
|
||||||
|
if subject, ok := middleware.GetAuthSubjectFromContext(c); ok {
|
||||||
|
uid := subject.UserID
|
||||||
|
createdBy = &uid
|
||||||
|
}
|
||||||
|
|
||||||
|
silence := &service.OpsAlertSilence{
|
||||||
|
RuleID: payload.RuleID,
|
||||||
|
Platform: strings.TrimSpace(payload.Platform),
|
||||||
|
GroupID: payload.GroupID,
|
||||||
|
Region: payload.Region,
|
||||||
|
Until: until,
|
||||||
|
Reason: strings.TrimSpace(payload.Reason),
|
||||||
|
CreatedBy: createdBy,
|
||||||
|
}
|
||||||
|
|
||||||
|
created, err := h.opsService.CreateAlertSilence(c.Request.Context(), silence)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, created)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *OpsHandler) ListAlertEvents(c *gin.Context) {
|
func (h *OpsHandler) ListAlertEvents(c *gin.Context) {
|
||||||
if h.opsService == nil {
|
if h.opsService == nil {
|
||||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
@@ -384,7 +511,7 @@ func (h *OpsHandler) ListAlertEvents(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
limit := 100
|
limit := 20
|
||||||
if raw := strings.TrimSpace(c.Query("limit")); raw != "" {
|
if raw := strings.TrimSpace(c.Query("limit")); raw != "" {
|
||||||
n, err := strconv.Atoi(raw)
|
n, err := strconv.Atoi(raw)
|
||||||
if err != nil || n <= 0 {
|
if err != nil || n <= 0 {
|
||||||
@@ -400,6 +527,49 @@ func (h *OpsHandler) ListAlertEvents(c *gin.Context) {
|
|||||||
Severity: strings.TrimSpace(c.Query("severity")),
|
Severity: strings.TrimSpace(c.Query("severity")),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if v := strings.TrimSpace(c.Query("email_sent")); v != "" {
|
||||||
|
vv := strings.ToLower(v)
|
||||||
|
switch vv {
|
||||||
|
case "true", "1":
|
||||||
|
b := true
|
||||||
|
filter.EmailSent = &b
|
||||||
|
case "false", "0":
|
||||||
|
b := false
|
||||||
|
filter.EmailSent = &b
|
||||||
|
default:
|
||||||
|
response.BadRequest(c, "Invalid email_sent")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cursor pagination: both params must be provided together.
|
||||||
|
rawTS := strings.TrimSpace(c.Query("before_fired_at"))
|
||||||
|
rawID := strings.TrimSpace(c.Query("before_id"))
|
||||||
|
if (rawTS == "") != (rawID == "") {
|
||||||
|
response.BadRequest(c, "before_fired_at and before_id must be provided together")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if rawTS != "" {
|
||||||
|
ts, err := time.Parse(time.RFC3339Nano, rawTS)
|
||||||
|
if err != nil {
|
||||||
|
if t2, err2 := time.Parse(time.RFC3339, rawTS); err2 == nil {
|
||||||
|
ts = t2
|
||||||
|
} else {
|
||||||
|
response.BadRequest(c, "Invalid before_fired_at")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
filter.BeforeFiredAt = &ts
|
||||||
|
}
|
||||||
|
if rawID != "" {
|
||||||
|
id, err := strconv.ParseInt(rawID, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid before_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.BeforeID = &id
|
||||||
|
}
|
||||||
|
|
||||||
// Optional global filter support (platform/group/time range).
|
// Optional global filter support (platform/group/time range).
|
||||||
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
|
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
|
||||||
filter.Platform = platform
|
filter.Platform = platform
|
||||||
|
|||||||
@@ -19,6 +19,57 @@ type OpsHandler struct {
|
|||||||
opsService *service.OpsService
|
opsService *service.OpsService
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetErrorLogByID returns ops error log detail.
|
||||||
|
// GET /api/v1/admin/ops/errors/:id
|
||||||
|
func (h *OpsHandler) GetErrorLogByID(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idStr := strings.TrimSpace(c.Param("id"))
|
||||||
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid error id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
detail, err := h.opsService.GetErrorLogByID(c.Request.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, detail)
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
opsListViewErrors = "errors"
|
||||||
|
opsListViewExcluded = "excluded"
|
||||||
|
opsListViewAll = "all"
|
||||||
|
)
|
||||||
|
|
||||||
|
func parseOpsViewParam(c *gin.Context) string {
|
||||||
|
if c == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
v := strings.ToLower(strings.TrimSpace(c.Query("view")))
|
||||||
|
switch v {
|
||||||
|
case "", opsListViewErrors:
|
||||||
|
return opsListViewErrors
|
||||||
|
case opsListViewExcluded:
|
||||||
|
return opsListViewExcluded
|
||||||
|
case opsListViewAll:
|
||||||
|
return opsListViewAll
|
||||||
|
default:
|
||||||
|
return opsListViewErrors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func NewOpsHandler(opsService *service.OpsService) *OpsHandler {
|
func NewOpsHandler(opsService *service.OpsService) *OpsHandler {
|
||||||
return &OpsHandler{opsService: opsService}
|
return &OpsHandler{opsService: opsService}
|
||||||
}
|
}
|
||||||
@@ -47,16 +98,26 @@ func (h *OpsHandler) GetErrorLogs(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
filter := &service.OpsErrorLogFilter{
|
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
|
||||||
Page: page,
|
|
||||||
PageSize: pageSize,
|
|
||||||
}
|
|
||||||
if !startTime.IsZero() {
|
if !startTime.IsZero() {
|
||||||
filter.StartTime = &startTime
|
filter.StartTime = &startTime
|
||||||
}
|
}
|
||||||
if !endTime.IsZero() {
|
if !endTime.IsZero() {
|
||||||
filter.EndTime = &endTime
|
filter.EndTime = &endTime
|
||||||
}
|
}
|
||||||
|
filter.View = parseOpsViewParam(c)
|
||||||
|
filter.Phase = strings.TrimSpace(c.Query("phase"))
|
||||||
|
filter.Owner = strings.TrimSpace(c.Query("error_owner"))
|
||||||
|
filter.Source = strings.TrimSpace(c.Query("error_source"))
|
||||||
|
filter.Query = strings.TrimSpace(c.Query("q"))
|
||||||
|
filter.UserQuery = strings.TrimSpace(c.Query("user_query"))
|
||||||
|
|
||||||
|
// Force request errors: client-visible status >= 400.
|
||||||
|
// buildOpsErrorLogsWhere already applies this for non-upstream phase.
|
||||||
|
if strings.EqualFold(strings.TrimSpace(filter.Phase), "upstream") {
|
||||||
|
filter.Phase = ""
|
||||||
|
}
|
||||||
|
|
||||||
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
|
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
|
||||||
filter.Platform = platform
|
filter.Platform = platform
|
||||||
@@ -77,11 +138,19 @@ func (h *OpsHandler) GetErrorLogs(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
filter.AccountID = &id
|
filter.AccountID = &id
|
||||||
}
|
}
|
||||||
if phase := strings.TrimSpace(c.Query("phase")); phase != "" {
|
|
||||||
filter.Phase = phase
|
if v := strings.TrimSpace(c.Query("resolved")); v != "" {
|
||||||
|
switch strings.ToLower(v) {
|
||||||
|
case "1", "true", "yes":
|
||||||
|
b := true
|
||||||
|
filter.Resolved = &b
|
||||||
|
case "0", "false", "no":
|
||||||
|
b := false
|
||||||
|
filter.Resolved = &b
|
||||||
|
default:
|
||||||
|
response.BadRequest(c, "Invalid resolved")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if q := strings.TrimSpace(c.Query("q")); q != "" {
|
|
||||||
filter.Query = q
|
|
||||||
}
|
}
|
||||||
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
|
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
|
||||||
parts := strings.Split(statusCodesStr, ",")
|
parts := strings.Split(statusCodesStr, ",")
|
||||||
@@ -106,13 +175,120 @@ func (h *OpsHandler) GetErrorLogs(c *gin.Context) {
|
|||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
|
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetErrorLogByID returns a single error log detail.
|
// ListRequestErrors lists client-visible request errors.
|
||||||
// GET /api/v1/admin/ops/errors/:id
|
// GET /api/v1/admin/ops/request-errors
|
||||||
func (h *OpsHandler) GetErrorLogByID(c *gin.Context) {
|
func (h *OpsHandler) ListRequestErrors(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
page, pageSize := response.ParsePagination(c)
|
||||||
|
if pageSize > 500 {
|
||||||
|
pageSize = 500
|
||||||
|
}
|
||||||
|
startTime, endTime, err := parseOpsTimeRange(c, "1h")
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
|
||||||
|
if !startTime.IsZero() {
|
||||||
|
filter.StartTime = &startTime
|
||||||
|
}
|
||||||
|
if !endTime.IsZero() {
|
||||||
|
filter.EndTime = &endTime
|
||||||
|
}
|
||||||
|
filter.View = parseOpsViewParam(c)
|
||||||
|
filter.Phase = strings.TrimSpace(c.Query("phase"))
|
||||||
|
filter.Owner = strings.TrimSpace(c.Query("error_owner"))
|
||||||
|
filter.Source = strings.TrimSpace(c.Query("error_source"))
|
||||||
|
filter.Query = strings.TrimSpace(c.Query("q"))
|
||||||
|
filter.UserQuery = strings.TrimSpace(c.Query("user_query"))
|
||||||
|
|
||||||
|
// Force request errors: client-visible status >= 400.
|
||||||
|
// buildOpsErrorLogsWhere already applies this for non-upstream phase.
|
||||||
|
if strings.EqualFold(strings.TrimSpace(filter.Phase), "upstream") {
|
||||||
|
filter.Phase = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
|
||||||
|
filter.Platform = platform
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid group_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.GroupID = &id
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid account_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.AccountID = &id
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := strings.TrimSpace(c.Query("resolved")); v != "" {
|
||||||
|
switch strings.ToLower(v) {
|
||||||
|
case "1", "true", "yes":
|
||||||
|
b := true
|
||||||
|
filter.Resolved = &b
|
||||||
|
case "0", "false", "no":
|
||||||
|
b := false
|
||||||
|
filter.Resolved = &b
|
||||||
|
default:
|
||||||
|
response.BadRequest(c, "Invalid resolved")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
|
||||||
|
parts := strings.Split(statusCodesStr, ",")
|
||||||
|
out := make([]int, 0, len(parts))
|
||||||
|
for _, part := range parts {
|
||||||
|
p := strings.TrimSpace(part)
|
||||||
|
if p == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
n, err := strconv.Atoi(p)
|
||||||
|
if err != nil || n < 0 {
|
||||||
|
response.BadRequest(c, "Invalid status_codes")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out = append(out, n)
|
||||||
|
}
|
||||||
|
filter.StatusCodes = out
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRequestError returns request error detail.
|
||||||
|
// GET /api/v1/admin/ops/request-errors/:id
|
||||||
|
func (h *OpsHandler) GetRequestError(c *gin.Context) {
|
||||||
|
// same storage; just proxy to existing detail
|
||||||
|
h.GetErrorLogByID(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListRequestErrorUpstreamErrors lists upstream error logs correlated to a request error.
|
||||||
|
// GET /api/v1/admin/ops/request-errors/:id/upstream-errors
|
||||||
|
func (h *OpsHandler) ListRequestErrorUpstreamErrors(c *gin.Context) {
|
||||||
if h.opsService == nil {
|
if h.opsService == nil {
|
||||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
return
|
return
|
||||||
@@ -129,15 +305,306 @@ func (h *OpsHandler) GetErrorLogByID(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load request error to get correlation keys.
|
||||||
detail, err := h.opsService.GetErrorLogByID(c.Request.Context(), id)
|
detail, err := h.opsService.GetErrorLogByID(c.Request.Context(), id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, detail)
|
// Correlate by request_id/client_request_id.
|
||||||
|
requestID := strings.TrimSpace(detail.RequestID)
|
||||||
|
clientRequestID := strings.TrimSpace(detail.ClientRequestID)
|
||||||
|
if requestID == "" && clientRequestID == "" {
|
||||||
|
response.Paginated(c, []*service.OpsErrorLog{}, 0, 1, 10)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
page, pageSize := response.ParsePagination(c)
|
||||||
|
if pageSize > 500 {
|
||||||
|
pageSize = 500
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep correlation window wide enough so linked upstream errors
|
||||||
|
// are discoverable even when UI defaults to 1h elsewhere.
|
||||||
|
startTime, endTime, err := parseOpsTimeRange(c, "30d")
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
|
||||||
|
if !startTime.IsZero() {
|
||||||
|
filter.StartTime = &startTime
|
||||||
|
}
|
||||||
|
if !endTime.IsZero() {
|
||||||
|
filter.EndTime = &endTime
|
||||||
|
}
|
||||||
|
filter.View = "all"
|
||||||
|
filter.Phase = "upstream"
|
||||||
|
filter.Owner = "provider"
|
||||||
|
filter.Source = strings.TrimSpace(c.Query("error_source"))
|
||||||
|
filter.Query = strings.TrimSpace(c.Query("q"))
|
||||||
|
|
||||||
|
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
|
||||||
|
filter.Platform = platform
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prefer exact match on request_id; if missing, fall back to client_request_id.
|
||||||
|
if requestID != "" {
|
||||||
|
filter.RequestID = requestID
|
||||||
|
} else {
|
||||||
|
filter.ClientRequestID = clientRequestID
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If client asks for details, expand each upstream error log to include upstream response fields.
|
||||||
|
includeDetail := strings.TrimSpace(c.Query("include_detail"))
|
||||||
|
if includeDetail == "1" || strings.EqualFold(includeDetail, "true") || strings.EqualFold(includeDetail, "yes") {
|
||||||
|
details := make([]*service.OpsErrorLogDetail, 0, len(result.Errors))
|
||||||
|
for _, item := range result.Errors {
|
||||||
|
if item == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
d, err := h.opsService.GetErrorLogByID(c.Request.Context(), item.ID)
|
||||||
|
if err != nil || d == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
details = append(details, d)
|
||||||
|
}
|
||||||
|
response.Paginated(c, details, int64(result.Total), result.Page, result.PageSize)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetryRequestErrorClient retries the client request based on stored request body.
|
||||||
|
// POST /api/v1/admin/ops/request-errors/:id/retry-client
|
||||||
|
func (h *OpsHandler) RetryRequestErrorClient(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok || subject.UserID <= 0 {
|
||||||
|
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idStr := strings.TrimSpace(c.Param("id"))
|
||||||
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid error id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, service.OpsRetryModeClient, nil)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetryRequestErrorUpstreamEvent retries a specific upstream attempt using captured upstream_request_body.
|
||||||
|
// POST /api/v1/admin/ops/request-errors/:id/upstream-errors/:idx/retry
|
||||||
|
func (h *OpsHandler) RetryRequestErrorUpstreamEvent(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok || subject.UserID <= 0 {
|
||||||
|
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idStr := strings.TrimSpace(c.Param("id"))
|
||||||
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid error id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idxStr := strings.TrimSpace(c.Param("idx"))
|
||||||
|
idx, err := strconv.Atoi(idxStr)
|
||||||
|
if err != nil || idx < 0 {
|
||||||
|
response.BadRequest(c, "Invalid upstream idx")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.opsService.RetryUpstreamEvent(c.Request.Context(), subject.UserID, id, idx)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveRequestError toggles resolved status.
|
||||||
|
// PUT /api/v1/admin/ops/request-errors/:id/resolve
|
||||||
|
func (h *OpsHandler) ResolveRequestError(c *gin.Context) {
|
||||||
|
h.UpdateErrorResolution(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListUpstreamErrors lists independent upstream errors.
|
||||||
|
// GET /api/v1/admin/ops/upstream-errors
|
||||||
|
func (h *OpsHandler) ListUpstreamErrors(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
page, pageSize := response.ParsePagination(c)
|
||||||
|
if pageSize > 500 {
|
||||||
|
pageSize = 500
|
||||||
|
}
|
||||||
|
startTime, endTime, err := parseOpsTimeRange(c, "1h")
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
|
||||||
|
if !startTime.IsZero() {
|
||||||
|
filter.StartTime = &startTime
|
||||||
|
}
|
||||||
|
if !endTime.IsZero() {
|
||||||
|
filter.EndTime = &endTime
|
||||||
|
}
|
||||||
|
|
||||||
|
filter.View = parseOpsViewParam(c)
|
||||||
|
filter.Phase = "upstream"
|
||||||
|
filter.Owner = "provider"
|
||||||
|
filter.Source = strings.TrimSpace(c.Query("error_source"))
|
||||||
|
filter.Query = strings.TrimSpace(c.Query("q"))
|
||||||
|
|
||||||
|
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
|
||||||
|
filter.Platform = platform
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid group_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.GroupID = &id
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid account_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.AccountID = &id
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := strings.TrimSpace(c.Query("resolved")); v != "" {
|
||||||
|
switch strings.ToLower(v) {
|
||||||
|
case "1", "true", "yes":
|
||||||
|
b := true
|
||||||
|
filter.Resolved = &b
|
||||||
|
case "0", "false", "no":
|
||||||
|
b := false
|
||||||
|
filter.Resolved = &b
|
||||||
|
default:
|
||||||
|
response.BadRequest(c, "Invalid resolved")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
|
||||||
|
parts := strings.Split(statusCodesStr, ",")
|
||||||
|
out := make([]int, 0, len(parts))
|
||||||
|
for _, part := range parts {
|
||||||
|
p := strings.TrimSpace(part)
|
||||||
|
if p == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
n, err := strconv.Atoi(p)
|
||||||
|
if err != nil || n < 0 {
|
||||||
|
response.BadRequest(c, "Invalid status_codes")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out = append(out, n)
|
||||||
|
}
|
||||||
|
filter.StatusCodes = out
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUpstreamError returns upstream error detail.
|
||||||
|
// GET /api/v1/admin/ops/upstream-errors/:id
|
||||||
|
func (h *OpsHandler) GetUpstreamError(c *gin.Context) {
|
||||||
|
h.GetErrorLogByID(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetryUpstreamError retries upstream error using the original account_id.
|
||||||
|
// POST /api/v1/admin/ops/upstream-errors/:id/retry
|
||||||
|
func (h *OpsHandler) RetryUpstreamError(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok || subject.UserID <= 0 {
|
||||||
|
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idStr := strings.TrimSpace(c.Param("id"))
|
||||||
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid error id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, service.OpsRetryModeUpstream, nil)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveUpstreamError toggles resolved status.
|
||||||
|
// PUT /api/v1/admin/ops/upstream-errors/:id/resolve
|
||||||
|
func (h *OpsHandler) ResolveUpstreamError(c *gin.Context) {
|
||||||
|
h.UpdateErrorResolution(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== Existing endpoints ====================
|
||||||
|
|
||||||
// ListRequestDetails returns a request-level list (success + error) for drill-down.
|
// ListRequestDetails returns a request-level list (success + error) for drill-down.
|
||||||
// GET /api/v1/admin/ops/requests
|
// GET /api/v1/admin/ops/requests
|
||||||
func (h *OpsHandler) ListRequestDetails(c *gin.Context) {
|
func (h *OpsHandler) ListRequestDetails(c *gin.Context) {
|
||||||
@@ -242,6 +709,11 @@ func (h *OpsHandler) ListRequestDetails(c *gin.Context) {
|
|||||||
type opsRetryRequest struct {
|
type opsRetryRequest struct {
|
||||||
Mode string `json:"mode"`
|
Mode string `json:"mode"`
|
||||||
PinnedAccountID *int64 `json:"pinned_account_id"`
|
PinnedAccountID *int64 `json:"pinned_account_id"`
|
||||||
|
Force bool `json:"force"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type opsResolveRequest struct {
|
||||||
|
Resolved bool `json:"resolved"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// RetryErrorRequest retries a failed request using stored request_body.
|
// RetryErrorRequest retries a failed request using stored request_body.
|
||||||
@@ -278,6 +750,16 @@ func (h *OpsHandler) RetryErrorRequest(c *gin.Context) {
|
|||||||
req.Mode = service.OpsRetryModeClient
|
req.Mode = service.OpsRetryModeClient
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Force flag is currently a UI-level acknowledgement. Server may still enforce safety constraints.
|
||||||
|
_ = req.Force
|
||||||
|
|
||||||
|
// Legacy endpoint safety: only allow retrying the client request here.
|
||||||
|
// Upstream retries must go through the split endpoints.
|
||||||
|
if strings.EqualFold(strings.TrimSpace(req.Mode), service.OpsRetryModeUpstream) {
|
||||||
|
response.BadRequest(c, "upstream retry is not supported on this endpoint")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, req.Mode, req.PinnedAccountID)
|
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, req.Mode, req.PinnedAccountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
@@ -287,6 +769,81 @@ func (h *OpsHandler) RetryErrorRequest(c *gin.Context) {
|
|||||||
response.Success(c, result)
|
response.Success(c, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListRetryAttempts lists retry attempts for an error log.
|
||||||
|
// GET /api/v1/admin/ops/errors/:id/retries
|
||||||
|
func (h *OpsHandler) ListRetryAttempts(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idStr := strings.TrimSpace(c.Param("id"))
|
||||||
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid error id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
limit := 50
|
||||||
|
if v := strings.TrimSpace(c.Query("limit")); v != "" {
|
||||||
|
n, err := strconv.Atoi(v)
|
||||||
|
if err != nil || n <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid limit")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
limit = n
|
||||||
|
}
|
||||||
|
|
||||||
|
items, err := h.opsService.ListRetryAttemptsByErrorID(c.Request.Context(), id, limit)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, items)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateErrorResolution allows manual resolve/unresolve.
|
||||||
|
// PUT /api/v1/admin/ops/errors/:id/resolve
|
||||||
|
func (h *OpsHandler) UpdateErrorResolution(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok || subject.UserID <= 0 {
|
||||||
|
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idStr := strings.TrimSpace(c.Param("id"))
|
||||||
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid error id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req opsResolveRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
uid := subject.UserID
|
||||||
|
if err := h.opsService.UpdateErrorResolution(c.Request.Context(), id, req.Resolved, &uid, nil); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"ok": true})
|
||||||
|
}
|
||||||
|
|
||||||
func parseOpsTimeRange(c *gin.Context, defaultRange string) (time.Time, time.Time, error) {
|
func parseOpsTimeRange(c *gin.Context, defaultRange string) (time.Time, time.Time, error) {
|
||||||
startStr := strings.TrimSpace(c.Query("start_time"))
|
startStr := strings.TrimSpace(c.Query("start_time"))
|
||||||
endStr := strings.TrimSpace(c.Query("end_time"))
|
endStr := strings.TrimSpace(c.Query("end_time"))
|
||||||
@@ -358,6 +915,10 @@ func parseOpsDuration(v string) (time.Duration, bool) {
|
|||||||
return 6 * time.Hour, true
|
return 6 * time.Hour, true
|
||||||
case "24h":
|
case "24h":
|
||||||
return 24 * time.Hour, true
|
return 24 * time.Hour, true
|
||||||
|
case "7d":
|
||||||
|
return 7 * 24 * time.Hour, true
|
||||||
|
case "30d":
|
||||||
|
return 30 * 24 * time.Hour, true
|
||||||
default:
|
default:
|
||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -196,6 +196,28 @@ func (h *ProxyHandler) Delete(c *gin.Context) {
|
|||||||
response.Success(c, gin.H{"message": "Proxy deleted successfully"})
|
response.Success(c, gin.H{"message": "Proxy deleted successfully"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BatchDelete handles batch deleting proxies
|
||||||
|
// POST /api/v1/admin/proxies/batch-delete
|
||||||
|
func (h *ProxyHandler) BatchDelete(c *gin.Context) {
|
||||||
|
type BatchDeleteRequest struct {
|
||||||
|
IDs []int64 `json:"ids" binding:"required,min=1"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var req BatchDeleteRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.adminService.BatchDeleteProxies(c.Request.Context(), req.IDs)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|
||||||
// Test handles testing proxy connectivity
|
// Test handles testing proxy connectivity
|
||||||
// POST /api/v1/admin/proxies/:id/test
|
// POST /api/v1/admin/proxies/:id/test
|
||||||
func (h *ProxyHandler) Test(c *gin.Context) {
|
func (h *ProxyHandler) Test(c *gin.Context) {
|
||||||
@@ -243,19 +265,17 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
page, pageSize := response.ParsePagination(c)
|
accounts, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID)
|
||||||
|
|
||||||
accounts, total, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID, page, pageSize)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
out := make([]dto.Account, 0, len(accounts))
|
out := make([]dto.ProxyAccountSummary, 0, len(accounts))
|
||||||
for i := range accounts {
|
for i := range accounts {
|
||||||
out = append(out, *dto.AccountFromService(&accounts[i]))
|
out = append(out, *dto.ProxyAccountSummaryFromService(&accounts[i]))
|
||||||
}
|
}
|
||||||
response.Paginated(c, out, total, page, pageSize)
|
response.Success(c, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BatchCreateProxyItem represents a single proxy in batch create request
|
// BatchCreateProxyItem represents a single proxy in batch create request
|
||||||
|
|||||||
@@ -89,6 +89,8 @@ func GroupFromServiceShallow(g *service.Group) *Group {
|
|||||||
ImagePrice4K: g.ImagePrice4K,
|
ImagePrice4K: g.ImagePrice4K,
|
||||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||||
FallbackGroupID: g.FallbackGroupID,
|
FallbackGroupID: g.FallbackGroupID,
|
||||||
|
ModelRouting: g.ModelRouting,
|
||||||
|
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||||
CreatedAt: g.CreatedAt,
|
CreatedAt: g.CreatedAt,
|
||||||
UpdatedAt: g.UpdatedAt,
|
UpdatedAt: g.UpdatedAt,
|
||||||
AccountCount: g.AccountCount,
|
AccountCount: g.AccountCount,
|
||||||
@@ -114,7 +116,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
|||||||
if a == nil {
|
if a == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &Account{
|
out := &Account{
|
||||||
ID: a.ID,
|
ID: a.ID,
|
||||||
Name: a.Name,
|
Name: a.Name,
|
||||||
Notes: a.Notes,
|
Notes: a.Notes,
|
||||||
@@ -125,6 +127,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
|||||||
ProxyID: a.ProxyID,
|
ProxyID: a.ProxyID,
|
||||||
Concurrency: a.Concurrency,
|
Concurrency: a.Concurrency,
|
||||||
Priority: a.Priority,
|
Priority: a.Priority,
|
||||||
|
RateMultiplier: a.BillingRateMultiplier(),
|
||||||
Status: a.Status,
|
Status: a.Status,
|
||||||
ErrorMessage: a.ErrorMessage,
|
ErrorMessage: a.ErrorMessage,
|
||||||
LastUsedAt: a.LastUsedAt,
|
LastUsedAt: a.LastUsedAt,
|
||||||
@@ -143,6 +146,24 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
|||||||
SessionWindowStatus: a.SessionWindowStatus,
|
SessionWindowStatus: a.SessionWindowStatus,
|
||||||
GroupIDs: a.GroupIDs,
|
GroupIDs: a.GroupIDs,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 提取 5h 窗口费用控制和会话数量控制配置(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||||
|
if a.IsAnthropicOAuthOrSetupToken() {
|
||||||
|
if limit := a.GetWindowCostLimit(); limit > 0 {
|
||||||
|
out.WindowCostLimit = &limit
|
||||||
|
}
|
||||||
|
if reserve := a.GetWindowCostStickyReserve(); reserve > 0 {
|
||||||
|
out.WindowCostStickyReserve = &reserve
|
||||||
|
}
|
||||||
|
if maxSessions := a.GetMaxSessions(); maxSessions > 0 {
|
||||||
|
out.MaxSessions = &maxSessions
|
||||||
|
}
|
||||||
|
if idleTimeout := a.GetSessionIdleTimeoutMinutes(); idleTimeout > 0 {
|
||||||
|
out.SessionIdleTimeoutMin = &idleTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func AccountFromService(a *service.Account) *Account {
|
func AccountFromService(a *service.Account) *Account {
|
||||||
@@ -214,6 +235,27 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
|
|||||||
return &ProxyWithAccountCount{
|
return &ProxyWithAccountCount{
|
||||||
Proxy: *ProxyFromService(&p.Proxy),
|
Proxy: *ProxyFromService(&p.Proxy),
|
||||||
AccountCount: p.AccountCount,
|
AccountCount: p.AccountCount,
|
||||||
|
LatencyMs: p.LatencyMs,
|
||||||
|
LatencyStatus: p.LatencyStatus,
|
||||||
|
LatencyMessage: p.LatencyMessage,
|
||||||
|
IPAddress: p.IPAddress,
|
||||||
|
Country: p.Country,
|
||||||
|
CountryCode: p.CountryCode,
|
||||||
|
Region: p.Region,
|
||||||
|
City: p.City,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ProxyAccountSummaryFromService(a *service.ProxyAccountSummary) *ProxyAccountSummary {
|
||||||
|
if a == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &ProxyAccountSummary{
|
||||||
|
ID: a.ID,
|
||||||
|
Name: a.Name,
|
||||||
|
Platform: a.Platform,
|
||||||
|
Type: a.Type,
|
||||||
|
Notes: a.Notes,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -279,6 +321,7 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, inclu
|
|||||||
TotalCost: l.TotalCost,
|
TotalCost: l.TotalCost,
|
||||||
ActualCost: l.ActualCost,
|
ActualCost: l.ActualCost,
|
||||||
RateMultiplier: l.RateMultiplier,
|
RateMultiplier: l.RateMultiplier,
|
||||||
|
AccountRateMultiplier: l.AccountRateMultiplier,
|
||||||
BillingType: l.BillingType,
|
BillingType: l.BillingType,
|
||||||
Stream: l.Stream,
|
Stream: l.Stream,
|
||||||
DurationMs: l.DurationMs,
|
DurationMs: l.DurationMs,
|
||||||
|
|||||||
@@ -58,6 +58,10 @@ type Group struct {
|
|||||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||||
|
|
||||||
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
|
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||||
|
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||||
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
|
||||||
@@ -76,6 +80,7 @@ type Account struct {
|
|||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
Concurrency int `json:"concurrency"`
|
Concurrency int `json:"concurrency"`
|
||||||
Priority int `json:"priority"`
|
Priority int `json:"priority"`
|
||||||
|
RateMultiplier float64 `json:"rate_multiplier"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
ErrorMessage string `json:"error_message"`
|
ErrorMessage string `json:"error_message"`
|
||||||
LastUsedAt *time.Time `json:"last_used_at"`
|
LastUsedAt *time.Time `json:"last_used_at"`
|
||||||
@@ -97,6 +102,16 @@ type Account struct {
|
|||||||
SessionWindowEnd *time.Time `json:"session_window_end"`
|
SessionWindowEnd *time.Time `json:"session_window_end"`
|
||||||
SessionWindowStatus string `json:"session_window_status"`
|
SessionWindowStatus string `json:"session_window_status"`
|
||||||
|
|
||||||
|
// 5h窗口费用控制(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||||
|
// 从 extra 字段提取,方便前端显示和编辑
|
||||||
|
WindowCostLimit *float64 `json:"window_cost_limit,omitempty"`
|
||||||
|
WindowCostStickyReserve *float64 `json:"window_cost_sticky_reserve,omitempty"`
|
||||||
|
|
||||||
|
// 会话数量控制(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||||
|
// 从 extra 字段提取,方便前端显示和编辑
|
||||||
|
MaxSessions *int `json:"max_sessions,omitempty"`
|
||||||
|
SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"`
|
||||||
|
|
||||||
Proxy *Proxy `json:"proxy,omitempty"`
|
Proxy *Proxy `json:"proxy,omitempty"`
|
||||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||||
|
|
||||||
@@ -130,6 +145,22 @@ type Proxy struct {
|
|||||||
type ProxyWithAccountCount struct {
|
type ProxyWithAccountCount struct {
|
||||||
Proxy
|
Proxy
|
||||||
AccountCount int64 `json:"account_count"`
|
AccountCount int64 `json:"account_count"`
|
||||||
|
LatencyMs *int64 `json:"latency_ms,omitempty"`
|
||||||
|
LatencyStatus string `json:"latency_status,omitempty"`
|
||||||
|
LatencyMessage string `json:"latency_message,omitempty"`
|
||||||
|
IPAddress string `json:"ip_address,omitempty"`
|
||||||
|
Country string `json:"country,omitempty"`
|
||||||
|
CountryCode string `json:"country_code,omitempty"`
|
||||||
|
Region string `json:"region,omitempty"`
|
||||||
|
City string `json:"city,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProxyAccountSummary struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Platform string `json:"platform"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Notes *string `json:"notes,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type RedeemCode struct {
|
type RedeemCode struct {
|
||||||
@@ -176,6 +207,7 @@ type UsageLog struct {
|
|||||||
TotalCost float64 `json:"total_cost"`
|
TotalCost float64 `json:"total_cost"`
|
||||||
ActualCost float64 `json:"actual_cost"`
|
ActualCost float64 `json:"actual_cost"`
|
||||||
RateMultiplier float64 `json:"rate_multiplier"`
|
RateMultiplier float64 `json:"rate_multiplier"`
|
||||||
|
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
|
||||||
|
|
||||||
BillingType int8 `json:"billing_type"`
|
BillingType int8 `json:"billing_type"`
|
||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
|
|||||||
@@ -185,7 +185,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
lastFailoverStatus := 0
|
lastFailoverStatus := 0
|
||||||
|
|
||||||
for {
|
for {
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if len(failedAccountIDs) == 0 {
|
if len(failedAccountIDs) == 0 {
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||||
@@ -320,7 +320,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
// 选择支持该模型的账号
|
// 选择支持该模型的账号
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if len(failedAccountIDs) == 0 {
|
if len(failedAccountIDs) == 0 {
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||||
|
|||||||
@@ -226,7 +226,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
lastFailoverStatus := 0
|
lastFailoverStatus := 0
|
||||||
|
|
||||||
for {
|
for {
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs)
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if len(failedAccountIDs) == 0 {
|
if len(failedAccountIDs) == 0 {
|
||||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||||
|
|||||||
@@ -544,6 +544,11 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
|||||||
body := w.buf.Bytes()
|
body := w.buf.Bytes()
|
||||||
parsed := parseOpsErrorResponse(body)
|
parsed := parseOpsErrorResponse(body)
|
||||||
|
|
||||||
|
// Skip logging if the error should be filtered based on settings
|
||||||
|
if shouldSkipOpsErrorLog(c.Request.Context(), ops, parsed.Message, string(body), c.Request.URL.Path) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
apiKey, _ := middleware2.GetAPIKeyFromContext(c)
|
apiKey, _ := middleware2.GetAPIKeyFromContext(c)
|
||||||
|
|
||||||
clientRequestID, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
|
clientRequestID, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
|
||||||
@@ -832,28 +837,30 @@ func normalizeOpsErrorType(errType string, code string) string {
|
|||||||
|
|
||||||
func classifyOpsPhase(errType, message, code string) string {
|
func classifyOpsPhase(errType, message, code string) string {
|
||||||
msg := strings.ToLower(message)
|
msg := strings.ToLower(message)
|
||||||
|
// Standardized phases: request|auth|routing|upstream|network|internal
|
||||||
|
// Map billing/concurrency/response => request; scheduling => routing.
|
||||||
switch strings.TrimSpace(code) {
|
switch strings.TrimSpace(code) {
|
||||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
||||||
return "billing"
|
return "request"
|
||||||
}
|
}
|
||||||
|
|
||||||
switch errType {
|
switch errType {
|
||||||
case "authentication_error":
|
case "authentication_error":
|
||||||
return "auth"
|
return "auth"
|
||||||
case "billing_error", "subscription_error":
|
case "billing_error", "subscription_error":
|
||||||
return "billing"
|
return "request"
|
||||||
case "rate_limit_error":
|
case "rate_limit_error":
|
||||||
if strings.Contains(msg, "concurrency") || strings.Contains(msg, "pending") || strings.Contains(msg, "queue") {
|
if strings.Contains(msg, "concurrency") || strings.Contains(msg, "pending") || strings.Contains(msg, "queue") {
|
||||||
return "concurrency"
|
return "request"
|
||||||
}
|
}
|
||||||
return "upstream"
|
return "upstream"
|
||||||
case "invalid_request_error":
|
case "invalid_request_error":
|
||||||
return "response"
|
return "request"
|
||||||
case "upstream_error", "overloaded_error":
|
case "upstream_error", "overloaded_error":
|
||||||
return "upstream"
|
return "upstream"
|
||||||
case "api_error":
|
case "api_error":
|
||||||
if strings.Contains(msg, "no available accounts") {
|
if strings.Contains(msg, "no available accounts") {
|
||||||
return "scheduling"
|
return "routing"
|
||||||
}
|
}
|
||||||
return "internal"
|
return "internal"
|
||||||
default:
|
default:
|
||||||
@@ -914,34 +921,38 @@ func classifyOpsIsBusinessLimited(errType, phase, code string, status int, messa
|
|||||||
}
|
}
|
||||||
|
|
||||||
func classifyOpsErrorOwner(phase string, message string) string {
|
func classifyOpsErrorOwner(phase string, message string) string {
|
||||||
|
// Standardized owners: client|provider|platform
|
||||||
switch phase {
|
switch phase {
|
||||||
case "upstream", "network":
|
case "upstream", "network":
|
||||||
return "provider"
|
return "provider"
|
||||||
case "billing", "concurrency", "auth", "response":
|
case "request", "auth":
|
||||||
return "client"
|
return "client"
|
||||||
|
case "routing", "internal":
|
||||||
|
return "platform"
|
||||||
default:
|
default:
|
||||||
if strings.Contains(strings.ToLower(message), "upstream") {
|
if strings.Contains(strings.ToLower(message), "upstream") {
|
||||||
return "provider"
|
return "provider"
|
||||||
}
|
}
|
||||||
return "sub2api"
|
return "platform"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func classifyOpsErrorSource(phase string, message string) string {
|
func classifyOpsErrorSource(phase string, message string) string {
|
||||||
|
// Standardized sources: client_request|upstream_http|gateway
|
||||||
switch phase {
|
switch phase {
|
||||||
case "upstream":
|
case "upstream":
|
||||||
return "upstream_http"
|
return "upstream_http"
|
||||||
case "network":
|
case "network":
|
||||||
return "upstream_network"
|
return "gateway"
|
||||||
case "billing":
|
case "request", "auth":
|
||||||
return "billing"
|
return "client_request"
|
||||||
case "concurrency":
|
case "routing", "internal":
|
||||||
return "concurrency"
|
return "gateway"
|
||||||
default:
|
default:
|
||||||
if strings.Contains(strings.ToLower(message), "upstream") {
|
if strings.Contains(strings.ToLower(message), "upstream") {
|
||||||
return "upstream_http"
|
return "upstream_http"
|
||||||
}
|
}
|
||||||
return "internal"
|
return "gateway"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -963,3 +974,42 @@ func truncateString(s string, max int) string {
|
|||||||
func strconvItoa(v int) string {
|
func strconvItoa(v int) string {
|
||||||
return strconv.Itoa(v)
|
return strconv.Itoa(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// shouldSkipOpsErrorLog determines if an error should be skipped from logging based on settings.
|
||||||
|
// Returns true for errors that should be filtered according to OpsAdvancedSettings.
|
||||||
|
func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message, body, requestPath string) bool {
|
||||||
|
if ops == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get advanced settings to check filter configuration
|
||||||
|
settings, err := ops.GetOpsAdvancedSettings(ctx)
|
||||||
|
if err != nil || settings == nil {
|
||||||
|
// If we can't get settings, don't skip (fail open)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
msgLower := strings.ToLower(message)
|
||||||
|
bodyLower := strings.ToLower(body)
|
||||||
|
|
||||||
|
// Check if count_tokens errors should be ignored
|
||||||
|
if settings.IgnoreCountTokensErrors && strings.Contains(requestPath, "/count_tokens") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if context canceled errors should be ignored (client disconnects)
|
||||||
|
if settings.IgnoreContextCanceled {
|
||||||
|
if strings.Contains(msgLower, "context canceled") || strings.Contains(bodyLower, "context canceled") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if "no available accounts" errors should be ignored
|
||||||
|
if settings.IgnoreNoAvailableAccounts {
|
||||||
|
if strings.Contains(msgLower, "no available accounts") || strings.Contains(bodyLower, "no available accounts") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,8 +1,14 @@
|
|||||||
package usagestats
|
package usagestats
|
||||||
|
|
||||||
// AccountStats 账号使用统计
|
// AccountStats 账号使用统计
|
||||||
|
//
|
||||||
|
// cost: 账号口径费用(使用 total_cost * account_rate_multiplier)
|
||||||
|
// standard_cost: 标准费用(使用 total_cost,不含倍率)
|
||||||
|
// user_cost: 用户/API Key 口径费用(使用 actual_cost,受分组倍率影响)
|
||||||
type AccountStats struct {
|
type AccountStats struct {
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
Tokens int64 `json:"tokens"`
|
Tokens int64 `json:"tokens"`
|
||||||
Cost float64 `json:"cost"`
|
Cost float64 `json:"cost"`
|
||||||
|
StandardCost float64 `json:"standard_cost"`
|
||||||
|
UserCost float64 `json:"user_cost"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -154,6 +154,7 @@ type UsageStats struct {
|
|||||||
TotalTokens int64 `json:"total_tokens"`
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
TotalCost float64 `json:"total_cost"`
|
TotalCost float64 `json:"total_cost"`
|
||||||
TotalActualCost float64 `json:"total_actual_cost"`
|
TotalActualCost float64 `json:"total_actual_cost"`
|
||||||
|
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
|
||||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -177,25 +178,29 @@ type AccountUsageHistory struct {
|
|||||||
Label string `json:"label"`
|
Label string `json:"label"`
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
Tokens int64 `json:"tokens"`
|
Tokens int64 `json:"tokens"`
|
||||||
Cost float64 `json:"cost"`
|
Cost float64 `json:"cost"` // 标准计费(total_cost)
|
||||||
ActualCost float64 `json:"actual_cost"`
|
ActualCost float64 `json:"actual_cost"` // 账号口径费用(total_cost * account_rate_multiplier)
|
||||||
|
UserCost float64 `json:"user_cost"` // 用户口径费用(actual_cost,受分组倍率影响)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AccountUsageSummary represents summary statistics for an account
|
// AccountUsageSummary represents summary statistics for an account
|
||||||
type AccountUsageSummary struct {
|
type AccountUsageSummary struct {
|
||||||
Days int `json:"days"`
|
Days int `json:"days"`
|
||||||
ActualDaysUsed int `json:"actual_days_used"`
|
ActualDaysUsed int `json:"actual_days_used"`
|
||||||
TotalCost float64 `json:"total_cost"`
|
TotalCost float64 `json:"total_cost"` // 账号口径费用
|
||||||
|
TotalUserCost float64 `json:"total_user_cost"` // 用户口径费用
|
||||||
TotalStandardCost float64 `json:"total_standard_cost"`
|
TotalStandardCost float64 `json:"total_standard_cost"`
|
||||||
TotalRequests int64 `json:"total_requests"`
|
TotalRequests int64 `json:"total_requests"`
|
||||||
TotalTokens int64 `json:"total_tokens"`
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
AvgDailyCost float64 `json:"avg_daily_cost"`
|
AvgDailyCost float64 `json:"avg_daily_cost"` // 账号口径日均
|
||||||
|
AvgDailyUserCost float64 `json:"avg_daily_user_cost"`
|
||||||
AvgDailyRequests float64 `json:"avg_daily_requests"`
|
AvgDailyRequests float64 `json:"avg_daily_requests"`
|
||||||
AvgDailyTokens float64 `json:"avg_daily_tokens"`
|
AvgDailyTokens float64 `json:"avg_daily_tokens"`
|
||||||
AvgDurationMs float64 `json:"avg_duration_ms"`
|
AvgDurationMs float64 `json:"avg_duration_ms"`
|
||||||
Today *struct {
|
Today *struct {
|
||||||
Date string `json:"date"`
|
Date string `json:"date"`
|
||||||
Cost float64 `json:"cost"`
|
Cost float64 `json:"cost"`
|
||||||
|
UserCost float64 `json:"user_cost"`
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
Tokens int64 `json:"tokens"`
|
Tokens int64 `json:"tokens"`
|
||||||
} `json:"today"`
|
} `json:"today"`
|
||||||
@@ -203,6 +208,7 @@ type AccountUsageSummary struct {
|
|||||||
Date string `json:"date"`
|
Date string `json:"date"`
|
||||||
Label string `json:"label"`
|
Label string `json:"label"`
|
||||||
Cost float64 `json:"cost"`
|
Cost float64 `json:"cost"`
|
||||||
|
UserCost float64 `json:"user_cost"`
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
} `json:"highest_cost_day"`
|
} `json:"highest_cost_day"`
|
||||||
HighestRequestDay *struct {
|
HighestRequestDay *struct {
|
||||||
@@ -210,6 +216,7 @@ type AccountUsageSummary struct {
|
|||||||
Label string `json:"label"`
|
Label string `json:"label"`
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
Cost float64 `json:"cost"`
|
Cost float64 `json:"cost"`
|
||||||
|
UserCost float64 `json:"user_cost"`
|
||||||
} `json:"highest_request_day"`
|
} `json:"highest_request_day"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -80,6 +80,10 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
|
|||||||
SetSchedulable(account.Schedulable).
|
SetSchedulable(account.Schedulable).
|
||||||
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
|
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
|
||||||
|
|
||||||
|
if account.RateMultiplier != nil {
|
||||||
|
builder.SetRateMultiplier(*account.RateMultiplier)
|
||||||
|
}
|
||||||
|
|
||||||
if account.ProxyID != nil {
|
if account.ProxyID != nil {
|
||||||
builder.SetProxyID(*account.ProxyID)
|
builder.SetProxyID(*account.ProxyID)
|
||||||
}
|
}
|
||||||
@@ -291,6 +295,10 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
|||||||
SetSchedulable(account.Schedulable).
|
SetSchedulable(account.Schedulable).
|
||||||
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
|
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
|
||||||
|
|
||||||
|
if account.RateMultiplier != nil {
|
||||||
|
builder.SetRateMultiplier(*account.RateMultiplier)
|
||||||
|
}
|
||||||
|
|
||||||
if account.ProxyID != nil {
|
if account.ProxyID != nil {
|
||||||
builder.SetProxyID(*account.ProxyID)
|
builder.SetProxyID(*account.ProxyID)
|
||||||
} else {
|
} else {
|
||||||
@@ -786,6 +794,46 @@ func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, i
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||||
|
if scope == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
now := time.Now().UTC()
|
||||||
|
payload := map[string]string{
|
||||||
|
"rate_limited_at": now.Format(time.RFC3339),
|
||||||
|
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
raw, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
path := "{model_rate_limits," + scope + "}"
|
||||||
|
client := clientFromContext(ctx, r.client)
|
||||||
|
result, err := client.ExecContext(
|
||||||
|
ctx,
|
||||||
|
"UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
|
||||||
|
path,
|
||||||
|
raw,
|
||||||
|
id,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
affected, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if affected == 0 {
|
||||||
|
return service.ErrAccountNotFound
|
||||||
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||||
_, err := r.client.Account.Update().
|
_, err := r.client.Account.Update().
|
||||||
Where(dbaccount.IDEQ(id)).
|
Where(dbaccount.IDEQ(id)).
|
||||||
@@ -877,6 +925,30 @@ func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *accountRepository) ClearModelRateLimits(ctx context.Context, id int64) error {
|
||||||
|
client := clientFromContext(ctx, r.client)
|
||||||
|
result, err := client.ExecContext(
|
||||||
|
ctx,
|
||||||
|
"UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) - 'model_rate_limits', updated_at = NOW() WHERE id = $1 AND deleted_at IS NULL",
|
||||||
|
id,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
affected, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if affected == 0 {
|
||||||
|
return service.ErrAccountNotFound
|
||||||
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||||
builder := r.client.Account.Update().
|
builder := r.client.Account.Update().
|
||||||
Where(dbaccount.IDEQ(id)).
|
Where(dbaccount.IDEQ(id)).
|
||||||
@@ -999,6 +1071,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
|||||||
args = append(args, *updates.Priority)
|
args = append(args, *updates.Priority)
|
||||||
idx++
|
idx++
|
||||||
}
|
}
|
||||||
|
if updates.RateMultiplier != nil {
|
||||||
|
setClauses = append(setClauses, "rate_multiplier = $"+itoa(idx))
|
||||||
|
args = append(args, *updates.RateMultiplier)
|
||||||
|
idx++
|
||||||
|
}
|
||||||
if updates.Status != nil {
|
if updates.Status != nil {
|
||||||
setClauses = append(setClauses, "status = $"+itoa(idx))
|
setClauses = append(setClauses, "status = $"+itoa(idx))
|
||||||
args = append(args, *updates.Status)
|
args = append(args, *updates.Status)
|
||||||
@@ -1347,6 +1424,8 @@ func accountEntityToService(m *dbent.Account) *service.Account {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rateMultiplier := m.RateMultiplier
|
||||||
|
|
||||||
return &service.Account{
|
return &service.Account{
|
||||||
ID: m.ID,
|
ID: m.ID,
|
||||||
Name: m.Name,
|
Name: m.Name,
|
||||||
@@ -1358,6 +1437,7 @@ func accountEntityToService(m *dbent.Account) *service.Account {
|
|||||||
ProxyID: m.ProxyID,
|
ProxyID: m.ProxyID,
|
||||||
Concurrency: m.Concurrency,
|
Concurrency: m.Concurrency,
|
||||||
Priority: m.Priority,
|
Priority: m.Priority,
|
||||||
|
RateMultiplier: &rateMultiplier,
|
||||||
Status: m.Status,
|
Status: m.Status,
|
||||||
ErrorMessage: derefString(m.ErrorMessage),
|
ErrorMessage: derefString(m.ErrorMessage),
|
||||||
LastUsedAt: m.LastUsedAt,
|
LastUsedAt: m.LastUsedAt,
|
||||||
|
|||||||
@@ -136,6 +136,8 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
|||||||
group.FieldImagePrice4k,
|
group.FieldImagePrice4k,
|
||||||
group.FieldClaudeCodeOnly,
|
group.FieldClaudeCodeOnly,
|
||||||
group.FieldFallbackGroupID,
|
group.FieldFallbackGroupID,
|
||||||
|
group.FieldModelRoutingEnabled,
|
||||||
|
group.FieldModelRouting,
|
||||||
)
|
)
|
||||||
}).
|
}).
|
||||||
Only(ctx)
|
Only(ctx)
|
||||||
@@ -422,6 +424,8 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
|||||||
DefaultValidityDays: g.DefaultValidityDays,
|
DefaultValidityDays: g.DefaultValidityDays,
|
||||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||||
FallbackGroupID: g.FallbackGroupID,
|
FallbackGroupID: g.FallbackGroupID,
|
||||||
|
ModelRouting: g.ModelRouting,
|
||||||
|
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||||
CreatedAt: g.CreatedAt,
|
CreatedAt: g.CreatedAt,
|
||||||
UpdatedAt: g.UpdatedAt,
|
UpdatedAt: g.UpdatedAt,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
)
|
)
|
||||||
@@ -41,21 +42,22 @@ func isPostgresDriver(db *sql.DB) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
|
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||||
startUTC := start.UTC()
|
loc := timezone.Location()
|
||||||
endUTC := end.UTC()
|
startLocal := start.In(loc)
|
||||||
if !endUTC.After(startUTC) {
|
endLocal := end.In(loc)
|
||||||
|
if !endLocal.After(startLocal) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
hourStart := startUTC.Truncate(time.Hour)
|
hourStart := startLocal.Truncate(time.Hour)
|
||||||
hourEnd := endUTC.Truncate(time.Hour)
|
hourEnd := endLocal.Truncate(time.Hour)
|
||||||
if endUTC.After(hourEnd) {
|
if endLocal.After(hourEnd) {
|
||||||
hourEnd = hourEnd.Add(time.Hour)
|
hourEnd = hourEnd.Add(time.Hour)
|
||||||
}
|
}
|
||||||
|
|
||||||
dayStart := truncateToDayUTC(startUTC)
|
dayStart := truncateToDay(startLocal)
|
||||||
dayEnd := truncateToDayUTC(endUTC)
|
dayEnd := truncateToDay(endLocal)
|
||||||
if endUTC.After(dayEnd) {
|
if endLocal.After(dayEnd) {
|
||||||
dayEnd = dayEnd.Add(24 * time.Hour)
|
dayEnd = dayEnd.Add(24 * time.Hour)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -146,38 +148,41 @@ func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *dashboardAggregationRepository) insertHourlyActiveUsers(ctx context.Context, start, end time.Time) error {
|
func (r *dashboardAggregationRepository) insertHourlyActiveUsers(ctx context.Context, start, end time.Time) error {
|
||||||
|
tzName := timezone.Name()
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO usage_dashboard_hourly_users (bucket_start, user_id)
|
INSERT INTO usage_dashboard_hourly_users (bucket_start, user_id)
|
||||||
SELECT DISTINCT
|
SELECT DISTINCT
|
||||||
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
|
date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
|
||||||
user_id
|
user_id
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE created_at >= $1 AND created_at < $2
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
ON CONFLICT DO NOTHING
|
ON CONFLICT DO NOTHING
|
||||||
`
|
`
|
||||||
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC())
|
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *dashboardAggregationRepository) insertDailyActiveUsers(ctx context.Context, start, end time.Time) error {
|
func (r *dashboardAggregationRepository) insertDailyActiveUsers(ctx context.Context, start, end time.Time) error {
|
||||||
|
tzName := timezone.Name()
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO usage_dashboard_daily_users (bucket_date, user_id)
|
INSERT INTO usage_dashboard_daily_users (bucket_date, user_id)
|
||||||
SELECT DISTINCT
|
SELECT DISTINCT
|
||||||
(bucket_start AT TIME ZONE 'UTC')::date AS bucket_date,
|
(bucket_start AT TIME ZONE $3)::date AS bucket_date,
|
||||||
user_id
|
user_id
|
||||||
FROM usage_dashboard_hourly_users
|
FROM usage_dashboard_hourly_users
|
||||||
WHERE bucket_start >= $1 AND bucket_start < $2
|
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||||
ON CONFLICT DO NOTHING
|
ON CONFLICT DO NOTHING
|
||||||
`
|
`
|
||||||
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC())
|
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Context, start, end time.Time) error {
|
func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Context, start, end time.Time) error {
|
||||||
|
tzName := timezone.Name()
|
||||||
query := `
|
query := `
|
||||||
WITH hourly AS (
|
WITH hourly AS (
|
||||||
SELECT
|
SELECT
|
||||||
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
|
date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
|
||||||
COUNT(*) AS total_requests,
|
COUNT(*) AS total_requests,
|
||||||
COALESCE(SUM(input_tokens), 0) AS input_tokens,
|
COALESCE(SUM(input_tokens), 0) AS input_tokens,
|
||||||
COALESCE(SUM(output_tokens), 0) AS output_tokens,
|
COALESCE(SUM(output_tokens), 0) AS output_tokens,
|
||||||
@@ -236,15 +241,16 @@ func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Cont
|
|||||||
active_users = EXCLUDED.active_users,
|
active_users = EXCLUDED.active_users,
|
||||||
computed_at = EXCLUDED.computed_at
|
computed_at = EXCLUDED.computed_at
|
||||||
`
|
`
|
||||||
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC())
|
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Context, start, end time.Time) error {
|
func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Context, start, end time.Time) error {
|
||||||
|
tzName := timezone.Name()
|
||||||
query := `
|
query := `
|
||||||
WITH daily AS (
|
WITH daily AS (
|
||||||
SELECT
|
SELECT
|
||||||
(bucket_start AT TIME ZONE 'UTC')::date AS bucket_date,
|
(bucket_start AT TIME ZONE $5)::date AS bucket_date,
|
||||||
COALESCE(SUM(total_requests), 0) AS total_requests,
|
COALESCE(SUM(total_requests), 0) AS total_requests,
|
||||||
COALESCE(SUM(input_tokens), 0) AS input_tokens,
|
COALESCE(SUM(input_tokens), 0) AS input_tokens,
|
||||||
COALESCE(SUM(output_tokens), 0) AS output_tokens,
|
COALESCE(SUM(output_tokens), 0) AS output_tokens,
|
||||||
@@ -255,7 +261,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
|
|||||||
COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms
|
COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms
|
||||||
FROM usage_dashboard_hourly
|
FROM usage_dashboard_hourly
|
||||||
WHERE bucket_start >= $1 AND bucket_start < $2
|
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||||
GROUP BY (bucket_start AT TIME ZONE 'UTC')::date
|
GROUP BY (bucket_start AT TIME ZONE $5)::date
|
||||||
),
|
),
|
||||||
user_counts AS (
|
user_counts AS (
|
||||||
SELECT bucket_date, COUNT(*) AS active_users
|
SELECT bucket_date, COUNT(*) AS active_users
|
||||||
@@ -303,7 +309,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
|
|||||||
active_users = EXCLUDED.active_users,
|
active_users = EXCLUDED.active_users,
|
||||||
computed_at = EXCLUDED.computed_at
|
computed_at = EXCLUDED.computed_at
|
||||||
`
|
`
|
||||||
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC(), start.UTC(), end.UTC())
|
_, err := r.sql.ExecContext(ctx, query, start, end, start, end, tzName)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -376,9 +382,8 @@ func (r *dashboardAggregationRepository) createUsageLogsPartition(ctx context.Co
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func truncateToDayUTC(t time.Time) time.Time {
|
func truncateToDay(t time.Time) time.Time {
|
||||||
t = t.UTC()
|
return timezone.StartOfDay(t)
|
||||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func truncateToMonthUTC(t time.Time) time.Time {
|
func truncateToMonthUTC(t time.Time) time.Time {
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
geminiTokenKeyPrefix = "gemini:token:"
|
oauthTokenKeyPrefix = "oauth:token:"
|
||||||
geminiRefreshLockKeyPrefix = "gemini:refresh_lock:"
|
oauthRefreshLockKeyPrefix = "oauth:refresh_lock:"
|
||||||
)
|
)
|
||||||
|
|
||||||
type geminiTokenCache struct {
|
type geminiTokenCache struct {
|
||||||
@@ -24,21 +24,26 @@ func NewGeminiTokenCache(rdb *redis.Client) service.GeminiTokenCache {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *geminiTokenCache) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
|
func (c *geminiTokenCache) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
|
||||||
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
|
key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
|
||||||
return c.rdb.Get(ctx, key).Result()
|
return c.rdb.Get(ctx, key).Result()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
|
func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
|
||||||
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
|
key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
|
||||||
return c.rdb.Set(ctx, key, token, ttl).Err()
|
return c.rdb.Set(ctx, key, token, ttl).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *geminiTokenCache) DeleteAccessToken(ctx context.Context, cacheKey string) error {
|
||||||
|
key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
|
|
||||||
func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
|
func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
|
||||||
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
|
key := fmt.Sprintf("%s%s", oauthRefreshLockKeyPrefix, cacheKey)
|
||||||
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
|
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *geminiTokenCache) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
|
func (c *geminiTokenCache) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
|
||||||
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
|
key := fmt.Sprintf("%s%s", oauthRefreshLockKeyPrefix, cacheKey)
|
||||||
return c.rdb.Del(ctx, key).Err()
|
return c.rdb.Del(ctx, key).Err()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,47 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GeminiTokenCacheSuite struct {
|
||||||
|
IntegrationRedisSuite
|
||||||
|
cache service.GeminiTokenCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GeminiTokenCacheSuite) SetupTest() {
|
||||||
|
s.IntegrationRedisSuite.SetupTest()
|
||||||
|
s.cache = NewGeminiTokenCache(s.rdb)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken() {
|
||||||
|
cacheKey := "project-123"
|
||||||
|
token := "token-value"
|
||||||
|
require.NoError(s.T(), s.cache.SetAccessToken(s.ctx, cacheKey, token, time.Minute))
|
||||||
|
|
||||||
|
got, err := s.cache.GetAccessToken(s.ctx, cacheKey)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Equal(s.T(), token, got)
|
||||||
|
|
||||||
|
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, cacheKey))
|
||||||
|
|
||||||
|
_, err = s.cache.GetAccessToken(s.ctx, cacheKey)
|
||||||
|
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken_MissingKey() {
|
||||||
|
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, "missing-key"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeminiTokenCacheSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(GeminiTokenCacheSuite))
|
||||||
|
}
|
||||||
28
backend/internal/repository/gemini_token_cache_test.go
Normal file
28
backend/internal/repository/gemini_token_cache_test.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGeminiTokenCache_DeleteAccessToken_RedisError(t *testing.T) {
|
||||||
|
rdb := redis.NewClient(&redis.Options{
|
||||||
|
Addr: "127.0.0.1:1",
|
||||||
|
DialTimeout: 50 * time.Millisecond,
|
||||||
|
ReadTimeout: 50 * time.Millisecond,
|
||||||
|
WriteTimeout: 50 * time.Millisecond,
|
||||||
|
})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = rdb.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
cache := NewGeminiTokenCache(rdb)
|
||||||
|
err := cache.DeleteAccessToken(context.Background(), "broken")
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
@@ -49,7 +49,13 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
|||||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||||
SetNillableFallbackGroupID(groupIn.FallbackGroupID)
|
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
|
||||||
|
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
|
||||||
|
|
||||||
|
// 设置模型路由配置
|
||||||
|
if groupIn.ModelRouting != nil {
|
||||||
|
builder = builder.SetModelRouting(groupIn.ModelRouting)
|
||||||
|
}
|
||||||
|
|
||||||
created, err := builder.Save(ctx)
|
created, err := builder.Save(ctx)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -101,7 +107,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
|||||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly)
|
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||||
|
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
|
||||||
|
|
||||||
// 处理 FallbackGroupID:nil 时清除,否则设置
|
// 处理 FallbackGroupID:nil 时清除,否则设置
|
||||||
if groupIn.FallbackGroupID != nil {
|
if groupIn.FallbackGroupID != nil {
|
||||||
@@ -110,6 +117,13 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
|||||||
builder = builder.ClearFallbackGroupID()
|
builder = builder.ClearFallbackGroupID()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 处理 ModelRouting:nil 时清除,否则设置
|
||||||
|
if groupIn.ModelRouting != nil {
|
||||||
|
builder = builder.SetModelRouting(groupIn.ModelRouting)
|
||||||
|
} else {
|
||||||
|
builder = builder.ClearModelRouting()
|
||||||
|
}
|
||||||
|
|
||||||
updated, err := builder.Save(ctx)
|
updated, err := builder.Save(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
|
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
|
||||||
|
|||||||
@@ -55,7 +55,6 @@ INSERT INTO ops_error_logs (
|
|||||||
upstream_error_message,
|
upstream_error_message,
|
||||||
upstream_error_detail,
|
upstream_error_detail,
|
||||||
upstream_errors,
|
upstream_errors,
|
||||||
duration_ms,
|
|
||||||
time_to_first_token_ms,
|
time_to_first_token_ms,
|
||||||
request_body,
|
request_body,
|
||||||
request_body_truncated,
|
request_body_truncated,
|
||||||
@@ -65,7 +64,7 @@ INSERT INTO ops_error_logs (
|
|||||||
retry_count,
|
retry_count,
|
||||||
created_at
|
created_at
|
||||||
) VALUES (
|
) VALUES (
|
||||||
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35
|
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34
|
||||||
) RETURNING id`
|
) RETURNING id`
|
||||||
|
|
||||||
var id int64
|
var id int64
|
||||||
@@ -98,7 +97,6 @@ INSERT INTO ops_error_logs (
|
|||||||
opsNullString(input.UpstreamErrorMessage),
|
opsNullString(input.UpstreamErrorMessage),
|
||||||
opsNullString(input.UpstreamErrorDetail),
|
opsNullString(input.UpstreamErrorDetail),
|
||||||
opsNullString(input.UpstreamErrorsJSON),
|
opsNullString(input.UpstreamErrorsJSON),
|
||||||
opsNullInt(input.DurationMs),
|
|
||||||
opsNullInt64(input.TimeToFirstTokenMs),
|
opsNullInt64(input.TimeToFirstTokenMs),
|
||||||
opsNullString(input.RequestBodyJSON),
|
opsNullString(input.RequestBodyJSON),
|
||||||
input.RequestBodyTruncated,
|
input.RequestBodyTruncated,
|
||||||
@@ -135,7 +133,7 @@ func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsEr
|
|||||||
}
|
}
|
||||||
|
|
||||||
where, args := buildOpsErrorLogsWhere(filter)
|
where, args := buildOpsErrorLogsWhere(filter)
|
||||||
countSQL := "SELECT COUNT(*) FROM ops_error_logs " + where
|
countSQL := "SELECT COUNT(*) FROM ops_error_logs e " + where
|
||||||
|
|
||||||
var total int
|
var total int
|
||||||
if err := r.db.QueryRowContext(ctx, countSQL, args...).Scan(&total); err != nil {
|
if err := r.db.QueryRowContext(ctx, countSQL, args...).Scan(&total); err != nil {
|
||||||
@@ -146,28 +144,43 @@ func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsEr
|
|||||||
argsWithLimit := append(args, pageSize, offset)
|
argsWithLimit := append(args, pageSize, offset)
|
||||||
selectSQL := `
|
selectSQL := `
|
||||||
SELECT
|
SELECT
|
||||||
id,
|
e.id,
|
||||||
created_at,
|
e.created_at,
|
||||||
error_phase,
|
e.error_phase,
|
||||||
error_type,
|
e.error_type,
|
||||||
severity,
|
COALESCE(e.error_owner, ''),
|
||||||
COALESCE(upstream_status_code, status_code, 0),
|
COALESCE(e.error_source, ''),
|
||||||
COALESCE(platform, ''),
|
e.severity,
|
||||||
COALESCE(model, ''),
|
COALESCE(e.upstream_status_code, e.status_code, 0),
|
||||||
duration_ms,
|
COALESCE(e.platform, ''),
|
||||||
COALESCE(client_request_id, ''),
|
COALESCE(e.model, ''),
|
||||||
COALESCE(request_id, ''),
|
COALESCE(e.is_retryable, false),
|
||||||
COALESCE(error_message, ''),
|
COALESCE(e.retry_count, 0),
|
||||||
user_id,
|
COALESCE(e.resolved, false),
|
||||||
api_key_id,
|
e.resolved_at,
|
||||||
account_id,
|
e.resolved_by_user_id,
|
||||||
group_id,
|
COALESCE(u2.email, ''),
|
||||||
CASE WHEN client_ip IS NULL THEN NULL ELSE client_ip::text END,
|
e.resolved_retry_id,
|
||||||
COALESCE(request_path, ''),
|
COALESCE(e.client_request_id, ''),
|
||||||
stream
|
COALESCE(e.request_id, ''),
|
||||||
FROM ops_error_logs
|
COALESCE(e.error_message, ''),
|
||||||
|
e.user_id,
|
||||||
|
COALESCE(u.email, ''),
|
||||||
|
e.api_key_id,
|
||||||
|
e.account_id,
|
||||||
|
COALESCE(a.name, ''),
|
||||||
|
e.group_id,
|
||||||
|
COALESCE(g.name, ''),
|
||||||
|
CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END,
|
||||||
|
COALESCE(e.request_path, ''),
|
||||||
|
e.stream
|
||||||
|
FROM ops_error_logs e
|
||||||
|
LEFT JOIN accounts a ON e.account_id = a.id
|
||||||
|
LEFT JOIN groups g ON e.group_id = g.id
|
||||||
|
LEFT JOIN users u ON e.user_id = u.id
|
||||||
|
LEFT JOIN users u2 ON e.resolved_by_user_id = u2.id
|
||||||
` + where + `
|
` + where + `
|
||||||
ORDER BY created_at DESC
|
ORDER BY e.created_at DESC
|
||||||
LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
|
LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
|
||||||
|
|
||||||
rows, err := r.db.QueryContext(ctx, selectSQL, argsWithLimit...)
|
rows, err := r.db.QueryContext(ctx, selectSQL, argsWithLimit...)
|
||||||
@@ -179,39 +192,65 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
|
|||||||
out := make([]*service.OpsErrorLog, 0, pageSize)
|
out := make([]*service.OpsErrorLog, 0, pageSize)
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var item service.OpsErrorLog
|
var item service.OpsErrorLog
|
||||||
var latency sql.NullInt64
|
|
||||||
var statusCode sql.NullInt64
|
var statusCode sql.NullInt64
|
||||||
var clientIP sql.NullString
|
var clientIP sql.NullString
|
||||||
var userID sql.NullInt64
|
var userID sql.NullInt64
|
||||||
var apiKeyID sql.NullInt64
|
var apiKeyID sql.NullInt64
|
||||||
var accountID sql.NullInt64
|
var accountID sql.NullInt64
|
||||||
|
var accountName string
|
||||||
var groupID sql.NullInt64
|
var groupID sql.NullInt64
|
||||||
|
var groupName string
|
||||||
|
var userEmail string
|
||||||
|
var resolvedAt sql.NullTime
|
||||||
|
var resolvedBy sql.NullInt64
|
||||||
|
var resolvedByName string
|
||||||
|
var resolvedRetryID sql.NullInt64
|
||||||
if err := rows.Scan(
|
if err := rows.Scan(
|
||||||
&item.ID,
|
&item.ID,
|
||||||
&item.CreatedAt,
|
&item.CreatedAt,
|
||||||
&item.Phase,
|
&item.Phase,
|
||||||
&item.Type,
|
&item.Type,
|
||||||
|
&item.Owner,
|
||||||
|
&item.Source,
|
||||||
&item.Severity,
|
&item.Severity,
|
||||||
&statusCode,
|
&statusCode,
|
||||||
&item.Platform,
|
&item.Platform,
|
||||||
&item.Model,
|
&item.Model,
|
||||||
&latency,
|
&item.IsRetryable,
|
||||||
|
&item.RetryCount,
|
||||||
|
&item.Resolved,
|
||||||
|
&resolvedAt,
|
||||||
|
&resolvedBy,
|
||||||
|
&resolvedByName,
|
||||||
|
&resolvedRetryID,
|
||||||
&item.ClientRequestID,
|
&item.ClientRequestID,
|
||||||
&item.RequestID,
|
&item.RequestID,
|
||||||
&item.Message,
|
&item.Message,
|
||||||
&userID,
|
&userID,
|
||||||
|
&userEmail,
|
||||||
&apiKeyID,
|
&apiKeyID,
|
||||||
&accountID,
|
&accountID,
|
||||||
|
&accountName,
|
||||||
&groupID,
|
&groupID,
|
||||||
|
&groupName,
|
||||||
&clientIP,
|
&clientIP,
|
||||||
&item.RequestPath,
|
&item.RequestPath,
|
||||||
&item.Stream,
|
&item.Stream,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if latency.Valid {
|
if resolvedAt.Valid {
|
||||||
v := int(latency.Int64)
|
t := resolvedAt.Time
|
||||||
item.LatencyMs = &v
|
item.ResolvedAt = &t
|
||||||
|
}
|
||||||
|
if resolvedBy.Valid {
|
||||||
|
v := resolvedBy.Int64
|
||||||
|
item.ResolvedByUserID = &v
|
||||||
|
}
|
||||||
|
item.ResolvedByUserName = resolvedByName
|
||||||
|
if resolvedRetryID.Valid {
|
||||||
|
v := resolvedRetryID.Int64
|
||||||
|
item.ResolvedRetryID = &v
|
||||||
}
|
}
|
||||||
item.StatusCode = int(statusCode.Int64)
|
item.StatusCode = int(statusCode.Int64)
|
||||||
if clientIP.Valid {
|
if clientIP.Valid {
|
||||||
@@ -222,6 +261,7 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
|
|||||||
v := userID.Int64
|
v := userID.Int64
|
||||||
item.UserID = &v
|
item.UserID = &v
|
||||||
}
|
}
|
||||||
|
item.UserEmail = userEmail
|
||||||
if apiKeyID.Valid {
|
if apiKeyID.Valid {
|
||||||
v := apiKeyID.Int64
|
v := apiKeyID.Int64
|
||||||
item.APIKeyID = &v
|
item.APIKeyID = &v
|
||||||
@@ -230,10 +270,12 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
|
|||||||
v := accountID.Int64
|
v := accountID.Int64
|
||||||
item.AccountID = &v
|
item.AccountID = &v
|
||||||
}
|
}
|
||||||
|
item.AccountName = accountName
|
||||||
if groupID.Valid {
|
if groupID.Valid {
|
||||||
v := groupID.Int64
|
v := groupID.Int64
|
||||||
item.GroupID = &v
|
item.GroupID = &v
|
||||||
}
|
}
|
||||||
|
item.GroupName = groupName
|
||||||
out = append(out, &item)
|
out = append(out, &item)
|
||||||
}
|
}
|
||||||
if err := rows.Err(); err != nil {
|
if err := rows.Err(); err != nil {
|
||||||
@@ -258,49 +300,64 @@ func (r *opsRepository) GetErrorLogByID(ctx context.Context, id int64) (*service
|
|||||||
|
|
||||||
q := `
|
q := `
|
||||||
SELECT
|
SELECT
|
||||||
id,
|
e.id,
|
||||||
created_at,
|
e.created_at,
|
||||||
error_phase,
|
e.error_phase,
|
||||||
error_type,
|
e.error_type,
|
||||||
severity,
|
COALESCE(e.error_owner, ''),
|
||||||
COALESCE(upstream_status_code, status_code, 0),
|
COALESCE(e.error_source, ''),
|
||||||
COALESCE(platform, ''),
|
e.severity,
|
||||||
COALESCE(model, ''),
|
COALESCE(e.upstream_status_code, e.status_code, 0),
|
||||||
duration_ms,
|
COALESCE(e.platform, ''),
|
||||||
COALESCE(client_request_id, ''),
|
COALESCE(e.model, ''),
|
||||||
COALESCE(request_id, ''),
|
COALESCE(e.is_retryable, false),
|
||||||
COALESCE(error_message, ''),
|
COALESCE(e.retry_count, 0),
|
||||||
COALESCE(error_body, ''),
|
COALESCE(e.resolved, false),
|
||||||
upstream_status_code,
|
e.resolved_at,
|
||||||
COALESCE(upstream_error_message, ''),
|
e.resolved_by_user_id,
|
||||||
COALESCE(upstream_error_detail, ''),
|
e.resolved_retry_id,
|
||||||
COALESCE(upstream_errors::text, ''),
|
COALESCE(e.client_request_id, ''),
|
||||||
is_business_limited,
|
COALESCE(e.request_id, ''),
|
||||||
user_id,
|
COALESCE(e.error_message, ''),
|
||||||
api_key_id,
|
COALESCE(e.error_body, ''),
|
||||||
account_id,
|
e.upstream_status_code,
|
||||||
group_id,
|
COALESCE(e.upstream_error_message, ''),
|
||||||
CASE WHEN client_ip IS NULL THEN NULL ELSE client_ip::text END,
|
COALESCE(e.upstream_error_detail, ''),
|
||||||
COALESCE(request_path, ''),
|
COALESCE(e.upstream_errors::text, ''),
|
||||||
stream,
|
e.is_business_limited,
|
||||||
COALESCE(user_agent, ''),
|
e.user_id,
|
||||||
auth_latency_ms,
|
COALESCE(u.email, ''),
|
||||||
routing_latency_ms,
|
e.api_key_id,
|
||||||
upstream_latency_ms,
|
e.account_id,
|
||||||
response_latency_ms,
|
COALESCE(a.name, ''),
|
||||||
time_to_first_token_ms,
|
e.group_id,
|
||||||
COALESCE(request_body::text, ''),
|
COALESCE(g.name, ''),
|
||||||
request_body_truncated,
|
CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END,
|
||||||
request_body_bytes,
|
COALESCE(e.request_path, ''),
|
||||||
COALESCE(request_headers::text, '')
|
e.stream,
|
||||||
FROM ops_error_logs
|
COALESCE(e.user_agent, ''),
|
||||||
WHERE id = $1
|
e.auth_latency_ms,
|
||||||
|
e.routing_latency_ms,
|
||||||
|
e.upstream_latency_ms,
|
||||||
|
e.response_latency_ms,
|
||||||
|
e.time_to_first_token_ms,
|
||||||
|
COALESCE(e.request_body::text, ''),
|
||||||
|
e.request_body_truncated,
|
||||||
|
e.request_body_bytes,
|
||||||
|
COALESCE(e.request_headers::text, '')
|
||||||
|
FROM ops_error_logs e
|
||||||
|
LEFT JOIN users u ON e.user_id = u.id
|
||||||
|
LEFT JOIN accounts a ON e.account_id = a.id
|
||||||
|
LEFT JOIN groups g ON e.group_id = g.id
|
||||||
|
WHERE e.id = $1
|
||||||
LIMIT 1`
|
LIMIT 1`
|
||||||
|
|
||||||
var out service.OpsErrorLogDetail
|
var out service.OpsErrorLogDetail
|
||||||
var latency sql.NullInt64
|
|
||||||
var statusCode sql.NullInt64
|
var statusCode sql.NullInt64
|
||||||
var upstreamStatusCode sql.NullInt64
|
var upstreamStatusCode sql.NullInt64
|
||||||
|
var resolvedAt sql.NullTime
|
||||||
|
var resolvedBy sql.NullInt64
|
||||||
|
var resolvedRetryID sql.NullInt64
|
||||||
var clientIP sql.NullString
|
var clientIP sql.NullString
|
||||||
var userID sql.NullInt64
|
var userID sql.NullInt64
|
||||||
var apiKeyID sql.NullInt64
|
var apiKeyID sql.NullInt64
|
||||||
@@ -318,11 +375,18 @@ LIMIT 1`
|
|||||||
&out.CreatedAt,
|
&out.CreatedAt,
|
||||||
&out.Phase,
|
&out.Phase,
|
||||||
&out.Type,
|
&out.Type,
|
||||||
|
&out.Owner,
|
||||||
|
&out.Source,
|
||||||
&out.Severity,
|
&out.Severity,
|
||||||
&statusCode,
|
&statusCode,
|
||||||
&out.Platform,
|
&out.Platform,
|
||||||
&out.Model,
|
&out.Model,
|
||||||
&latency,
|
&out.IsRetryable,
|
||||||
|
&out.RetryCount,
|
||||||
|
&out.Resolved,
|
||||||
|
&resolvedAt,
|
||||||
|
&resolvedBy,
|
||||||
|
&resolvedRetryID,
|
||||||
&out.ClientRequestID,
|
&out.ClientRequestID,
|
||||||
&out.RequestID,
|
&out.RequestID,
|
||||||
&out.Message,
|
&out.Message,
|
||||||
@@ -333,9 +397,12 @@ LIMIT 1`
|
|||||||
&out.UpstreamErrors,
|
&out.UpstreamErrors,
|
||||||
&out.IsBusinessLimited,
|
&out.IsBusinessLimited,
|
||||||
&userID,
|
&userID,
|
||||||
|
&out.UserEmail,
|
||||||
&apiKeyID,
|
&apiKeyID,
|
||||||
&accountID,
|
&accountID,
|
||||||
|
&out.AccountName,
|
||||||
&groupID,
|
&groupID,
|
||||||
|
&out.GroupName,
|
||||||
&clientIP,
|
&clientIP,
|
||||||
&out.RequestPath,
|
&out.RequestPath,
|
||||||
&out.Stream,
|
&out.Stream,
|
||||||
@@ -355,9 +422,17 @@ LIMIT 1`
|
|||||||
}
|
}
|
||||||
|
|
||||||
out.StatusCode = int(statusCode.Int64)
|
out.StatusCode = int(statusCode.Int64)
|
||||||
if latency.Valid {
|
if resolvedAt.Valid {
|
||||||
v := int(latency.Int64)
|
t := resolvedAt.Time
|
||||||
out.LatencyMs = &v
|
out.ResolvedAt = &t
|
||||||
|
}
|
||||||
|
if resolvedBy.Valid {
|
||||||
|
v := resolvedBy.Int64
|
||||||
|
out.ResolvedByUserID = &v
|
||||||
|
}
|
||||||
|
if resolvedRetryID.Valid {
|
||||||
|
v := resolvedRetryID.Int64
|
||||||
|
out.ResolvedRetryID = &v
|
||||||
}
|
}
|
||||||
if clientIP.Valid {
|
if clientIP.Valid {
|
||||||
s := clientIP.String
|
s := clientIP.String
|
||||||
@@ -487,9 +562,15 @@ SET
|
|||||||
status = $2,
|
status = $2,
|
||||||
finished_at = $3,
|
finished_at = $3,
|
||||||
duration_ms = $4,
|
duration_ms = $4,
|
||||||
result_request_id = $5,
|
success = $5,
|
||||||
result_error_id = $6,
|
http_status_code = $6,
|
||||||
error_message = $7
|
upstream_request_id = $7,
|
||||||
|
used_account_id = $8,
|
||||||
|
response_preview = $9,
|
||||||
|
response_truncated = $10,
|
||||||
|
result_request_id = $11,
|
||||||
|
result_error_id = $12,
|
||||||
|
error_message = $13
|
||||||
WHERE id = $1`
|
WHERE id = $1`
|
||||||
|
|
||||||
_, err := r.db.ExecContext(
|
_, err := r.db.ExecContext(
|
||||||
@@ -499,8 +580,14 @@ WHERE id = $1`
|
|||||||
strings.TrimSpace(input.Status),
|
strings.TrimSpace(input.Status),
|
||||||
nullTime(input.FinishedAt),
|
nullTime(input.FinishedAt),
|
||||||
input.DurationMs,
|
input.DurationMs,
|
||||||
|
nullBool(input.Success),
|
||||||
|
nullInt(input.HTTPStatusCode),
|
||||||
|
opsNullString(input.UpstreamRequestID),
|
||||||
|
nullInt64(input.UsedAccountID),
|
||||||
|
opsNullString(input.ResponsePreview),
|
||||||
|
nullBool(input.ResponseTruncated),
|
||||||
opsNullString(input.ResultRequestID),
|
opsNullString(input.ResultRequestID),
|
||||||
opsNullInt64(input.ResultErrorID),
|
nullInt64(input.ResultErrorID),
|
||||||
opsNullString(input.ErrorMessage),
|
opsNullString(input.ErrorMessage),
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
@@ -526,6 +613,12 @@ SELECT
|
|||||||
started_at,
|
started_at,
|
||||||
finished_at,
|
finished_at,
|
||||||
duration_ms,
|
duration_ms,
|
||||||
|
success,
|
||||||
|
http_status_code,
|
||||||
|
upstream_request_id,
|
||||||
|
used_account_id,
|
||||||
|
response_preview,
|
||||||
|
response_truncated,
|
||||||
result_request_id,
|
result_request_id,
|
||||||
result_error_id,
|
result_error_id,
|
||||||
error_message
|
error_message
|
||||||
@@ -540,6 +633,12 @@ LIMIT 1`
|
|||||||
var startedAt sql.NullTime
|
var startedAt sql.NullTime
|
||||||
var finishedAt sql.NullTime
|
var finishedAt sql.NullTime
|
||||||
var durationMs sql.NullInt64
|
var durationMs sql.NullInt64
|
||||||
|
var success sql.NullBool
|
||||||
|
var httpStatusCode sql.NullInt64
|
||||||
|
var upstreamRequestID sql.NullString
|
||||||
|
var usedAccountID sql.NullInt64
|
||||||
|
var responsePreview sql.NullString
|
||||||
|
var responseTruncated sql.NullBool
|
||||||
var resultRequestID sql.NullString
|
var resultRequestID sql.NullString
|
||||||
var resultErrorID sql.NullInt64
|
var resultErrorID sql.NullInt64
|
||||||
var errorMessage sql.NullString
|
var errorMessage sql.NullString
|
||||||
@@ -555,6 +654,12 @@ LIMIT 1`
|
|||||||
&startedAt,
|
&startedAt,
|
||||||
&finishedAt,
|
&finishedAt,
|
||||||
&durationMs,
|
&durationMs,
|
||||||
|
&success,
|
||||||
|
&httpStatusCode,
|
||||||
|
&upstreamRequestID,
|
||||||
|
&usedAccountID,
|
||||||
|
&responsePreview,
|
||||||
|
&responseTruncated,
|
||||||
&resultRequestID,
|
&resultRequestID,
|
||||||
&resultErrorID,
|
&resultErrorID,
|
||||||
&errorMessage,
|
&errorMessage,
|
||||||
@@ -579,6 +684,30 @@ LIMIT 1`
|
|||||||
v := durationMs.Int64
|
v := durationMs.Int64
|
||||||
out.DurationMs = &v
|
out.DurationMs = &v
|
||||||
}
|
}
|
||||||
|
if success.Valid {
|
||||||
|
v := success.Bool
|
||||||
|
out.Success = &v
|
||||||
|
}
|
||||||
|
if httpStatusCode.Valid {
|
||||||
|
v := int(httpStatusCode.Int64)
|
||||||
|
out.HTTPStatusCode = &v
|
||||||
|
}
|
||||||
|
if upstreamRequestID.Valid {
|
||||||
|
s := upstreamRequestID.String
|
||||||
|
out.UpstreamRequestID = &s
|
||||||
|
}
|
||||||
|
if usedAccountID.Valid {
|
||||||
|
v := usedAccountID.Int64
|
||||||
|
out.UsedAccountID = &v
|
||||||
|
}
|
||||||
|
if responsePreview.Valid {
|
||||||
|
s := responsePreview.String
|
||||||
|
out.ResponsePreview = &s
|
||||||
|
}
|
||||||
|
if responseTruncated.Valid {
|
||||||
|
v := responseTruncated.Bool
|
||||||
|
out.ResponseTruncated = &v
|
||||||
|
}
|
||||||
if resultRequestID.Valid {
|
if resultRequestID.Valid {
|
||||||
s := resultRequestID.String
|
s := resultRequestID.String
|
||||||
out.ResultRequestID = &s
|
out.ResultRequestID = &s
|
||||||
@@ -602,30 +731,234 @@ func nullTime(t time.Time) sql.NullTime {
|
|||||||
return sql.NullTime{Time: t, Valid: true}
|
return sql.NullTime{Time: t, Valid: true}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func nullBool(v *bool) sql.NullBool {
|
||||||
|
if v == nil {
|
||||||
|
return sql.NullBool{}
|
||||||
|
}
|
||||||
|
return sql.NullBool{Bool: *v, Valid: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) ListRetryAttemptsByErrorID(ctx context.Context, sourceErrorID int64, limit int) ([]*service.OpsRetryAttempt, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if sourceErrorID <= 0 {
|
||||||
|
return nil, fmt.Errorf("invalid source_error_id")
|
||||||
|
}
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 50
|
||||||
|
}
|
||||||
|
if limit > 200 {
|
||||||
|
limit = 200
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
SELECT
|
||||||
|
r.id,
|
||||||
|
r.created_at,
|
||||||
|
COALESCE(r.requested_by_user_id, 0),
|
||||||
|
r.source_error_id,
|
||||||
|
COALESCE(r.mode, ''),
|
||||||
|
r.pinned_account_id,
|
||||||
|
COALESCE(pa.name, ''),
|
||||||
|
COALESCE(r.status, ''),
|
||||||
|
r.started_at,
|
||||||
|
r.finished_at,
|
||||||
|
r.duration_ms,
|
||||||
|
r.success,
|
||||||
|
r.http_status_code,
|
||||||
|
r.upstream_request_id,
|
||||||
|
r.used_account_id,
|
||||||
|
COALESCE(ua.name, ''),
|
||||||
|
r.response_preview,
|
||||||
|
r.response_truncated,
|
||||||
|
r.result_request_id,
|
||||||
|
r.result_error_id,
|
||||||
|
r.error_message
|
||||||
|
FROM ops_retry_attempts r
|
||||||
|
LEFT JOIN accounts pa ON r.pinned_account_id = pa.id
|
||||||
|
LEFT JOIN accounts ua ON r.used_account_id = ua.id
|
||||||
|
WHERE r.source_error_id = $1
|
||||||
|
ORDER BY r.created_at DESC
|
||||||
|
LIMIT $2`
|
||||||
|
|
||||||
|
rows, err := r.db.QueryContext(ctx, q, sourceErrorID, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
out := make([]*service.OpsRetryAttempt, 0, 16)
|
||||||
|
for rows.Next() {
|
||||||
|
var item service.OpsRetryAttempt
|
||||||
|
var pinnedAccountID sql.NullInt64
|
||||||
|
var pinnedAccountName string
|
||||||
|
var requestedBy sql.NullInt64
|
||||||
|
var startedAt sql.NullTime
|
||||||
|
var finishedAt sql.NullTime
|
||||||
|
var durationMs sql.NullInt64
|
||||||
|
var success sql.NullBool
|
||||||
|
var httpStatusCode sql.NullInt64
|
||||||
|
var upstreamRequestID sql.NullString
|
||||||
|
var usedAccountID sql.NullInt64
|
||||||
|
var usedAccountName string
|
||||||
|
var responsePreview sql.NullString
|
||||||
|
var responseTruncated sql.NullBool
|
||||||
|
var resultRequestID sql.NullString
|
||||||
|
var resultErrorID sql.NullInt64
|
||||||
|
var errorMessage sql.NullString
|
||||||
|
|
||||||
|
if err := rows.Scan(
|
||||||
|
&item.ID,
|
||||||
|
&item.CreatedAt,
|
||||||
|
&requestedBy,
|
||||||
|
&item.SourceErrorID,
|
||||||
|
&item.Mode,
|
||||||
|
&pinnedAccountID,
|
||||||
|
&pinnedAccountName,
|
||||||
|
&item.Status,
|
||||||
|
&startedAt,
|
||||||
|
&finishedAt,
|
||||||
|
&durationMs,
|
||||||
|
&success,
|
||||||
|
&httpStatusCode,
|
||||||
|
&upstreamRequestID,
|
||||||
|
&usedAccountID,
|
||||||
|
&usedAccountName,
|
||||||
|
&responsePreview,
|
||||||
|
&responseTruncated,
|
||||||
|
&resultRequestID,
|
||||||
|
&resultErrorID,
|
||||||
|
&errorMessage,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
item.RequestedByUserID = requestedBy.Int64
|
||||||
|
if pinnedAccountID.Valid {
|
||||||
|
v := pinnedAccountID.Int64
|
||||||
|
item.PinnedAccountID = &v
|
||||||
|
}
|
||||||
|
item.PinnedAccountName = pinnedAccountName
|
||||||
|
if startedAt.Valid {
|
||||||
|
t := startedAt.Time
|
||||||
|
item.StartedAt = &t
|
||||||
|
}
|
||||||
|
if finishedAt.Valid {
|
||||||
|
t := finishedAt.Time
|
||||||
|
item.FinishedAt = &t
|
||||||
|
}
|
||||||
|
if durationMs.Valid {
|
||||||
|
v := durationMs.Int64
|
||||||
|
item.DurationMs = &v
|
||||||
|
}
|
||||||
|
if success.Valid {
|
||||||
|
v := success.Bool
|
||||||
|
item.Success = &v
|
||||||
|
}
|
||||||
|
if httpStatusCode.Valid {
|
||||||
|
v := int(httpStatusCode.Int64)
|
||||||
|
item.HTTPStatusCode = &v
|
||||||
|
}
|
||||||
|
if upstreamRequestID.Valid {
|
||||||
|
item.UpstreamRequestID = &upstreamRequestID.String
|
||||||
|
}
|
||||||
|
if usedAccountID.Valid {
|
||||||
|
v := usedAccountID.Int64
|
||||||
|
item.UsedAccountID = &v
|
||||||
|
}
|
||||||
|
item.UsedAccountName = usedAccountName
|
||||||
|
if responsePreview.Valid {
|
||||||
|
item.ResponsePreview = &responsePreview.String
|
||||||
|
}
|
||||||
|
if responseTruncated.Valid {
|
||||||
|
v := responseTruncated.Bool
|
||||||
|
item.ResponseTruncated = &v
|
||||||
|
}
|
||||||
|
if resultRequestID.Valid {
|
||||||
|
item.ResultRequestID = &resultRequestID.String
|
||||||
|
}
|
||||||
|
if resultErrorID.Valid {
|
||||||
|
v := resultErrorID.Int64
|
||||||
|
item.ResultErrorID = &v
|
||||||
|
}
|
||||||
|
if errorMessage.Valid {
|
||||||
|
item.ErrorMessage = &errorMessage.String
|
||||||
|
}
|
||||||
|
out = append(out, &item)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64, resolvedAt *time.Time) error {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if errorID <= 0 {
|
||||||
|
return fmt.Errorf("invalid error id")
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
UPDATE ops_error_logs
|
||||||
|
SET
|
||||||
|
resolved = $2,
|
||||||
|
resolved_at = $3,
|
||||||
|
resolved_by_user_id = $4,
|
||||||
|
resolved_retry_id = $5
|
||||||
|
WHERE id = $1`
|
||||||
|
|
||||||
|
at := sql.NullTime{}
|
||||||
|
if resolvedAt != nil && !resolvedAt.IsZero() {
|
||||||
|
at = sql.NullTime{Time: resolvedAt.UTC(), Valid: true}
|
||||||
|
} else if resolved {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
at = sql.NullTime{Time: now, Valid: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := r.db.ExecContext(
|
||||||
|
ctx,
|
||||||
|
q,
|
||||||
|
errorID,
|
||||||
|
resolved,
|
||||||
|
at,
|
||||||
|
nullInt64(resolvedByUserID),
|
||||||
|
nullInt64(resolvedRetryID),
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
||||||
clauses := make([]string, 0, 8)
|
clauses := make([]string, 0, 12)
|
||||||
args := make([]any, 0, 8)
|
args := make([]any, 0, 12)
|
||||||
clauses = append(clauses, "1=1")
|
clauses = append(clauses, "1=1")
|
||||||
|
|
||||||
phaseFilter := ""
|
phaseFilter := ""
|
||||||
if filter != nil {
|
if filter != nil {
|
||||||
phaseFilter = strings.TrimSpace(strings.ToLower(filter.Phase))
|
phaseFilter = strings.TrimSpace(strings.ToLower(filter.Phase))
|
||||||
}
|
}
|
||||||
// ops_error_logs primarily stores client-visible error requests (status>=400),
|
// ops_error_logs stores client-visible error requests (status>=400),
|
||||||
// but we also persist "recovered" upstream errors (status<400) for upstream health visibility.
|
// but we also persist "recovered" upstream errors (status<400) for upstream health visibility.
|
||||||
// By default, keep list endpoints scoped to client errors unless explicitly filtering upstream phase.
|
// If Resolved is not specified, do not filter by resolved state (backward-compatible).
|
||||||
|
resolvedFilter := (*bool)(nil)
|
||||||
|
if filter != nil {
|
||||||
|
resolvedFilter = filter.Resolved
|
||||||
|
}
|
||||||
|
// Keep list endpoints scoped to client errors unless explicitly filtering upstream phase.
|
||||||
if phaseFilter != "upstream" {
|
if phaseFilter != "upstream" {
|
||||||
clauses = append(clauses, "COALESCE(status_code, 0) >= 400")
|
clauses = append(clauses, "COALESCE(status_code, 0) >= 400")
|
||||||
}
|
}
|
||||||
|
|
||||||
if filter.StartTime != nil && !filter.StartTime.IsZero() {
|
if filter.StartTime != nil && !filter.StartTime.IsZero() {
|
||||||
args = append(args, filter.StartTime.UTC())
|
args = append(args, filter.StartTime.UTC())
|
||||||
clauses = append(clauses, "created_at >= $"+itoa(len(args)))
|
clauses = append(clauses, "e.created_at >= $"+itoa(len(args)))
|
||||||
}
|
}
|
||||||
if filter.EndTime != nil && !filter.EndTime.IsZero() {
|
if filter.EndTime != nil && !filter.EndTime.IsZero() {
|
||||||
args = append(args, filter.EndTime.UTC())
|
args = append(args, filter.EndTime.UTC())
|
||||||
// Keep time-window semantics consistent with other ops queries: [start, end)
|
// Keep time-window semantics consistent with other ops queries: [start, end)
|
||||||
clauses = append(clauses, "created_at < $"+itoa(len(args)))
|
clauses = append(clauses, "e.created_at < $"+itoa(len(args)))
|
||||||
}
|
}
|
||||||
if p := strings.TrimSpace(filter.Platform); p != "" {
|
if p := strings.TrimSpace(filter.Platform); p != "" {
|
||||||
args = append(args, p)
|
args = append(args, p)
|
||||||
@@ -643,10 +976,59 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
|||||||
args = append(args, phase)
|
args = append(args, phase)
|
||||||
clauses = append(clauses, "error_phase = $"+itoa(len(args)))
|
clauses = append(clauses, "error_phase = $"+itoa(len(args)))
|
||||||
}
|
}
|
||||||
|
if filter != nil {
|
||||||
|
if owner := strings.TrimSpace(strings.ToLower(filter.Owner)); owner != "" {
|
||||||
|
args = append(args, owner)
|
||||||
|
clauses = append(clauses, "LOWER(COALESCE(error_owner,'')) = $"+itoa(len(args)))
|
||||||
|
}
|
||||||
|
if source := strings.TrimSpace(strings.ToLower(filter.Source)); source != "" {
|
||||||
|
args = append(args, source)
|
||||||
|
clauses = append(clauses, "LOWER(COALESCE(error_source,'')) = $"+itoa(len(args)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if resolvedFilter != nil {
|
||||||
|
args = append(args, *resolvedFilter)
|
||||||
|
clauses = append(clauses, "COALESCE(resolved,false) = $"+itoa(len(args)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// View filter: errors vs excluded vs all.
|
||||||
|
// Excluded = upstream 429/529 and business-limited (quota/concurrency/billing) errors.
|
||||||
|
view := ""
|
||||||
|
if filter != nil {
|
||||||
|
view = strings.ToLower(strings.TrimSpace(filter.View))
|
||||||
|
}
|
||||||
|
switch view {
|
||||||
|
case "", "errors":
|
||||||
|
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
|
||||||
|
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)")
|
||||||
|
case "excluded":
|
||||||
|
clauses = append(clauses, "(COALESCE(is_business_limited,false) = true OR COALESCE(upstream_status_code, status_code, 0) IN (429, 529))")
|
||||||
|
case "all":
|
||||||
|
// no-op
|
||||||
|
default:
|
||||||
|
// treat unknown as default 'errors'
|
||||||
|
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
|
||||||
|
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)")
|
||||||
|
}
|
||||||
if len(filter.StatusCodes) > 0 {
|
if len(filter.StatusCodes) > 0 {
|
||||||
args = append(args, pq.Array(filter.StatusCodes))
|
args = append(args, pq.Array(filter.StatusCodes))
|
||||||
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+")")
|
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+")")
|
||||||
|
} else if filter.StatusCodesOther {
|
||||||
|
// "Other" means: status codes not in the common list.
|
||||||
|
known := []int{400, 401, 403, 404, 409, 422, 429, 500, 502, 503, 504, 529}
|
||||||
|
args = append(args, pq.Array(known))
|
||||||
|
clauses = append(clauses, "NOT (COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+"))")
|
||||||
}
|
}
|
||||||
|
// Exact correlation keys (preferred for request↔upstream linkage).
|
||||||
|
if rid := strings.TrimSpace(filter.RequestID); rid != "" {
|
||||||
|
args = append(args, rid)
|
||||||
|
clauses = append(clauses, "COALESCE(request_id,'') = $"+itoa(len(args)))
|
||||||
|
}
|
||||||
|
if crid := strings.TrimSpace(filter.ClientRequestID); crid != "" {
|
||||||
|
args = append(args, crid)
|
||||||
|
clauses = append(clauses, "COALESCE(client_request_id,'') = $"+itoa(len(args)))
|
||||||
|
}
|
||||||
|
|
||||||
if q := strings.TrimSpace(filter.Query); q != "" {
|
if q := strings.TrimSpace(filter.Query); q != "" {
|
||||||
like := "%" + q + "%"
|
like := "%" + q + "%"
|
||||||
args = append(args, like)
|
args = append(args, like)
|
||||||
@@ -654,6 +1036,13 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
|||||||
clauses = append(clauses, "(request_id ILIKE $"+n+" OR client_request_id ILIKE $"+n+" OR error_message ILIKE $"+n+")")
|
clauses = append(clauses, "(request_id ILIKE $"+n+" OR client_request_id ILIKE $"+n+" OR error_message ILIKE $"+n+")")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if userQuery := strings.TrimSpace(filter.UserQuery); userQuery != "" {
|
||||||
|
like := "%" + userQuery + "%"
|
||||||
|
args = append(args, like)
|
||||||
|
n := itoa(len(args))
|
||||||
|
clauses = append(clauses, "u.email ILIKE $"+n)
|
||||||
|
}
|
||||||
|
|
||||||
return "WHERE " + strings.Join(clauses, " AND "), args
|
return "WHERE " + strings.Join(clauses, " AND "), args
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -354,7 +354,7 @@ SELECT
|
|||||||
created_at
|
created_at
|
||||||
FROM ops_alert_events
|
FROM ops_alert_events
|
||||||
` + where + `
|
` + where + `
|
||||||
ORDER BY fired_at DESC
|
ORDER BY fired_at DESC, id DESC
|
||||||
LIMIT ` + limitArg
|
LIMIT ` + limitArg
|
||||||
|
|
||||||
rows, err := r.db.QueryContext(ctx, q, args...)
|
rows, err := r.db.QueryContext(ctx, q, args...)
|
||||||
@@ -413,6 +413,43 @@ LIMIT ` + limitArg
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) GetAlertEventByID(ctx context.Context, eventID int64) (*service.OpsAlertEvent, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if eventID <= 0 {
|
||||||
|
return nil, fmt.Errorf("invalid event id")
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
COALESCE(rule_id, 0),
|
||||||
|
COALESCE(severity, ''),
|
||||||
|
COALESCE(status, ''),
|
||||||
|
COALESCE(title, ''),
|
||||||
|
COALESCE(description, ''),
|
||||||
|
metric_value,
|
||||||
|
threshold_value,
|
||||||
|
dimensions,
|
||||||
|
fired_at,
|
||||||
|
resolved_at,
|
||||||
|
email_sent,
|
||||||
|
created_at
|
||||||
|
FROM ops_alert_events
|
||||||
|
WHERE id = $1`
|
||||||
|
|
||||||
|
row := r.db.QueryRowContext(ctx, q, eventID)
|
||||||
|
ev, err := scanOpsAlertEvent(row)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ev, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *opsRepository) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*service.OpsAlertEvent, error) {
|
func (r *opsRepository) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*service.OpsAlertEvent, error) {
|
||||||
if r == nil || r.db == nil {
|
if r == nil || r.db == nil {
|
||||||
return nil, fmt.Errorf("nil ops repository")
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
@@ -591,6 +628,121 @@ type opsAlertEventRow interface {
|
|||||||
Scan(dest ...any) error
|
Scan(dest ...any) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) CreateAlertSilence(ctx context.Context, input *service.OpsAlertSilence) (*service.OpsAlertSilence, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if input == nil {
|
||||||
|
return nil, fmt.Errorf("nil input")
|
||||||
|
}
|
||||||
|
if input.RuleID <= 0 {
|
||||||
|
return nil, fmt.Errorf("invalid rule_id")
|
||||||
|
}
|
||||||
|
platform := strings.TrimSpace(input.Platform)
|
||||||
|
if platform == "" {
|
||||||
|
return nil, fmt.Errorf("invalid platform")
|
||||||
|
}
|
||||||
|
if input.Until.IsZero() {
|
||||||
|
return nil, fmt.Errorf("invalid until")
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
INSERT INTO ops_alert_silences (
|
||||||
|
rule_id,
|
||||||
|
platform,
|
||||||
|
group_id,
|
||||||
|
region,
|
||||||
|
until,
|
||||||
|
reason,
|
||||||
|
created_by,
|
||||||
|
created_at
|
||||||
|
) VALUES (
|
||||||
|
$1,$2,$3,$4,$5,$6,$7,NOW()
|
||||||
|
)
|
||||||
|
RETURNING id, rule_id, platform, group_id, region, until, COALESCE(reason,''), created_by, created_at`
|
||||||
|
|
||||||
|
row := r.db.QueryRowContext(
|
||||||
|
ctx,
|
||||||
|
q,
|
||||||
|
input.RuleID,
|
||||||
|
platform,
|
||||||
|
opsNullInt64(input.GroupID),
|
||||||
|
opsNullString(input.Region),
|
||||||
|
input.Until,
|
||||||
|
opsNullString(input.Reason),
|
||||||
|
opsNullInt64(input.CreatedBy),
|
||||||
|
)
|
||||||
|
|
||||||
|
var out service.OpsAlertSilence
|
||||||
|
var groupID sql.NullInt64
|
||||||
|
var region sql.NullString
|
||||||
|
var createdBy sql.NullInt64
|
||||||
|
if err := row.Scan(
|
||||||
|
&out.ID,
|
||||||
|
&out.RuleID,
|
||||||
|
&out.Platform,
|
||||||
|
&groupID,
|
||||||
|
®ion,
|
||||||
|
&out.Until,
|
||||||
|
&out.Reason,
|
||||||
|
&createdBy,
|
||||||
|
&out.CreatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if groupID.Valid {
|
||||||
|
v := groupID.Int64
|
||||||
|
out.GroupID = &v
|
||||||
|
}
|
||||||
|
if region.Valid {
|
||||||
|
v := strings.TrimSpace(region.String)
|
||||||
|
if v != "" {
|
||||||
|
out.Region = &v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if createdBy.Valid {
|
||||||
|
v := createdBy.Int64
|
||||||
|
out.CreatedBy = &v
|
||||||
|
}
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return false, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if ruleID <= 0 {
|
||||||
|
return false, fmt.Errorf("invalid rule id")
|
||||||
|
}
|
||||||
|
platform = strings.TrimSpace(platform)
|
||||||
|
if platform == "" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if now.IsZero() {
|
||||||
|
now = time.Now().UTC()
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
SELECT 1
|
||||||
|
FROM ops_alert_silences
|
||||||
|
WHERE rule_id = $1
|
||||||
|
AND platform = $2
|
||||||
|
AND (group_id IS NOT DISTINCT FROM $3)
|
||||||
|
AND (region IS NOT DISTINCT FROM $4)
|
||||||
|
AND until > $5
|
||||||
|
LIMIT 1`
|
||||||
|
|
||||||
|
var dummy int
|
||||||
|
err := r.db.QueryRowContext(ctx, q, ruleID, platform, opsNullInt64(groupID), opsNullString(region), now).Scan(&dummy)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
func scanOpsAlertEvent(row opsAlertEventRow) (*service.OpsAlertEvent, error) {
|
func scanOpsAlertEvent(row opsAlertEventRow) (*service.OpsAlertEvent, error) {
|
||||||
var ev service.OpsAlertEvent
|
var ev service.OpsAlertEvent
|
||||||
var metricValue sql.NullFloat64
|
var metricValue sql.NullFloat64
|
||||||
@@ -652,6 +804,10 @@ func buildOpsAlertEventsWhere(filter *service.OpsAlertEventFilter) (string, []an
|
|||||||
args = append(args, severity)
|
args = append(args, severity)
|
||||||
clauses = append(clauses, "severity = $"+itoa(len(args)))
|
clauses = append(clauses, "severity = $"+itoa(len(args)))
|
||||||
}
|
}
|
||||||
|
if filter.EmailSent != nil {
|
||||||
|
args = append(args, *filter.EmailSent)
|
||||||
|
clauses = append(clauses, "email_sent = $"+itoa(len(args)))
|
||||||
|
}
|
||||||
if filter.StartTime != nil && !filter.StartTime.IsZero() {
|
if filter.StartTime != nil && !filter.StartTime.IsZero() {
|
||||||
args = append(args, *filter.StartTime)
|
args = append(args, *filter.StartTime)
|
||||||
clauses = append(clauses, "fired_at >= $"+itoa(len(args)))
|
clauses = append(clauses, "fired_at >= $"+itoa(len(args)))
|
||||||
@@ -661,6 +817,14 @@ func buildOpsAlertEventsWhere(filter *service.OpsAlertEventFilter) (string, []an
|
|||||||
clauses = append(clauses, "fired_at < $"+itoa(len(args)))
|
clauses = append(clauses, "fired_at < $"+itoa(len(args)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cursor pagination (descending by fired_at, then id)
|
||||||
|
if filter.BeforeFiredAt != nil && !filter.BeforeFiredAt.IsZero() && filter.BeforeID != nil && *filter.BeforeID > 0 {
|
||||||
|
args = append(args, *filter.BeforeFiredAt)
|
||||||
|
tsArg := "$" + itoa(len(args))
|
||||||
|
args = append(args, *filter.BeforeID)
|
||||||
|
idArg := "$" + itoa(len(args))
|
||||||
|
clauses = append(clauses, fmt.Sprintf("(fired_at < %s OR (fired_at = %s AND id < %s))", tsArg, tsArg, idArg))
|
||||||
|
}
|
||||||
// Dimensions are stored in JSONB. We filter best-effort without requiring GIN indexes.
|
// Dimensions are stored in JSONB. We filter best-effort without requiring GIN indexes.
|
||||||
if platform := strings.TrimSpace(filter.Platform); platform != "" {
|
if platform := strings.TrimSpace(filter.Platform); platform != "" {
|
||||||
args = append(args, platform)
|
args = append(args, platform)
|
||||||
|
|||||||
@@ -296,9 +296,10 @@ INSERT INTO ops_job_heartbeats (
|
|||||||
last_error_at,
|
last_error_at,
|
||||||
last_error,
|
last_error,
|
||||||
last_duration_ms,
|
last_duration_ms,
|
||||||
|
last_result,
|
||||||
updated_at
|
updated_at
|
||||||
) VALUES (
|
) VALUES (
|
||||||
$1,$2,$3,$4,$5,$6,NOW()
|
$1,$2,$3,$4,$5,$6,$7,NOW()
|
||||||
)
|
)
|
||||||
ON CONFLICT (job_name) DO UPDATE SET
|
ON CONFLICT (job_name) DO UPDATE SET
|
||||||
last_run_at = COALESCE(EXCLUDED.last_run_at, ops_job_heartbeats.last_run_at),
|
last_run_at = COALESCE(EXCLUDED.last_run_at, ops_job_heartbeats.last_run_at),
|
||||||
@@ -312,6 +313,10 @@ ON CONFLICT (job_name) DO UPDATE SET
|
|||||||
ELSE COALESCE(EXCLUDED.last_error, ops_job_heartbeats.last_error)
|
ELSE COALESCE(EXCLUDED.last_error, ops_job_heartbeats.last_error)
|
||||||
END,
|
END,
|
||||||
last_duration_ms = COALESCE(EXCLUDED.last_duration_ms, ops_job_heartbeats.last_duration_ms),
|
last_duration_ms = COALESCE(EXCLUDED.last_duration_ms, ops_job_heartbeats.last_duration_ms),
|
||||||
|
last_result = CASE
|
||||||
|
WHEN EXCLUDED.last_success_at IS NOT NULL THEN COALESCE(EXCLUDED.last_result, ops_job_heartbeats.last_result)
|
||||||
|
ELSE ops_job_heartbeats.last_result
|
||||||
|
END,
|
||||||
updated_at = NOW()`
|
updated_at = NOW()`
|
||||||
|
|
||||||
_, err := r.db.ExecContext(
|
_, err := r.db.ExecContext(
|
||||||
@@ -323,6 +328,7 @@ ON CONFLICT (job_name) DO UPDATE SET
|
|||||||
opsNullTime(input.LastErrorAt),
|
opsNullTime(input.LastErrorAt),
|
||||||
opsNullString(input.LastError),
|
opsNullString(input.LastError),
|
||||||
opsNullInt(input.LastDurationMs),
|
opsNullInt(input.LastDurationMs),
|
||||||
|
opsNullString(input.LastResult),
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -340,6 +346,7 @@ SELECT
|
|||||||
last_error_at,
|
last_error_at,
|
||||||
last_error,
|
last_error,
|
||||||
last_duration_ms,
|
last_duration_ms,
|
||||||
|
last_result,
|
||||||
updated_at
|
updated_at
|
||||||
FROM ops_job_heartbeats
|
FROM ops_job_heartbeats
|
||||||
ORDER BY job_name ASC`
|
ORDER BY job_name ASC`
|
||||||
@@ -359,6 +366,8 @@ ORDER BY job_name ASC`
|
|||||||
var lastError sql.NullString
|
var lastError sql.NullString
|
||||||
var lastDuration sql.NullInt64
|
var lastDuration sql.NullInt64
|
||||||
|
|
||||||
|
var lastResult sql.NullString
|
||||||
|
|
||||||
if err := rows.Scan(
|
if err := rows.Scan(
|
||||||
&item.JobName,
|
&item.JobName,
|
||||||
&lastRun,
|
&lastRun,
|
||||||
@@ -366,6 +375,7 @@ ORDER BY job_name ASC`
|
|||||||
&lastErrorAt,
|
&lastErrorAt,
|
||||||
&lastError,
|
&lastError,
|
||||||
&lastDuration,
|
&lastDuration,
|
||||||
|
&lastResult,
|
||||||
&item.UpdatedAt,
|
&item.UpdatedAt,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -391,6 +401,10 @@ ORDER BY job_name ASC`
|
|||||||
v := lastDuration.Int64
|
v := lastDuration.Int64
|
||||||
item.LastDurationMs = &v
|
item.LastDurationMs = &v
|
||||||
}
|
}
|
||||||
|
if lastResult.Valid {
|
||||||
|
v := lastResult.String
|
||||||
|
item.LastResult = &v
|
||||||
|
}
|
||||||
|
|
||||||
out = append(out, &item)
|
out = append(out, &item)
|
||||||
}
|
}
|
||||||
|
|||||||
74
backend/internal/repository/proxy_latency_cache.go
Normal file
74
backend/internal/repository/proxy_latency_cache.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const proxyLatencyKeyPrefix = "proxy:latency:"
|
||||||
|
|
||||||
|
func proxyLatencyKey(proxyID int64) string {
|
||||||
|
return fmt.Sprintf("%s%d", proxyLatencyKeyPrefix, proxyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
type proxyLatencyCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewProxyLatencyCache(rdb *redis.Client) service.ProxyLatencyCache {
|
||||||
|
return &proxyLatencyCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *proxyLatencyCache) GetProxyLatencies(ctx context.Context, proxyIDs []int64) (map[int64]*service.ProxyLatencyInfo, error) {
|
||||||
|
results := make(map[int64]*service.ProxyLatencyInfo)
|
||||||
|
if len(proxyIDs) == 0 {
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := make([]string, 0, len(proxyIDs))
|
||||||
|
for _, id := range proxyIDs {
|
||||||
|
keys = append(keys, proxyLatencyKey(id))
|
||||||
|
}
|
||||||
|
|
||||||
|
values, err := c.rdb.MGet(ctx, keys...).Result()
|
||||||
|
if err != nil {
|
||||||
|
return results, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, raw := range values {
|
||||||
|
if raw == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var payload []byte
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case string:
|
||||||
|
payload = []byte(v)
|
||||||
|
case []byte:
|
||||||
|
payload = v
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var info service.ProxyLatencyInfo
|
||||||
|
if err := json.Unmarshal(payload, &info); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
results[proxyIDs[i]] = &info
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *proxyLatencyCache) SetProxyLatency(ctx context.Context, proxyID int64, info *service.ProxyLatencyInfo) error {
|
||||||
|
if info == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
payload, err := json.Marshal(info)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.rdb.Set(ctx, proxyLatencyKey(proxyID), payload, 0).Err()
|
||||||
|
}
|
||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
@@ -34,7 +35,10 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const defaultIPInfoURL = "https://ipinfo.io/json"
|
const (
|
||||||
|
defaultIPInfoURL = "http://ip-api.com/json/?lang=zh-CN"
|
||||||
|
defaultProxyProbeTimeout = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
type proxyProbeService struct {
|
type proxyProbeService struct {
|
||||||
ipInfoURL string
|
ipInfoURL string
|
||||||
@@ -46,7 +50,7 @@ type proxyProbeService struct {
|
|||||||
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
||||||
client, err := httpclient.GetClient(httpclient.Options{
|
client, err := httpclient.GetClient(httpclient.Options{
|
||||||
ProxyURL: proxyURL,
|
ProxyURL: proxyURL,
|
||||||
Timeout: 15 * time.Second,
|
Timeout: defaultProxyProbeTimeout,
|
||||||
InsecureSkipVerify: s.insecureSkipVerify,
|
InsecureSkipVerify: s.insecureSkipVerify,
|
||||||
ProxyStrict: true,
|
ProxyStrict: true,
|
||||||
ValidateResolvedIP: s.validateResolvedIP,
|
ValidateResolvedIP: s.validateResolvedIP,
|
||||||
@@ -75,10 +79,14 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
|||||||
}
|
}
|
||||||
|
|
||||||
var ipInfo struct {
|
var ipInfo struct {
|
||||||
IP string `json:"ip"`
|
Status string `json:"status"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Query string `json:"query"`
|
||||||
City string `json:"city"`
|
City string `json:"city"`
|
||||||
Region string `json:"region"`
|
Region string `json:"region"`
|
||||||
|
RegionName string `json:"regionName"`
|
||||||
Country string `json:"country"`
|
Country string `json:"country"`
|
||||||
|
CountryCode string `json:"countryCode"`
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
@@ -89,11 +97,22 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
|||||||
if err := json.Unmarshal(body, &ipInfo); err != nil {
|
if err := json.Unmarshal(body, &ipInfo); err != nil {
|
||||||
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
|
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
|
||||||
}
|
}
|
||||||
|
if strings.ToLower(ipInfo.Status) != "success" {
|
||||||
|
if ipInfo.Message == "" {
|
||||||
|
ipInfo.Message = "ip-api request failed"
|
||||||
|
}
|
||||||
|
return nil, latencyMs, fmt.Errorf("ip-api request failed: %s", ipInfo.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
region := ipInfo.RegionName
|
||||||
|
if region == "" {
|
||||||
|
region = ipInfo.Region
|
||||||
|
}
|
||||||
return &service.ProxyExitInfo{
|
return &service.ProxyExitInfo{
|
||||||
IP: ipInfo.IP,
|
IP: ipInfo.Query,
|
||||||
City: ipInfo.City,
|
City: ipInfo.City,
|
||||||
Region: ipInfo.Region,
|
Region: region,
|
||||||
Country: ipInfo.Country,
|
Country: ipInfo.Country,
|
||||||
|
CountryCode: ipInfo.CountryCode,
|
||||||
}, latencyMs, nil
|
}, latencyMs, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ type ProxyProbeServiceSuite struct {
|
|||||||
func (s *ProxyProbeServiceSuite) SetupTest() {
|
func (s *ProxyProbeServiceSuite) SetupTest() {
|
||||||
s.ctx = context.Background()
|
s.ctx = context.Background()
|
||||||
s.prober = &proxyProbeService{
|
s.prober = &proxyProbeService{
|
||||||
ipInfoURL: "http://ipinfo.test/json",
|
ipInfoURL: "http://ip-api.test/json/?lang=zh-CN",
|
||||||
allowPrivateHosts: true,
|
allowPrivateHosts: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -54,7 +54,7 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
|
|||||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
seen <- r.RequestURI
|
seen <- r.RequestURI
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
_, _ = io.WriteString(w, `{"ip":"1.2.3.4","city":"c","region":"r","country":"cc"}`)
|
_, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||||
@@ -64,11 +64,12 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
|
|||||||
require.Equal(s.T(), "c", info.City)
|
require.Equal(s.T(), "c", info.City)
|
||||||
require.Equal(s.T(), "r", info.Region)
|
require.Equal(s.T(), "r", info.Region)
|
||||||
require.Equal(s.T(), "cc", info.Country)
|
require.Equal(s.T(), "cc", info.Country)
|
||||||
|
require.Equal(s.T(), "CC", info.CountryCode)
|
||||||
|
|
||||||
// Verify proxy received the request
|
// Verify proxy received the request
|
||||||
select {
|
select {
|
||||||
case uri := <-seen:
|
case uri := <-seen:
|
||||||
require.Contains(s.T(), uri, "ipinfo.test", "expected request to go through proxy")
|
require.Contains(s.T(), uri, "ip-api.test", "expected request to go through proxy")
|
||||||
default:
|
default:
|
||||||
require.Fail(s.T(), "expected proxy to receive request")
|
require.Fail(s.T(), "expected proxy to receive request")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -219,12 +219,54 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
|
|||||||
// CountAccountsByProxyID returns the number of accounts using a specific proxy
|
// CountAccountsByProxyID returns the number of accounts using a specific proxy
|
||||||
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
||||||
var count int64
|
var count int64
|
||||||
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1", []any{proxyID}, &count); err != nil {
|
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1 AND deleted_at IS NULL", []any{proxyID}, &count); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return count, nil
|
return count, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *proxyRepository) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
|
||||||
|
rows, err := r.sql.QueryContext(ctx, `
|
||||||
|
SELECT id, name, platform, type, notes
|
||||||
|
FROM accounts
|
||||||
|
WHERE proxy_id = $1 AND deleted_at IS NULL
|
||||||
|
ORDER BY id DESC
|
||||||
|
`, proxyID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
out := make([]service.ProxyAccountSummary, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var (
|
||||||
|
id int64
|
||||||
|
name string
|
||||||
|
platform string
|
||||||
|
accType string
|
||||||
|
notes sql.NullString
|
||||||
|
)
|
||||||
|
if err := rows.Scan(&id, &name, &platform, &accType, ¬es); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var notesPtr *string
|
||||||
|
if notes.Valid {
|
||||||
|
notesPtr = ¬es.String
|
||||||
|
}
|
||||||
|
out = append(out, service.ProxyAccountSummary{
|
||||||
|
ID: id,
|
||||||
|
Name: name,
|
||||||
|
Platform: platform,
|
||||||
|
Type: accType,
|
||||||
|
Notes: notesPtr,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
|
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
|
||||||
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (counts map[int64]int64, err error) {
|
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (counts map[int64]int64, err error) {
|
||||||
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL AND deleted_at IS NULL GROUP BY proxy_id")
|
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL AND deleted_at IS NULL GROUP BY proxy_id")
|
||||||
|
|||||||
321
backend/internal/repository/session_limit_cache.go
Normal file
321
backend/internal/repository/session_limit_cache.go
Normal file
@@ -0,0 +1,321 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 会话限制缓存常量定义
|
||||||
|
//
|
||||||
|
// 设计说明:
|
||||||
|
// 使用 Redis 有序集合(Sorted Set)跟踪每个账号的活跃会话:
|
||||||
|
// - Key: session_limit:account:{accountID}
|
||||||
|
// - Member: sessionUUID(从 metadata.user_id 中提取)
|
||||||
|
// - Score: Unix 时间戳(会话最后活跃时间)
|
||||||
|
//
|
||||||
|
// 通过 ZREMRANGEBYSCORE 自动清理过期会话,无需手动管理 TTL
|
||||||
|
const (
|
||||||
|
// 会话限制键前缀
|
||||||
|
// 格式: session_limit:account:{accountID}
|
||||||
|
sessionLimitKeyPrefix = "session_limit:account:"
|
||||||
|
|
||||||
|
// 窗口费用缓存键前缀
|
||||||
|
// 格式: window_cost:account:{accountID}
|
||||||
|
windowCostKeyPrefix = "window_cost:account:"
|
||||||
|
|
||||||
|
// 窗口费用缓存 TTL(30秒)
|
||||||
|
windowCostCacheTTL = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// registerSessionScript 注册会话活动
|
||||||
|
// 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步
|
||||||
|
// KEYS[1] = session_limit:account:{accountID}
|
||||||
|
// ARGV[1] = maxSessions
|
||||||
|
// ARGV[2] = idleTimeout(秒)
|
||||||
|
// ARGV[3] = sessionUUID
|
||||||
|
// 返回: 1 = 允许, 0 = 拒绝
|
||||||
|
registerSessionScript = redis.NewScript(`
|
||||||
|
local key = KEYS[1]
|
||||||
|
local maxSessions = tonumber(ARGV[1])
|
||||||
|
local idleTimeout = tonumber(ARGV[2])
|
||||||
|
local sessionUUID = ARGV[3]
|
||||||
|
|
||||||
|
-- 使用 Redis 服务器时间,确保多实例时钟一致
|
||||||
|
local timeResult = redis.call('TIME')
|
||||||
|
local now = tonumber(timeResult[1])
|
||||||
|
local expireBefore = now - idleTimeout
|
||||||
|
|
||||||
|
-- 清理过期会话
|
||||||
|
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||||
|
|
||||||
|
-- 检查会话是否已存在(支持刷新时间戳)
|
||||||
|
local exists = redis.call('ZSCORE', key, sessionUUID)
|
||||||
|
if exists ~= false then
|
||||||
|
-- 会话已存在,刷新时间戳
|
||||||
|
redis.call('ZADD', key, now, sessionUUID)
|
||||||
|
redis.call('EXPIRE', key, idleTimeout + 60)
|
||||||
|
return 1
|
||||||
|
end
|
||||||
|
|
||||||
|
-- 检查是否达到会话数量上限
|
||||||
|
local count = redis.call('ZCARD', key)
|
||||||
|
if count < maxSessions then
|
||||||
|
-- 未达上限,添加新会话
|
||||||
|
redis.call('ZADD', key, now, sessionUUID)
|
||||||
|
redis.call('EXPIRE', key, idleTimeout + 60)
|
||||||
|
return 1
|
||||||
|
end
|
||||||
|
|
||||||
|
-- 达到上限,拒绝新会话
|
||||||
|
return 0
|
||||||
|
`)
|
||||||
|
|
||||||
|
// refreshSessionScript 刷新会话时间戳
|
||||||
|
// KEYS[1] = session_limit:account:{accountID}
|
||||||
|
// ARGV[1] = idleTimeout(秒)
|
||||||
|
// ARGV[2] = sessionUUID
|
||||||
|
refreshSessionScript = redis.NewScript(`
|
||||||
|
local key = KEYS[1]
|
||||||
|
local idleTimeout = tonumber(ARGV[1])
|
||||||
|
local sessionUUID = ARGV[2]
|
||||||
|
|
||||||
|
local timeResult = redis.call('TIME')
|
||||||
|
local now = tonumber(timeResult[1])
|
||||||
|
|
||||||
|
-- 检查会话是否存在
|
||||||
|
local exists = redis.call('ZSCORE', key, sessionUUID)
|
||||||
|
if exists ~= false then
|
||||||
|
redis.call('ZADD', key, now, sessionUUID)
|
||||||
|
redis.call('EXPIRE', key, idleTimeout + 60)
|
||||||
|
end
|
||||||
|
return 1
|
||||||
|
`)
|
||||||
|
|
||||||
|
// getActiveSessionCountScript 获取活跃会话数
|
||||||
|
// KEYS[1] = session_limit:account:{accountID}
|
||||||
|
// ARGV[1] = idleTimeout(秒)
|
||||||
|
getActiveSessionCountScript = redis.NewScript(`
|
||||||
|
local key = KEYS[1]
|
||||||
|
local idleTimeout = tonumber(ARGV[1])
|
||||||
|
|
||||||
|
local timeResult = redis.call('TIME')
|
||||||
|
local now = tonumber(timeResult[1])
|
||||||
|
local expireBefore = now - idleTimeout
|
||||||
|
|
||||||
|
-- 清理过期会话
|
||||||
|
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||||
|
|
||||||
|
return redis.call('ZCARD', key)
|
||||||
|
`)
|
||||||
|
|
||||||
|
// isSessionActiveScript 检查会话是否活跃
|
||||||
|
// KEYS[1] = session_limit:account:{accountID}
|
||||||
|
// ARGV[1] = idleTimeout(秒)
|
||||||
|
// ARGV[2] = sessionUUID
|
||||||
|
isSessionActiveScript = redis.NewScript(`
|
||||||
|
local key = KEYS[1]
|
||||||
|
local idleTimeout = tonumber(ARGV[1])
|
||||||
|
local sessionUUID = ARGV[2]
|
||||||
|
|
||||||
|
local timeResult = redis.call('TIME')
|
||||||
|
local now = tonumber(timeResult[1])
|
||||||
|
local expireBefore = now - idleTimeout
|
||||||
|
|
||||||
|
-- 获取会话的时间戳
|
||||||
|
local score = redis.call('ZSCORE', key, sessionUUID)
|
||||||
|
if score == false then
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
|
||||||
|
-- 检查是否过期
|
||||||
|
if tonumber(score) <= expireBefore then
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
|
||||||
|
return 1
|
||||||
|
`)
|
||||||
|
)
|
||||||
|
|
||||||
|
type sessionLimitCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
defaultIdleTimeout time.Duration // 默认空闲超时(用于 GetActiveSessionCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSessionLimitCache 创建会话限制缓存
|
||||||
|
// defaultIdleTimeoutMinutes: 默认空闲超时时间(分钟),用于无参数查询
|
||||||
|
func NewSessionLimitCache(rdb *redis.Client, defaultIdleTimeoutMinutes int) service.SessionLimitCache {
|
||||||
|
if defaultIdleTimeoutMinutes <= 0 {
|
||||||
|
defaultIdleTimeoutMinutes = 5 // 默认 5 分钟
|
||||||
|
}
|
||||||
|
return &sessionLimitCache{
|
||||||
|
rdb: rdb,
|
||||||
|
defaultIdleTimeout: time.Duration(defaultIdleTimeoutMinutes) * time.Minute,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sessionLimitKey 生成会话限制的 Redis 键
|
||||||
|
func sessionLimitKey(accountID int64) string {
|
||||||
|
return fmt.Sprintf("%s%d", sessionLimitKeyPrefix, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// windowCostKey 生成窗口费用缓存的 Redis 键
|
||||||
|
func windowCostKey(accountID int64) string {
|
||||||
|
return fmt.Sprintf("%s%d", windowCostKeyPrefix, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterSession 注册会话活动
|
||||||
|
func (c *sessionLimitCache) RegisterSession(ctx context.Context, accountID int64, sessionUUID string, maxSessions int, idleTimeout time.Duration) (bool, error) {
|
||||||
|
if sessionUUID == "" || maxSessions <= 0 {
|
||||||
|
return true, nil // 无效参数,默认允许
|
||||||
|
}
|
||||||
|
|
||||||
|
key := sessionLimitKey(accountID)
|
||||||
|
idleTimeoutSeconds := int(idleTimeout.Seconds())
|
||||||
|
if idleTimeoutSeconds <= 0 {
|
||||||
|
idleTimeoutSeconds = int(c.defaultIdleTimeout.Seconds())
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := registerSessionScript.Run(ctx, c.rdb, []string{key}, maxSessions, idleTimeoutSeconds, sessionUUID).Int()
|
||||||
|
if err != nil {
|
||||||
|
return true, err // 失败开放:缓存错误时允许请求通过
|
||||||
|
}
|
||||||
|
return result == 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshSession 刷新会话时间戳
|
||||||
|
func (c *sessionLimitCache) RefreshSession(ctx context.Context, accountID int64, sessionUUID string, idleTimeout time.Duration) error {
|
||||||
|
if sessionUUID == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
key := sessionLimitKey(accountID)
|
||||||
|
idleTimeoutSeconds := int(idleTimeout.Seconds())
|
||||||
|
if idleTimeoutSeconds <= 0 {
|
||||||
|
idleTimeoutSeconds = int(c.defaultIdleTimeout.Seconds())
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := refreshSessionScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds, sessionUUID).Result()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveSessionCount 获取活跃会话数
|
||||||
|
func (c *sessionLimitCache) GetActiveSessionCount(ctx context.Context, accountID int64) (int, error) {
|
||||||
|
key := sessionLimitKey(accountID)
|
||||||
|
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
|
||||||
|
|
||||||
|
result, err := getActiveSessionCountScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds).Int()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
|
||||||
|
func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||||
|
if len(accountIDs) == 0 {
|
||||||
|
return make(map[int64]int), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
results := make(map[int64]int, len(accountIDs))
|
||||||
|
|
||||||
|
// 使用 pipeline 批量执行
|
||||||
|
pipe := c.rdb.Pipeline()
|
||||||
|
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
|
||||||
|
|
||||||
|
cmds := make(map[int64]*redis.Cmd, len(accountIDs))
|
||||||
|
for _, accountID := range accountIDs {
|
||||||
|
key := sessionLimitKey(accountID)
|
||||||
|
cmds[accountID] = getActiveSessionCountScript.Run(ctx, pipe, []string{key}, idleTimeoutSeconds)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行 pipeline,即使部分失败也尝试获取成功的结果
|
||||||
|
_, _ = pipe.Exec(ctx)
|
||||||
|
|
||||||
|
for accountID, cmd := range cmds {
|
||||||
|
if result, err := cmd.Int(); err == nil {
|
||||||
|
results[accountID] = result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSessionActive 检查会话是否活跃
|
||||||
|
func (c *sessionLimitCache) IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error) {
|
||||||
|
if sessionUUID == "" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
key := sessionLimitKey(accountID)
|
||||||
|
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
|
||||||
|
|
||||||
|
result, err := isSessionActiveScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds, sessionUUID).Int()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return result == 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== 5h窗口费用缓存实现 ==========
|
||||||
|
|
||||||
|
// GetWindowCost 获取缓存的窗口费用
|
||||||
|
func (c *sessionLimitCache) GetWindowCost(ctx context.Context, accountID int64) (float64, bool, error) {
|
||||||
|
key := windowCostKey(accountID)
|
||||||
|
val, err := c.rdb.Get(ctx, key).Float64()
|
||||||
|
if err == redis.Nil {
|
||||||
|
return 0, false, nil // 缓存未命中
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return 0, false, err
|
||||||
|
}
|
||||||
|
return val, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetWindowCost 设置窗口费用缓存
|
||||||
|
func (c *sessionLimitCache) SetWindowCost(ctx context.Context, accountID int64, cost float64) error {
|
||||||
|
key := windowCostKey(accountID)
|
||||||
|
return c.rdb.Set(ctx, key, cost, windowCostCacheTTL).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWindowCostBatch 批量获取窗口费用缓存
|
||||||
|
func (c *sessionLimitCache) GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error) {
|
||||||
|
if len(accountIDs) == 0 {
|
||||||
|
return make(map[int64]float64), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建批量查询的 keys
|
||||||
|
keys := make([]string, len(accountIDs))
|
||||||
|
for i, accountID := range accountIDs {
|
||||||
|
keys[i] = windowCostKey(accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用 MGET 批量获取
|
||||||
|
vals, err := c.rdb.MGet(ctx, keys...).Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
results := make(map[int64]float64, len(accountIDs))
|
||||||
|
for i, val := range vals {
|
||||||
|
if val == nil {
|
||||||
|
continue // 缓存未命中
|
||||||
|
}
|
||||||
|
// 尝试解析为 float64
|
||||||
|
switch v := val.(type) {
|
||||||
|
case string:
|
||||||
|
if cost, err := strconv.ParseFloat(v, 64); err == nil {
|
||||||
|
results[accountIDs[i]] = cost
|
||||||
|
}
|
||||||
|
case float64:
|
||||||
|
results[accountIDs[i]] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
@@ -22,7 +22,7 @@ import (
|
|||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
)
|
)
|
||||||
|
|
||||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
|
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
|
||||||
|
|
||||||
type usageLogRepository struct {
|
type usageLogRepository struct {
|
||||||
client *dbent.Client
|
client *dbent.Client
|
||||||
@@ -105,6 +105,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
total_cost,
|
total_cost,
|
||||||
actual_cost,
|
actual_cost,
|
||||||
rate_multiplier,
|
rate_multiplier,
|
||||||
|
account_rate_multiplier,
|
||||||
billing_type,
|
billing_type,
|
||||||
stream,
|
stream,
|
||||||
duration_ms,
|
duration_ms,
|
||||||
@@ -120,7 +121,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
$8, $9, $10, $11,
|
$8, $9, $10, $11,
|
||||||
$12, $13,
|
$12, $13,
|
||||||
$14, $15, $16, $17, $18, $19,
|
$14, $15, $16, $17, $18, $19,
|
||||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29
|
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30
|
||||||
)
|
)
|
||||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
RETURNING id, created_at
|
RETURNING id, created_at
|
||||||
@@ -160,6 +161,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
log.TotalCost,
|
log.TotalCost,
|
||||||
log.ActualCost,
|
log.ActualCost,
|
||||||
rateMultiplier,
|
rateMultiplier,
|
||||||
|
log.AccountRateMultiplier,
|
||||||
log.BillingType,
|
log.BillingType,
|
||||||
log.Stream,
|
log.Stream,
|
||||||
duration,
|
duration,
|
||||||
@@ -270,13 +272,13 @@ type DashboardStats = usagestats.DashboardStats
|
|||||||
|
|
||||||
func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
||||||
stats := &DashboardStats{}
|
stats := &DashboardStats{}
|
||||||
now := time.Now().UTC()
|
now := timezone.Now()
|
||||||
todayUTC := truncateToDayUTC(now)
|
todayStart := timezone.Today()
|
||||||
|
|
||||||
if err := r.fillDashboardEntityStats(ctx, stats, todayUTC, now); err != nil {
|
if err := r.fillDashboardEntityStats(ctx, stats, todayStart, now); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := r.fillDashboardUsageStatsAggregated(ctx, stats, todayUTC, now); err != nil {
|
if err := r.fillDashboardUsageStatsAggregated(ctx, stats, todayStart, now); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -298,13 +300,13 @@ func (r *usageLogRepository) GetDashboardStatsWithRange(ctx context.Context, sta
|
|||||||
}
|
}
|
||||||
|
|
||||||
stats := &DashboardStats{}
|
stats := &DashboardStats{}
|
||||||
now := time.Now().UTC()
|
now := timezone.Now()
|
||||||
todayUTC := truncateToDayUTC(now)
|
todayStart := timezone.Today()
|
||||||
|
|
||||||
if err := r.fillDashboardEntityStats(ctx, stats, todayUTC, now); err != nil {
|
if err := r.fillDashboardEntityStats(ctx, stats, todayStart, now); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := r.fillDashboardUsageStatsFromUsageLogs(ctx, stats, startUTC, endUTC, todayUTC, now); err != nil {
|
if err := r.fillDashboardUsageStatsFromUsageLogs(ctx, stats, startUTC, endUTC, todayStart, now); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -455,7 +457,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
|
|||||||
FROM usage_dashboard_hourly
|
FROM usage_dashboard_hourly
|
||||||
WHERE bucket_start = $1
|
WHERE bucket_start = $1
|
||||||
`
|
`
|
||||||
hourStart := now.UTC().Truncate(time.Hour)
|
hourStart := now.In(timezone.Location()).Truncate(time.Hour)
|
||||||
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart}, &stats.HourlyActiveUsers); err != nil {
|
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart}, &stats.HourlyActiveUsers); err != nil {
|
||||||
if err != sql.ErrNoRows {
|
if err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
@@ -835,7 +837,9 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
|
|||||||
SELECT
|
SELECT
|
||||||
COUNT(*) as requests,
|
COUNT(*) as requests,
|
||||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||||
COALESCE(SUM(actual_cost), 0) as cost
|
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||||
|
COALESCE(SUM(total_cost), 0) as standard_cost,
|
||||||
|
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE account_id = $1 AND created_at >= $2
|
WHERE account_id = $1 AND created_at >= $2
|
||||||
`
|
`
|
||||||
@@ -849,6 +853,8 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
|
|||||||
&stats.Requests,
|
&stats.Requests,
|
||||||
&stats.Tokens,
|
&stats.Tokens,
|
||||||
&stats.Cost,
|
&stats.Cost,
|
||||||
|
&stats.StandardCost,
|
||||||
|
&stats.UserCost,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -861,7 +867,9 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
|
|||||||
SELECT
|
SELECT
|
||||||
COUNT(*) as requests,
|
COUNT(*) as requests,
|
||||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||||
COALESCE(SUM(actual_cost), 0) as cost
|
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||||
|
COALESCE(SUM(total_cost), 0) as standard_cost,
|
||||||
|
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE account_id = $1 AND created_at >= $2
|
WHERE account_id = $1 AND created_at >= $2
|
||||||
`
|
`
|
||||||
@@ -875,6 +883,8 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
|
|||||||
&stats.Requests,
|
&stats.Requests,
|
||||||
&stats.Tokens,
|
&stats.Tokens,
|
||||||
&stats.Cost,
|
&stats.Cost,
|
||||||
|
&stats.StandardCost,
|
||||||
|
&stats.UserCost,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1400,8 +1410,8 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
|
// GetUsageTrendWithFilters returns usage trend data with optional filters
|
||||||
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) (results []TrendDataPoint, err error) {
|
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) (results []TrendDataPoint, err error) {
|
||||||
dateFormat := "YYYY-MM-DD"
|
dateFormat := "YYYY-MM-DD"
|
||||||
if granularity == "hour" {
|
if granularity == "hour" {
|
||||||
dateFormat = "YYYY-MM-DD HH24:00"
|
dateFormat = "YYYY-MM-DD HH24:00"
|
||||||
@@ -1430,6 +1440,22 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
|||||||
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
|
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
|
||||||
args = append(args, apiKeyID)
|
args = append(args, apiKeyID)
|
||||||
}
|
}
|
||||||
|
if accountID > 0 {
|
||||||
|
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
|
||||||
|
args = append(args, accountID)
|
||||||
|
}
|
||||||
|
if groupID > 0 {
|
||||||
|
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
|
||||||
|
args = append(args, groupID)
|
||||||
|
}
|
||||||
|
if model != "" {
|
||||||
|
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
|
||||||
|
args = append(args, model)
|
||||||
|
}
|
||||||
|
if stream != nil {
|
||||||
|
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
|
||||||
|
args = append(args, *stream)
|
||||||
|
}
|
||||||
query += " GROUP BY date ORDER BY date ASC"
|
query += " GROUP BY date ORDER BY date ASC"
|
||||||
|
|
||||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||||
@@ -1452,9 +1478,15 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
|||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
|
// GetModelStatsWithFilters returns model statistics with optional filters
|
||||||
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) (results []ModelStat, err error) {
|
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) (results []ModelStat, err error) {
|
||||||
query := `
|
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||||
|
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
|
||||||
|
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||||
|
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||||
|
}
|
||||||
|
|
||||||
|
query := fmt.Sprintf(`
|
||||||
SELECT
|
SELECT
|
||||||
model,
|
model,
|
||||||
COUNT(*) as requests,
|
COUNT(*) as requests,
|
||||||
@@ -1462,10 +1494,10 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
|
|||||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||||
COALESCE(SUM(total_cost), 0) as cost,
|
COALESCE(SUM(total_cost), 0) as cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
%s
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE created_at >= $1 AND created_at < $2
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
`
|
`, actualCostExpr)
|
||||||
|
|
||||||
args := []any{startTime, endTime}
|
args := []any{startTime, endTime}
|
||||||
if userID > 0 {
|
if userID > 0 {
|
||||||
@@ -1480,6 +1512,14 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
|
|||||||
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
|
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
|
||||||
args = append(args, accountID)
|
args = append(args, accountID)
|
||||||
}
|
}
|
||||||
|
if groupID > 0 {
|
||||||
|
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
|
||||||
|
args = append(args, groupID)
|
||||||
|
}
|
||||||
|
if stream != nil {
|
||||||
|
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
|
||||||
|
args = append(args, *stream)
|
||||||
|
}
|
||||||
query += " GROUP BY model ORDER BY total_tokens DESC"
|
query += " GROUP BY model ORDER BY total_tokens DESC"
|
||||||
|
|
||||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||||
@@ -1587,12 +1627,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
|
|||||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||||
|
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
|
||||||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
%s
|
%s
|
||||||
`, buildWhere(conditions))
|
`, buildWhere(conditions))
|
||||||
|
|
||||||
stats := &UsageStats{}
|
stats := &UsageStats{}
|
||||||
|
var totalAccountCost float64
|
||||||
if err := scanSingleRow(
|
if err := scanSingleRow(
|
||||||
ctx,
|
ctx,
|
||||||
r.sql,
|
r.sql,
|
||||||
@@ -1604,10 +1646,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
|
|||||||
&stats.TotalCacheTokens,
|
&stats.TotalCacheTokens,
|
||||||
&stats.TotalCost,
|
&stats.TotalCost,
|
||||||
&stats.TotalActualCost,
|
&stats.TotalActualCost,
|
||||||
|
&totalAccountCost,
|
||||||
&stats.AverageDurationMs,
|
&stats.AverageDurationMs,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if filters.AccountID > 0 {
|
||||||
|
stats.TotalAccountCost = &totalAccountCost
|
||||||
|
}
|
||||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||||||
return stats, nil
|
return stats, nil
|
||||||
}
|
}
|
||||||
@@ -1634,7 +1680,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
|||||||
COUNT(*) as requests,
|
COUNT(*) as requests,
|
||||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||||
COALESCE(SUM(total_cost), 0) as cost,
|
COALESCE(SUM(total_cost), 0) as cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
|
||||||
|
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
|
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
|
||||||
GROUP BY date
|
GROUP BY date
|
||||||
@@ -1661,7 +1708,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
|||||||
var tokens int64
|
var tokens int64
|
||||||
var cost float64
|
var cost float64
|
||||||
var actualCost float64
|
var actualCost float64
|
||||||
if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost); err != nil {
|
var userCost float64
|
||||||
|
if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost, &userCost); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
t, _ := time.Parse("2006-01-02", date)
|
t, _ := time.Parse("2006-01-02", date)
|
||||||
@@ -1672,19 +1720,21 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
|||||||
Tokens: tokens,
|
Tokens: tokens,
|
||||||
Cost: cost,
|
Cost: cost,
|
||||||
ActualCost: actualCost,
|
ActualCost: actualCost,
|
||||||
|
UserCost: userCost,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if err = rows.Err(); err != nil {
|
if err = rows.Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var totalActualCost, totalStandardCost float64
|
var totalAccountCost, totalUserCost, totalStandardCost float64
|
||||||
var totalRequests, totalTokens int64
|
var totalRequests, totalTokens int64
|
||||||
var highestCostDay, highestRequestDay *AccountUsageHistory
|
var highestCostDay, highestRequestDay *AccountUsageHistory
|
||||||
|
|
||||||
for i := range history {
|
for i := range history {
|
||||||
h := &history[i]
|
h := &history[i]
|
||||||
totalActualCost += h.ActualCost
|
totalAccountCost += h.ActualCost
|
||||||
|
totalUserCost += h.UserCost
|
||||||
totalStandardCost += h.Cost
|
totalStandardCost += h.Cost
|
||||||
totalRequests += h.Requests
|
totalRequests += h.Requests
|
||||||
totalTokens += h.Tokens
|
totalTokens += h.Tokens
|
||||||
@@ -1711,11 +1761,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
|||||||
summary := AccountUsageSummary{
|
summary := AccountUsageSummary{
|
||||||
Days: daysCount,
|
Days: daysCount,
|
||||||
ActualDaysUsed: actualDaysUsed,
|
ActualDaysUsed: actualDaysUsed,
|
||||||
TotalCost: totalActualCost,
|
TotalCost: totalAccountCost,
|
||||||
|
TotalUserCost: totalUserCost,
|
||||||
TotalStandardCost: totalStandardCost,
|
TotalStandardCost: totalStandardCost,
|
||||||
TotalRequests: totalRequests,
|
TotalRequests: totalRequests,
|
||||||
TotalTokens: totalTokens,
|
TotalTokens: totalTokens,
|
||||||
AvgDailyCost: totalActualCost / float64(actualDaysUsed),
|
AvgDailyCost: totalAccountCost / float64(actualDaysUsed),
|
||||||
|
AvgDailyUserCost: totalUserCost / float64(actualDaysUsed),
|
||||||
AvgDailyRequests: float64(totalRequests) / float64(actualDaysUsed),
|
AvgDailyRequests: float64(totalRequests) / float64(actualDaysUsed),
|
||||||
AvgDailyTokens: float64(totalTokens) / float64(actualDaysUsed),
|
AvgDailyTokens: float64(totalTokens) / float64(actualDaysUsed),
|
||||||
AvgDurationMs: avgDuration,
|
AvgDurationMs: avgDuration,
|
||||||
@@ -1727,11 +1779,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
|||||||
summary.Today = &struct {
|
summary.Today = &struct {
|
||||||
Date string `json:"date"`
|
Date string `json:"date"`
|
||||||
Cost float64 `json:"cost"`
|
Cost float64 `json:"cost"`
|
||||||
|
UserCost float64 `json:"user_cost"`
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
Tokens int64 `json:"tokens"`
|
Tokens int64 `json:"tokens"`
|
||||||
}{
|
}{
|
||||||
Date: history[i].Date,
|
Date: history[i].Date,
|
||||||
Cost: history[i].ActualCost,
|
Cost: history[i].ActualCost,
|
||||||
|
UserCost: history[i].UserCost,
|
||||||
Requests: history[i].Requests,
|
Requests: history[i].Requests,
|
||||||
Tokens: history[i].Tokens,
|
Tokens: history[i].Tokens,
|
||||||
}
|
}
|
||||||
@@ -1744,11 +1798,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
|||||||
Date string `json:"date"`
|
Date string `json:"date"`
|
||||||
Label string `json:"label"`
|
Label string `json:"label"`
|
||||||
Cost float64 `json:"cost"`
|
Cost float64 `json:"cost"`
|
||||||
|
UserCost float64 `json:"user_cost"`
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
}{
|
}{
|
||||||
Date: highestCostDay.Date,
|
Date: highestCostDay.Date,
|
||||||
Label: highestCostDay.Label,
|
Label: highestCostDay.Label,
|
||||||
Cost: highestCostDay.ActualCost,
|
Cost: highestCostDay.ActualCost,
|
||||||
|
UserCost: highestCostDay.UserCost,
|
||||||
Requests: highestCostDay.Requests,
|
Requests: highestCostDay.Requests,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1759,15 +1815,17 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
|||||||
Label string `json:"label"`
|
Label string `json:"label"`
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
Cost float64 `json:"cost"`
|
Cost float64 `json:"cost"`
|
||||||
|
UserCost float64 `json:"user_cost"`
|
||||||
}{
|
}{
|
||||||
Date: highestRequestDay.Date,
|
Date: highestRequestDay.Date,
|
||||||
Label: highestRequestDay.Label,
|
Label: highestRequestDay.Label,
|
||||||
Requests: highestRequestDay.Requests,
|
Requests: highestRequestDay.Requests,
|
||||||
Cost: highestRequestDay.ActualCost,
|
Cost: highestRequestDay.ActualCost,
|
||||||
|
UserCost: highestRequestDay.UserCost,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID)
|
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
models = []ModelStat{}
|
models = []ModelStat{}
|
||||||
}
|
}
|
||||||
@@ -2015,6 +2073,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
totalCost float64
|
totalCost float64
|
||||||
actualCost float64
|
actualCost float64
|
||||||
rateMultiplier float64
|
rateMultiplier float64
|
||||||
|
accountRateMultiplier sql.NullFloat64
|
||||||
billingType int16
|
billingType int16
|
||||||
stream bool
|
stream bool
|
||||||
durationMs sql.NullInt64
|
durationMs sql.NullInt64
|
||||||
@@ -2048,6 +2107,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
&totalCost,
|
&totalCost,
|
||||||
&actualCost,
|
&actualCost,
|
||||||
&rateMultiplier,
|
&rateMultiplier,
|
||||||
|
&accountRateMultiplier,
|
||||||
&billingType,
|
&billingType,
|
||||||
&stream,
|
&stream,
|
||||||
&durationMs,
|
&durationMs,
|
||||||
@@ -2080,6 +2140,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
TotalCost: totalCost,
|
TotalCost: totalCost,
|
||||||
ActualCost: actualCost,
|
ActualCost: actualCost,
|
||||||
RateMultiplier: rateMultiplier,
|
RateMultiplier: rateMultiplier,
|
||||||
|
AccountRateMultiplier: nullFloat64Ptr(accountRateMultiplier),
|
||||||
BillingType: int8(billingType),
|
BillingType: int8(billingType),
|
||||||
Stream: stream,
|
Stream: stream,
|
||||||
ImageCount: imageCount,
|
ImageCount: imageCount,
|
||||||
@@ -2186,6 +2247,14 @@ func nullInt(v *int) sql.NullInt64 {
|
|||||||
return sql.NullInt64{Int64: int64(*v), Valid: true}
|
return sql.NullInt64{Int64: int64(*v), Valid: true}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func nullFloat64Ptr(v sql.NullFloat64) *float64 {
|
||||||
|
if !v.Valid {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := v.Float64
|
||||||
|
return &out
|
||||||
|
}
|
||||||
|
|
||||||
func nullString(v *string) sql.NullString {
|
func nullString(v *string) sql.NullString {
|
||||||
if v == nil || *v == "" {
|
if v == nil || *v == "" {
|
||||||
return sql.NullString{}
|
return sql.NullString{}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
@@ -36,6 +37,12 @@ func TestUsageLogRepoSuite(t *testing.T) {
|
|||||||
suite.Run(t, new(UsageLogRepoSuite))
|
suite.Run(t, new(UsageLogRepoSuite))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// truncateToDayUTC 截断到 UTC 日期边界(测试辅助函数)
|
||||||
|
func truncateToDayUTC(t time.Time) time.Time {
|
||||||
|
t = t.UTC()
|
||||||
|
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.APIKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
|
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.APIKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
|
||||||
log := &service.UsageLog{
|
log := &service.UsageLog{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
@@ -95,6 +102,34 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
|
|||||||
s.Require().Error(err, "expected error for non-existent ID")
|
s.Require().Error(err, "expected error for non-existent ID")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestGetByID_ReturnsAccountRateMultiplier() {
|
||||||
|
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-mult@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-mult", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-mult"})
|
||||||
|
|
||||||
|
m := 0.5
|
||||||
|
log := &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: uuid.New().String(),
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 1.0,
|
||||||
|
ActualCost: 2.0,
|
||||||
|
AccountRateMultiplier: &m,
|
||||||
|
CreatedAt: timezone.Today().Add(2 * time.Hour),
|
||||||
|
}
|
||||||
|
_, err := s.repo.Create(s.ctx, log)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, log.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().NotNil(got.AccountRateMultiplier)
|
||||||
|
s.Require().InEpsilon(0.5, *got.AccountRateMultiplier, 0.0001)
|
||||||
|
}
|
||||||
|
|
||||||
// --- Delete ---
|
// --- Delete ---
|
||||||
|
|
||||||
func (s *UsageLogRepoSuite) TestDelete() {
|
func (s *UsageLogRepoSuite) TestDelete() {
|
||||||
@@ -403,12 +438,49 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
|
|||||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
|
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
|
||||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"})
|
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"})
|
||||||
|
|
||||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
createdAt := timezone.Today().Add(1 * time.Hour)
|
||||||
|
|
||||||
|
m1 := 1.5
|
||||||
|
m2 := 0.0
|
||||||
|
_, err := s.repo.Create(s.ctx, &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: uuid.New().String(),
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 1.0,
|
||||||
|
ActualCost: 2.0,
|
||||||
|
AccountRateMultiplier: &m1,
|
||||||
|
CreatedAt: createdAt,
|
||||||
|
})
|
||||||
|
s.Require().NoError(err)
|
||||||
|
_, err = s.repo.Create(s.ctx, &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: uuid.New().String(),
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 5,
|
||||||
|
OutputTokens: 5,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 1.0,
|
||||||
|
AccountRateMultiplier: &m2,
|
||||||
|
CreatedAt: createdAt,
|
||||||
|
})
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
stats, err := s.repo.GetAccountTodayStats(s.ctx, account.ID)
|
stats, err := s.repo.GetAccountTodayStats(s.ctx, account.ID)
|
||||||
s.Require().NoError(err, "GetAccountTodayStats")
|
s.Require().NoError(err, "GetAccountTodayStats")
|
||||||
s.Require().Equal(int64(1), stats.Requests)
|
s.Require().Equal(int64(2), stats.Requests)
|
||||||
s.Require().Equal(int64(30), stats.Tokens)
|
s.Require().Equal(int64(40), stats.Tokens)
|
||||||
|
// account cost = SUM(total_cost * account_rate_multiplier)
|
||||||
|
s.Require().InEpsilon(1.5, stats.Cost, 0.0001)
|
||||||
|
// standard cost = SUM(total_cost)
|
||||||
|
s.Require().InEpsilon(1.5, stats.StandardCost, 0.0001)
|
||||||
|
// user cost = SUM(actual_cost)
|
||||||
|
s.Require().InEpsilon(3.0, stats.UserCost, 0.0001)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() {
|
func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() {
|
||||||
@@ -872,17 +944,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
|
|||||||
endTime := base.Add(48 * time.Hour)
|
endTime := base.Add(48 * time.Hour)
|
||||||
|
|
||||||
// Test with user filter
|
// Test with user filter
|
||||||
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0)
|
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil)
|
||||||
s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
|
s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
|
||||||
s.Require().Len(trend, 2)
|
s.Require().Len(trend, 2)
|
||||||
|
|
||||||
// Test with apiKey filter
|
// Test with apiKey filter
|
||||||
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID)
|
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil)
|
||||||
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
|
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
|
||||||
s.Require().Len(trend, 2)
|
s.Require().Len(trend, 2)
|
||||||
|
|
||||||
// Test with both filters
|
// Test with both filters
|
||||||
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID)
|
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil)
|
||||||
s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
|
s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
|
||||||
s.Require().Len(trend, 2)
|
s.Require().Len(trend, 2)
|
||||||
}
|
}
|
||||||
@@ -899,7 +971,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
|
|||||||
startTime := base.Add(-1 * time.Hour)
|
startTime := base.Add(-1 * time.Hour)
|
||||||
endTime := base.Add(3 * time.Hour)
|
endTime := base.Add(3 * time.Hour)
|
||||||
|
|
||||||
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0)
|
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil)
|
||||||
s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
|
s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
|
||||||
s.Require().Len(trend, 2)
|
s.Require().Len(trend, 2)
|
||||||
}
|
}
|
||||||
@@ -945,17 +1017,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
|
|||||||
endTime := base.Add(2 * time.Hour)
|
endTime := base.Add(2 * time.Hour)
|
||||||
|
|
||||||
// Test with user filter
|
// Test with user filter
|
||||||
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0)
|
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil)
|
||||||
s.Require().NoError(err, "GetModelStatsWithFilters user filter")
|
s.Require().NoError(err, "GetModelStatsWithFilters user filter")
|
||||||
s.Require().Len(stats, 2)
|
s.Require().Len(stats, 2)
|
||||||
|
|
||||||
// Test with apiKey filter
|
// Test with apiKey filter
|
||||||
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0)
|
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil)
|
||||||
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
|
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
|
||||||
s.Require().Len(stats, 2)
|
s.Require().Len(stats, 2)
|
||||||
|
|
||||||
// Test with account filter
|
// Test with account filter
|
||||||
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID)
|
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil)
|
||||||
s.Require().NoError(err, "GetModelStatsWithFilters account filter")
|
s.Require().NoError(err, "GetModelStatsWithFilters account filter")
|
||||||
s.Require().Len(stats, 2)
|
s.Require().Len(stats, 2)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,6 +37,16 @@ func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient
|
|||||||
return NewPricingRemoteClient(cfg.Update.ProxyURL)
|
return NewPricingRemoteClient(cfg.Update.ProxyURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProvideSessionLimitCache 创建会话限制缓存
|
||||||
|
// 用于 Anthropic OAuth/SetupToken 账号的并发会话数量控制
|
||||||
|
func ProvideSessionLimitCache(rdb *redis.Client, cfg *config.Config) service.SessionLimitCache {
|
||||||
|
defaultIdleTimeoutMinutes := 5 // 默认 5 分钟空闲超时
|
||||||
|
if cfg != nil && cfg.Gateway.SessionIdleTimeoutMinutes > 0 {
|
||||||
|
defaultIdleTimeoutMinutes = cfg.Gateway.SessionIdleTimeoutMinutes
|
||||||
|
}
|
||||||
|
return NewSessionLimitCache(rdb, defaultIdleTimeoutMinutes)
|
||||||
|
}
|
||||||
|
|
||||||
// ProviderSet is the Wire provider set for all repositories
|
// ProviderSet is the Wire provider set for all repositories
|
||||||
var ProviderSet = wire.NewSet(
|
var ProviderSet = wire.NewSet(
|
||||||
NewUserRepository,
|
NewUserRepository,
|
||||||
@@ -61,6 +71,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewTempUnschedCache,
|
NewTempUnschedCache,
|
||||||
NewTimeoutCounterCache,
|
NewTimeoutCounterCache,
|
||||||
ProvideConcurrencyCache,
|
ProvideConcurrencyCache,
|
||||||
|
ProvideSessionLimitCache,
|
||||||
NewDashboardCache,
|
NewDashboardCache,
|
||||||
NewEmailCache,
|
NewEmailCache,
|
||||||
NewIdentityCache,
|
NewIdentityCache,
|
||||||
@@ -69,6 +80,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewGeminiTokenCache,
|
NewGeminiTokenCache,
|
||||||
NewSchedulerCache,
|
NewSchedulerCache,
|
||||||
NewSchedulerOutboxRepository,
|
NewSchedulerOutboxRepository,
|
||||||
|
NewProxyLatencyCache,
|
||||||
|
|
||||||
// HTTP service ports (DI Strategy A: return interface directly)
|
// HTTP service ports (DI Strategy A: return interface directly)
|
||||||
NewTurnstileVerifier,
|
NewTurnstileVerifier,
|
||||||
|
|||||||
@@ -241,6 +241,7 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"total_cost": 0.5,
|
"total_cost": 0.5,
|
||||||
"actual_cost": 0.5,
|
"actual_cost": 0.5,
|
||||||
"rate_multiplier": 1,
|
"rate_multiplier": 1,
|
||||||
|
"account_rate_multiplier": null,
|
||||||
"billing_type": 0,
|
"billing_type": 0,
|
||||||
"stream": true,
|
"stream": true,
|
||||||
"duration_ms": 100,
|
"duration_ms": 100,
|
||||||
@@ -435,12 +436,12 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
settingRepo := newStubSettingRepo()
|
settingRepo := newStubSettingRepo()
|
||||||
settingService := service.NewSettingService(settingRepo, cfg)
|
settingService := service.NewSettingService(settingRepo, cfg)
|
||||||
|
|
||||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil)
|
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
|
||||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
|
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
|
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
|
||||||
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
jwtAuth := func(c *gin.Context) {
|
jwtAuth := func(c *gin.Context) {
|
||||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
|
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
|
||||||
@@ -779,6 +780,10 @@ func (s *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id
|
|||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *stubAccountRepo) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
func (s *stubAccountRepo) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
@@ -799,6 +804,10 @@ func (s *stubAccountRepo) ClearAntigravityQuotaScopes(ctx context.Context, id in
|
|||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ClearModelRateLimits(ctx context.Context, id int64) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *stubAccountRepo) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
func (s *stubAccountRepo) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
@@ -858,6 +867,10 @@ func (stubProxyRepo) CountAccountsByProxyID(ctx context.Context, proxyID int64)
|
|||||||
return 0, errors.New("not implemented")
|
return 0, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (stubProxyRepo) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
type stubRedeemCodeRepo struct{}
|
type stubRedeemCodeRepo struct{}
|
||||||
|
|
||||||
func (stubRedeemCodeRepo) Create(ctx context.Context, code *service.RedeemCode) error {
|
func (stubRedeemCodeRepo) Create(ctx context.Context, code *service.RedeemCode) error {
|
||||||
@@ -1229,11 +1242,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) {
|
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error) {
|
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,40 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// CSPNonceKey is the context key for storing the CSP nonce
|
||||||
|
CSPNonceKey = "csp_nonce"
|
||||||
|
// NonceTemplate is the placeholder in CSP policy for nonce
|
||||||
|
NonceTemplate = "__CSP_NONCE__"
|
||||||
|
// CloudflareInsightsDomain is the domain for Cloudflare Web Analytics
|
||||||
|
CloudflareInsightsDomain = "https://static.cloudflareinsights.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateNonce generates a cryptographically secure random nonce
|
||||||
|
func GenerateNonce() string {
|
||||||
|
b := make([]byte, 16)
|
||||||
|
_, _ = rand.Read(b)
|
||||||
|
return base64.StdEncoding.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNonceFromContext retrieves the CSP nonce from gin context
|
||||||
|
func GetNonceFromContext(c *gin.Context) string {
|
||||||
|
if nonce, exists := c.Get(CSPNonceKey); exists {
|
||||||
|
if s, ok := nonce.(string); ok {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// SecurityHeaders sets baseline security headers for all responses.
|
// SecurityHeaders sets baseline security headers for all responses.
|
||||||
func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
|
func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
|
||||||
policy := strings.TrimSpace(cfg.Policy)
|
policy := strings.TrimSpace(cfg.Policy)
|
||||||
@@ -14,13 +42,75 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
|
|||||||
policy = config.DefaultCSPPolicy
|
policy = config.DefaultCSPPolicy
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Enhance policy with required directives (nonce placeholder and Cloudflare Insights)
|
||||||
|
policy = enhanceCSPPolicy(policy)
|
||||||
|
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
c.Header("X-Content-Type-Options", "nosniff")
|
c.Header("X-Content-Type-Options", "nosniff")
|
||||||
c.Header("X-Frame-Options", "DENY")
|
c.Header("X-Frame-Options", "DENY")
|
||||||
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||||
|
|
||||||
if cfg.Enabled {
|
if cfg.Enabled {
|
||||||
c.Header("Content-Security-Policy", policy)
|
// Generate nonce for this request
|
||||||
|
nonce := GenerateNonce()
|
||||||
|
c.Set(CSPNonceKey, nonce)
|
||||||
|
|
||||||
|
// Replace nonce placeholder in policy
|
||||||
|
finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'")
|
||||||
|
c.Header("Content-Security-Policy", finalPolicy)
|
||||||
}
|
}
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain.
|
||||||
|
// This allows the application to work correctly even if the config file has an older CSP policy.
|
||||||
|
func enhanceCSPPolicy(policy string) string {
|
||||||
|
// Add nonce placeholder to script-src if not present
|
||||||
|
if !strings.Contains(policy, NonceTemplate) && !strings.Contains(policy, "'nonce-") {
|
||||||
|
policy = addToDirective(policy, "script-src", NonceTemplate)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add Cloudflare Insights domain to script-src if not present
|
||||||
|
if !strings.Contains(policy, CloudflareInsightsDomain) {
|
||||||
|
policy = addToDirective(policy, "script-src", CloudflareInsightsDomain)
|
||||||
|
}
|
||||||
|
|
||||||
|
return policy
|
||||||
|
}
|
||||||
|
|
||||||
|
// addToDirective adds a value to a specific CSP directive.
|
||||||
|
// If the directive doesn't exist, it will be added after default-src.
|
||||||
|
func addToDirective(policy, directive, value string) string {
|
||||||
|
// Find the directive in the policy
|
||||||
|
directivePrefix := directive + " "
|
||||||
|
idx := strings.Index(policy, directivePrefix)
|
||||||
|
|
||||||
|
if idx == -1 {
|
||||||
|
// Directive not found, add it after default-src or at the beginning
|
||||||
|
defaultSrcIdx := strings.Index(policy, "default-src ")
|
||||||
|
if defaultSrcIdx != -1 {
|
||||||
|
// Find the end of default-src directive (next semicolon)
|
||||||
|
endIdx := strings.Index(policy[defaultSrcIdx:], ";")
|
||||||
|
if endIdx != -1 {
|
||||||
|
insertPos := defaultSrcIdx + endIdx + 1
|
||||||
|
// Insert new directive after default-src
|
||||||
|
return policy[:insertPos] + " " + directive + " 'self' " + value + ";" + policy[insertPos:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Fallback: prepend the directive
|
||||||
|
return directive + " 'self' " + value + "; " + policy
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the end of this directive (next semicolon or end of string)
|
||||||
|
endIdx := strings.Index(policy[idx:], ";")
|
||||||
|
|
||||||
|
if endIdx == -1 {
|
||||||
|
// No semicolon found, directive goes to end of string
|
||||||
|
return policy + " " + value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert value before the semicolon
|
||||||
|
insertPos := idx + endIdx
|
||||||
|
return policy[:insertPos] + " " + value + policy[insertPos:]
|
||||||
|
}
|
||||||
|
|||||||
365
backend/internal/server/middleware/security_headers_test.go
Normal file
365
backend/internal/server/middleware/security_headers_test.go
Normal file
@@ -0,0 +1,365 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateNonce(t *testing.T) {
|
||||||
|
t.Run("generates_valid_base64_string", func(t *testing.T) {
|
||||||
|
nonce := GenerateNonce()
|
||||||
|
|
||||||
|
// Should be valid base64
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(nonce)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Should decode to 16 bytes
|
||||||
|
assert.Len(t, decoded, 16)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("generates_unique_nonces", func(t *testing.T) {
|
||||||
|
nonces := make(map[string]bool)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
nonce := GenerateNonce()
|
||||||
|
assert.False(t, nonces[nonce], "nonce should be unique")
|
||||||
|
nonces[nonce] = true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nonce_has_expected_length", func(t *testing.T) {
|
||||||
|
nonce := GenerateNonce()
|
||||||
|
// 16 bytes -> 24 chars in base64 (with padding)
|
||||||
|
assert.Len(t, nonce, 24)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetNonceFromContext(t *testing.T) {
|
||||||
|
t.Run("returns_nonce_when_present", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
expectedNonce := "test-nonce-123"
|
||||||
|
c.Set(CSPNonceKey, expectedNonce)
|
||||||
|
|
||||||
|
nonce := GetNonceFromContext(c)
|
||||||
|
assert.Equal(t, expectedNonce, nonce)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns_empty_string_when_not_present", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
nonce := GetNonceFromContext(c)
|
||||||
|
assert.Empty(t, nonce)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns_empty_for_wrong_type", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
// Set a non-string value
|
||||||
|
c.Set(CSPNonceKey, 12345)
|
||||||
|
|
||||||
|
// Should return empty string for wrong type (safe type assertion)
|
||||||
|
nonce := GetNonceFromContext(c)
|
||||||
|
assert.Empty(t, nonce)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSecurityHeaders(t *testing.T) {
|
||||||
|
t.Run("sets_basic_security_headers", func(t *testing.T) {
|
||||||
|
cfg := config.CSPConfig{Enabled: false}
|
||||||
|
middleware := SecurityHeaders(cfg)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
middleware(c)
|
||||||
|
|
||||||
|
assert.Equal(t, "nosniff", w.Header().Get("X-Content-Type-Options"))
|
||||||
|
assert.Equal(t, "DENY", w.Header().Get("X-Frame-Options"))
|
||||||
|
assert.Equal(t, "strict-origin-when-cross-origin", w.Header().Get("Referrer-Policy"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("csp_disabled_no_csp_header", func(t *testing.T) {
|
||||||
|
cfg := config.CSPConfig{Enabled: false}
|
||||||
|
middleware := SecurityHeaders(cfg)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
middleware(c)
|
||||||
|
|
||||||
|
assert.Empty(t, w.Header().Get("Content-Security-Policy"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("csp_enabled_sets_csp_header", func(t *testing.T) {
|
||||||
|
cfg := config.CSPConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Policy: "default-src 'self'",
|
||||||
|
}
|
||||||
|
middleware := SecurityHeaders(cfg)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
middleware(c)
|
||||||
|
|
||||||
|
csp := w.Header().Get("Content-Security-Policy")
|
||||||
|
assert.NotEmpty(t, csp)
|
||||||
|
// Policy is auto-enhanced with nonce and Cloudflare Insights domain
|
||||||
|
assert.Contains(t, csp, "default-src 'self'")
|
||||||
|
assert.Contains(t, csp, "'nonce-")
|
||||||
|
assert.Contains(t, csp, CloudflareInsightsDomain)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("csp_enabled_with_nonce_placeholder", func(t *testing.T) {
|
||||||
|
cfg := config.CSPConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Policy: "script-src 'self' __CSP_NONCE__",
|
||||||
|
}
|
||||||
|
middleware := SecurityHeaders(cfg)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
middleware(c)
|
||||||
|
|
||||||
|
csp := w.Header().Get("Content-Security-Policy")
|
||||||
|
assert.NotEmpty(t, csp)
|
||||||
|
assert.NotContains(t, csp, "__CSP_NONCE__", "placeholder should be replaced")
|
||||||
|
assert.Contains(t, csp, "'nonce-", "should contain nonce directive")
|
||||||
|
|
||||||
|
// Verify nonce is stored in context
|
||||||
|
nonce := GetNonceFromContext(c)
|
||||||
|
assert.NotEmpty(t, nonce)
|
||||||
|
assert.Contains(t, csp, "'nonce-"+nonce+"'")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uses_default_policy_when_empty", func(t *testing.T) {
|
||||||
|
cfg := config.CSPConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Policy: "",
|
||||||
|
}
|
||||||
|
middleware := SecurityHeaders(cfg)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
middleware(c)
|
||||||
|
|
||||||
|
csp := w.Header().Get("Content-Security-Policy")
|
||||||
|
assert.NotEmpty(t, csp)
|
||||||
|
// Default policy should contain these elements
|
||||||
|
assert.Contains(t, csp, "default-src 'self'")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uses_default_policy_when_whitespace_only", func(t *testing.T) {
|
||||||
|
cfg := config.CSPConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Policy: " \t\n ",
|
||||||
|
}
|
||||||
|
middleware := SecurityHeaders(cfg)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
middleware(c)
|
||||||
|
|
||||||
|
csp := w.Header().Get("Content-Security-Policy")
|
||||||
|
assert.NotEmpty(t, csp)
|
||||||
|
assert.Contains(t, csp, "default-src 'self'")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple_nonce_placeholders_replaced", func(t *testing.T) {
|
||||||
|
cfg := config.CSPConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Policy: "script-src __CSP_NONCE__; style-src __CSP_NONCE__",
|
||||||
|
}
|
||||||
|
middleware := SecurityHeaders(cfg)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
middleware(c)
|
||||||
|
|
||||||
|
csp := w.Header().Get("Content-Security-Policy")
|
||||||
|
nonce := GetNonceFromContext(c)
|
||||||
|
|
||||||
|
// Count occurrences of the nonce
|
||||||
|
count := strings.Count(csp, "'nonce-"+nonce+"'")
|
||||||
|
assert.Equal(t, 2, count, "both placeholders should be replaced with same nonce")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("calls_next_handler", func(t *testing.T) {
|
||||||
|
cfg := config.CSPConfig{Enabled: true, Policy: "default-src 'self'"}
|
||||||
|
middleware := SecurityHeaders(cfg)
|
||||||
|
|
||||||
|
nextCalled := false
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(middleware)
|
||||||
|
router.GET("/test", func(c *gin.Context) {
|
||||||
|
nextCalled = true
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.True(t, nextCalled, "next handler should be called")
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nonce_unique_per_request", func(t *testing.T) {
|
||||||
|
cfg := config.CSPConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Policy: "script-src __CSP_NONCE__",
|
||||||
|
}
|
||||||
|
middleware := SecurityHeaders(cfg)
|
||||||
|
|
||||||
|
nonces := make(map[string]bool)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
middleware(c)
|
||||||
|
|
||||||
|
nonce := GetNonceFromContext(c)
|
||||||
|
assert.False(t, nonces[nonce], "nonce should be unique per request")
|
||||||
|
nonces[nonce] = true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCSPNonceKey(t *testing.T) {
|
||||||
|
t.Run("constant_value", func(t *testing.T) {
|
||||||
|
assert.Equal(t, "csp_nonce", CSPNonceKey)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNonceTemplate(t *testing.T) {
|
||||||
|
t.Run("constant_value", func(t *testing.T) {
|
||||||
|
assert.Equal(t, "__CSP_NONCE__", NonceTemplate)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnhanceCSPPolicy(t *testing.T) {
|
||||||
|
t.Run("adds_nonce_placeholder_if_missing", func(t *testing.T) {
|
||||||
|
policy := "default-src 'self'; script-src 'self'"
|
||||||
|
enhanced := enhanceCSPPolicy(policy)
|
||||||
|
|
||||||
|
assert.Contains(t, enhanced, NonceTemplate)
|
||||||
|
assert.Contains(t, enhanced, CloudflareInsightsDomain)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("does_not_duplicate_nonce_placeholder", func(t *testing.T) {
|
||||||
|
policy := "default-src 'self'; script-src 'self' __CSP_NONCE__"
|
||||||
|
enhanced := enhanceCSPPolicy(policy)
|
||||||
|
|
||||||
|
// Should not duplicate
|
||||||
|
count := strings.Count(enhanced, NonceTemplate)
|
||||||
|
assert.Equal(t, 1, count)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("does_not_duplicate_cloudflare_domain", func(t *testing.T) {
|
||||||
|
policy := "default-src 'self'; script-src 'self' https://static.cloudflareinsights.com"
|
||||||
|
enhanced := enhanceCSPPolicy(policy)
|
||||||
|
|
||||||
|
count := strings.Count(enhanced, CloudflareInsightsDomain)
|
||||||
|
assert.Equal(t, 1, count)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles_policy_without_script_src", func(t *testing.T) {
|
||||||
|
policy := "default-src 'self'"
|
||||||
|
enhanced := enhanceCSPPolicy(policy)
|
||||||
|
|
||||||
|
assert.Contains(t, enhanced, "script-src")
|
||||||
|
assert.Contains(t, enhanced, NonceTemplate)
|
||||||
|
assert.Contains(t, enhanced, CloudflareInsightsDomain)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserves_existing_nonce", func(t *testing.T) {
|
||||||
|
policy := "script-src 'self' 'nonce-existing'"
|
||||||
|
enhanced := enhanceCSPPolicy(policy)
|
||||||
|
|
||||||
|
// Should not add placeholder if nonce already exists
|
||||||
|
assert.NotContains(t, enhanced, NonceTemplate)
|
||||||
|
assert.Contains(t, enhanced, "'nonce-existing'")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddToDirective(t *testing.T) {
|
||||||
|
t.Run("adds_to_existing_directive", func(t *testing.T) {
|
||||||
|
policy := "script-src 'self'; style-src 'self'"
|
||||||
|
result := addToDirective(policy, "script-src", "https://example.com")
|
||||||
|
|
||||||
|
assert.Contains(t, result, "script-src 'self' https://example.com")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("creates_directive_if_not_exists", func(t *testing.T) {
|
||||||
|
policy := "default-src 'self'"
|
||||||
|
result := addToDirective(policy, "script-src", "https://example.com")
|
||||||
|
|
||||||
|
assert.Contains(t, result, "script-src")
|
||||||
|
assert.Contains(t, result, "https://example.com")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles_directive_at_end_without_semicolon", func(t *testing.T) {
|
||||||
|
policy := "default-src 'self'; script-src 'self'"
|
||||||
|
result := addToDirective(policy, "script-src", "https://example.com")
|
||||||
|
|
||||||
|
assert.Contains(t, result, "https://example.com")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles_empty_policy", func(t *testing.T) {
|
||||||
|
policy := ""
|
||||||
|
result := addToDirective(policy, "script-src", "https://example.com")
|
||||||
|
|
||||||
|
assert.Contains(t, result, "script-src")
|
||||||
|
assert.Contains(t, result, "https://example.com")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Benchmark tests
|
||||||
|
func BenchmarkGenerateNonce(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
GenerateNonce()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkSecurityHeadersMiddleware(b *testing.B) {
|
||||||
|
cfg := config.CSPConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Policy: "script-src 'self' __CSP_NONCE__",
|
||||||
|
}
|
||||||
|
middleware := SecurityHeaders(cfg)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
middleware(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -81,6 +81,9 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
ops.PUT("/alert-rules/:id", h.Admin.Ops.UpdateAlertRule)
|
ops.PUT("/alert-rules/:id", h.Admin.Ops.UpdateAlertRule)
|
||||||
ops.DELETE("/alert-rules/:id", h.Admin.Ops.DeleteAlertRule)
|
ops.DELETE("/alert-rules/:id", h.Admin.Ops.DeleteAlertRule)
|
||||||
ops.GET("/alert-events", h.Admin.Ops.ListAlertEvents)
|
ops.GET("/alert-events", h.Admin.Ops.ListAlertEvents)
|
||||||
|
ops.GET("/alert-events/:id", h.Admin.Ops.GetAlertEvent)
|
||||||
|
ops.PUT("/alert-events/:id/status", h.Admin.Ops.UpdateAlertEventStatus)
|
||||||
|
ops.POST("/alert-silences", h.Admin.Ops.CreateAlertSilence)
|
||||||
|
|
||||||
// Email notification config (DB-backed)
|
// Email notification config (DB-backed)
|
||||||
ops.GET("/email-notification/config", h.Admin.Ops.GetEmailNotificationConfig)
|
ops.GET("/email-notification/config", h.Admin.Ops.GetEmailNotificationConfig)
|
||||||
@@ -110,10 +113,26 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
ws.GET("/qps", h.Admin.Ops.QPSWSHandler)
|
ws.GET("/qps", h.Admin.Ops.QPSWSHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error logs (MVP-1)
|
// Error logs (legacy)
|
||||||
ops.GET("/errors", h.Admin.Ops.GetErrorLogs)
|
ops.GET("/errors", h.Admin.Ops.GetErrorLogs)
|
||||||
ops.GET("/errors/:id", h.Admin.Ops.GetErrorLogByID)
|
ops.GET("/errors/:id", h.Admin.Ops.GetErrorLogByID)
|
||||||
|
ops.GET("/errors/:id/retries", h.Admin.Ops.ListRetryAttempts)
|
||||||
ops.POST("/errors/:id/retry", h.Admin.Ops.RetryErrorRequest)
|
ops.POST("/errors/:id/retry", h.Admin.Ops.RetryErrorRequest)
|
||||||
|
ops.PUT("/errors/:id/resolve", h.Admin.Ops.UpdateErrorResolution)
|
||||||
|
|
||||||
|
// Request errors (client-visible failures)
|
||||||
|
ops.GET("/request-errors", h.Admin.Ops.ListRequestErrors)
|
||||||
|
ops.GET("/request-errors/:id", h.Admin.Ops.GetRequestError)
|
||||||
|
ops.GET("/request-errors/:id/upstream-errors", h.Admin.Ops.ListRequestErrorUpstreamErrors)
|
||||||
|
ops.POST("/request-errors/:id/retry-client", h.Admin.Ops.RetryRequestErrorClient)
|
||||||
|
ops.POST("/request-errors/:id/upstream-errors/:idx/retry", h.Admin.Ops.RetryRequestErrorUpstreamEvent)
|
||||||
|
ops.PUT("/request-errors/:id/resolve", h.Admin.Ops.ResolveRequestError)
|
||||||
|
|
||||||
|
// Upstream errors (independent upstream failures)
|
||||||
|
ops.GET("/upstream-errors", h.Admin.Ops.ListUpstreamErrors)
|
||||||
|
ops.GET("/upstream-errors/:id", h.Admin.Ops.GetUpstreamError)
|
||||||
|
ops.POST("/upstream-errors/:id/retry", h.Admin.Ops.RetryUpstreamError)
|
||||||
|
ops.PUT("/upstream-errors/:id/resolve", h.Admin.Ops.ResolveUpstreamError)
|
||||||
|
|
||||||
// Request drilldown (success + error)
|
// Request drilldown (success + error)
|
||||||
ops.GET("/requests", h.Admin.Ops.ListRequestDetails)
|
ops.GET("/requests", h.Admin.Ops.ListRequestDetails)
|
||||||
@@ -250,6 +269,7 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
proxies.POST("/:id/test", h.Admin.Proxy.Test)
|
proxies.POST("/:id/test", h.Admin.Proxy.Test)
|
||||||
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
|
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
|
||||||
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
|
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
|
||||||
|
proxies.POST("/batch-delete", h.Admin.Proxy.BatchDelete)
|
||||||
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
|
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ type Account struct {
|
|||||||
ProxyID *int64
|
ProxyID *int64
|
||||||
Concurrency int
|
Concurrency int
|
||||||
Priority int
|
Priority int
|
||||||
|
// RateMultiplier 账号计费倍率(>=0,允许 0 表示该账号计费为 0)。
|
||||||
|
// 使用指针用于兼容旧版本调度缓存(Redis)中缺字段的情况:nil 表示按 1.0 处理。
|
||||||
|
RateMultiplier *float64
|
||||||
Status string
|
Status string
|
||||||
ErrorMessage string
|
ErrorMessage string
|
||||||
LastUsedAt *time.Time
|
LastUsedAt *time.Time
|
||||||
@@ -57,6 +60,20 @@ func (a *Account) IsActive() bool {
|
|||||||
return a.Status == StatusActive
|
return a.Status == StatusActive
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BillingRateMultiplier 返回账号计费倍率。
|
||||||
|
// - nil 表示未配置/旧缓存缺字段,按 1.0 处理
|
||||||
|
// - 允许 0,表示该账号计费为 0
|
||||||
|
// - 负数属于非法数据,出于安全考虑按 1.0 处理
|
||||||
|
func (a *Account) BillingRateMultiplier() float64 {
|
||||||
|
if a == nil || a.RateMultiplier == nil {
|
||||||
|
return 1.0
|
||||||
|
}
|
||||||
|
if *a.RateMultiplier < 0 {
|
||||||
|
return 1.0
|
||||||
|
}
|
||||||
|
return *a.RateMultiplier
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Account) IsSchedulable() bool {
|
func (a *Account) IsSchedulable() bool {
|
||||||
if !a.IsActive() || !a.Schedulable {
|
if !a.IsActive() || !a.Schedulable {
|
||||||
return false
|
return false
|
||||||
@@ -556,3 +573,141 @@ func (a *Account) IsMixedSchedulingEnabled() bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WindowCostSchedulability 窗口费用调度状态
|
||||||
|
type WindowCostSchedulability int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// WindowCostSchedulable 可正常调度
|
||||||
|
WindowCostSchedulable WindowCostSchedulability = iota
|
||||||
|
// WindowCostStickyOnly 仅允许粘性会话
|
||||||
|
WindowCostStickyOnly
|
||||||
|
// WindowCostNotSchedulable 完全不可调度
|
||||||
|
WindowCostNotSchedulable
|
||||||
|
)
|
||||||
|
|
||||||
|
// IsAnthropicOAuthOrSetupToken 判断是否为 Anthropic OAuth 或 SetupToken 类型账号
|
||||||
|
// 仅这两类账号支持 5h 窗口额度控制和会话数量控制
|
||||||
|
func (a *Account) IsAnthropicOAuthOrSetupToken() bool {
|
||||||
|
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
|
||||||
|
// 返回 0 表示未启用
|
||||||
|
func (a *Account) GetWindowCostLimit() float64 {
|
||||||
|
if a.Extra == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if v, ok := a.Extra["window_cost_limit"]; ok {
|
||||||
|
return parseExtraFloat64(v)
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWindowCostStickyReserve 获取粘性会话预留额度(美元)
|
||||||
|
// 默认值为 10
|
||||||
|
func (a *Account) GetWindowCostStickyReserve() float64 {
|
||||||
|
if a.Extra == nil {
|
||||||
|
return 10.0
|
||||||
|
}
|
||||||
|
if v, ok := a.Extra["window_cost_sticky_reserve"]; ok {
|
||||||
|
val := parseExtraFloat64(v)
|
||||||
|
if val > 0 {
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 10.0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMaxSessions 获取最大并发会话数
|
||||||
|
// 返回 0 表示未启用
|
||||||
|
func (a *Account) GetMaxSessions() int {
|
||||||
|
if a.Extra == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if v, ok := a.Extra["max_sessions"]; ok {
|
||||||
|
return parseExtraInt(v)
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSessionIdleTimeoutMinutes 获取会话空闲超时分钟数
|
||||||
|
// 默认值为 5 分钟
|
||||||
|
func (a *Account) GetSessionIdleTimeoutMinutes() int {
|
||||||
|
if a.Extra == nil {
|
||||||
|
return 5
|
||||||
|
}
|
||||||
|
if v, ok := a.Extra["session_idle_timeout_minutes"]; ok {
|
||||||
|
val := parseExtraInt(v)
|
||||||
|
if val > 0 {
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 5
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckWindowCostSchedulability 根据当前窗口费用检查调度状态
|
||||||
|
// - 费用 < 阈值: WindowCostSchedulable(可正常调度)
|
||||||
|
// - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly(仅粘性会话)
|
||||||
|
// - 费用 >= 阈值+预留: WindowCostNotSchedulable(不可调度)
|
||||||
|
func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) WindowCostSchedulability {
|
||||||
|
limit := a.GetWindowCostLimit()
|
||||||
|
if limit <= 0 {
|
||||||
|
return WindowCostSchedulable
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentWindowCost < limit {
|
||||||
|
return WindowCostSchedulable
|
||||||
|
}
|
||||||
|
|
||||||
|
stickyReserve := a.GetWindowCostStickyReserve()
|
||||||
|
if currentWindowCost < limit+stickyReserve {
|
||||||
|
return WindowCostStickyOnly
|
||||||
|
}
|
||||||
|
|
||||||
|
return WindowCostNotSchedulable
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseExtraFloat64 从 extra 字段解析 float64 值
|
||||||
|
func parseExtraFloat64(value any) float64 {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case float64:
|
||||||
|
return v
|
||||||
|
case float32:
|
||||||
|
return float64(v)
|
||||||
|
case int:
|
||||||
|
return float64(v)
|
||||||
|
case int64:
|
||||||
|
return float64(v)
|
||||||
|
case json.Number:
|
||||||
|
if f, err := v.Float64(); err == nil {
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
if f, err := strconv.ParseFloat(strings.TrimSpace(v), 64); err == nil {
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseExtraInt 从 extra 字段解析 int 值
|
||||||
|
func parseExtraInt(value any) int {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case int:
|
||||||
|
return v
|
||||||
|
case int64:
|
||||||
|
return int(v)
|
||||||
|
case float64:
|
||||||
|
return int(v)
|
||||||
|
case json.Number:
|
||||||
|
if i, err := v.Int64(); err == nil {
|
||||||
|
return int(i)
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,27 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAccount_BillingRateMultiplier_DefaultsToOneWhenNil(t *testing.T) {
|
||||||
|
var a Account
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(`{"id":1,"name":"acc","status":"active"}`), &a))
|
||||||
|
require.Nil(t, a.RateMultiplier)
|
||||||
|
require.Equal(t, 1.0, a.BillingRateMultiplier())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccount_BillingRateMultiplier_AllowsZero(t *testing.T) {
|
||||||
|
v := 0.0
|
||||||
|
a := Account{RateMultiplier: &v}
|
||||||
|
require.Equal(t, 0.0, a.BillingRateMultiplier())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccount_BillingRateMultiplier_NegativeFallsBackToOne(t *testing.T) {
|
||||||
|
v := -1.0
|
||||||
|
a := Account{RateMultiplier: &v}
|
||||||
|
require.Equal(t, 1.0, a.BillingRateMultiplier())
|
||||||
|
}
|
||||||
@@ -50,11 +50,13 @@ type AccountRepository interface {
|
|||||||
|
|
||||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||||
SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error
|
SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error
|
||||||
|
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error
|
||||||
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
||||||
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
|
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
|
||||||
ClearTempUnschedulable(ctx context.Context, id int64) error
|
ClearTempUnschedulable(ctx context.Context, id int64) error
|
||||||
ClearRateLimit(ctx context.Context, id int64) error
|
ClearRateLimit(ctx context.Context, id int64) error
|
||||||
ClearAntigravityQuotaScopes(ctx context.Context, id int64) error
|
ClearAntigravityQuotaScopes(ctx context.Context, id int64) error
|
||||||
|
ClearModelRateLimits(ctx context.Context, id int64) error
|
||||||
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
||||||
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
|
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
|
||||||
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
|
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
|
||||||
@@ -67,6 +69,7 @@ type AccountBulkUpdate struct {
|
|||||||
ProxyID *int64
|
ProxyID *int64
|
||||||
Concurrency *int
|
Concurrency *int
|
||||||
Priority *int
|
Priority *int
|
||||||
|
RateMultiplier *float64
|
||||||
Status *string
|
Status *string
|
||||||
Schedulable *bool
|
Schedulable *bool
|
||||||
Credentials map[string]any
|
Credentials map[string]any
|
||||||
|
|||||||
@@ -143,6 +143,10 @@ func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id
|
|||||||
panic("unexpected SetAntigravityQuotaScopeLimit call")
|
panic("unexpected SetAntigravityQuotaScopeLimit call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||||
|
panic("unexpected SetModelRateLimit call")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||||
panic("unexpected SetOverloaded call")
|
panic("unexpected SetOverloaded call")
|
||||||
}
|
}
|
||||||
@@ -163,6 +167,10 @@ func (s *accountRepoStub) ClearAntigravityQuotaScopes(ctx context.Context, id in
|
|||||||
panic("unexpected ClearAntigravityQuotaScopes call")
|
panic("unexpected ClearAntigravityQuotaScopes call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *accountRepoStub) ClearModelRateLimits(ctx context.Context, id int64) error {
|
||||||
|
panic("unexpected ClearModelRateLimits call")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *accountRepoStub) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
func (s *accountRepoStub) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||||
panic("unexpected UpdateSessionWindow call")
|
panic("unexpected UpdateSessionWindow call")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,8 +32,8 @@ type UsageLogRepository interface {
|
|||||||
|
|
||||||
// Admin dashboard stats
|
// Admin dashboard stats
|
||||||
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
||||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
|
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error)
|
||||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
|
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error)
|
||||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
|
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||||
@@ -96,10 +96,16 @@ func NewUsageCache() *UsageCache {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// WindowStats 窗口期统计
|
// WindowStats 窗口期统计
|
||||||
|
//
|
||||||
|
// cost: 账号口径费用(total_cost * account_rate_multiplier)
|
||||||
|
// standard_cost: 标准费用(total_cost,不含倍率)
|
||||||
|
// user_cost: 用户/API Key 口径费用(actual_cost,受分组倍率影响)
|
||||||
type WindowStats struct {
|
type WindowStats struct {
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
Tokens int64 `json:"tokens"`
|
Tokens int64 `json:"tokens"`
|
||||||
Cost float64 `json:"cost"`
|
Cost float64 `json:"cost"`
|
||||||
|
StandardCost float64 `json:"standard_cost"`
|
||||||
|
UserCost float64 `json:"user_cost"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UsageProgress 使用量进度
|
// UsageProgress 使用量进度
|
||||||
@@ -266,7 +272,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
|
|||||||
}
|
}
|
||||||
|
|
||||||
dayStart := geminiDailyWindowStart(now)
|
dayStart := geminiDailyWindowStart(now)
|
||||||
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID)
|
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
|
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -288,7 +294,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
|
|||||||
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
|
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
|
||||||
minuteStart := now.Truncate(time.Minute)
|
minuteStart := now.Truncate(time.Minute)
|
||||||
minuteResetAt := minuteStart.Add(time.Minute)
|
minuteResetAt := minuteStart.Add(time.Minute)
|
||||||
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID)
|
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
|
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -380,6 +386,8 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
|
|||||||
Requests: stats.Requests,
|
Requests: stats.Requests,
|
||||||
Tokens: stats.Tokens,
|
Tokens: stats.Tokens,
|
||||||
Cost: stats.Cost,
|
Cost: stats.Cost,
|
||||||
|
StandardCost: stats.StandardCost,
|
||||||
|
UserCost: stats.UserCost,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 缓存窗口统计(1 分钟)
|
// 缓存窗口统计(1 分钟)
|
||||||
@@ -406,6 +414,8 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
|
|||||||
Requests: stats.Requests,
|
Requests: stats.Requests,
|
||||||
Tokens: stats.Tokens,
|
Tokens: stats.Tokens,
|
||||||
Cost: stats.Cost,
|
Cost: stats.Cost,
|
||||||
|
StandardCost: stats.StandardCost,
|
||||||
|
UserCost: stats.UserCost,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -565,3 +575,9 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAccountWindowStats 获取账号在指定时间窗口内的使用统计
|
||||||
|
// 用于账号列表页面显示当前窗口费用
|
||||||
|
func (s *AccountUsageService) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
|
||||||
|
return s.usageLogRepo.GetAccountWindowStats(ctx, accountID, startTime)
|
||||||
|
}
|
||||||
|
|||||||
@@ -54,7 +54,8 @@ type AdminService interface {
|
|||||||
CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error)
|
CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error)
|
||||||
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error)
|
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error)
|
||||||
DeleteProxy(ctx context.Context, id int64) error
|
DeleteProxy(ctx context.Context, id int64) error
|
||||||
GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error)
|
BatchDeleteProxies(ctx context.Context, ids []int64) (*ProxyBatchDeleteResult, error)
|
||||||
|
GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error)
|
||||||
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
|
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
|
||||||
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
|
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
|
||||||
|
|
||||||
@@ -105,6 +106,9 @@ type CreateGroupInput struct {
|
|||||||
ImagePrice4K *float64
|
ImagePrice4K *float64
|
||||||
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
|
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
|
||||||
FallbackGroupID *int64 // 降级分组 ID
|
FallbackGroupID *int64 // 降级分组 ID
|
||||||
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
|
ModelRouting map[string][]int64
|
||||||
|
ModelRoutingEnabled bool // 是否启用模型路由
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpdateGroupInput struct {
|
type UpdateGroupInput struct {
|
||||||
@@ -124,6 +128,9 @@ type UpdateGroupInput struct {
|
|||||||
ImagePrice4K *float64
|
ImagePrice4K *float64
|
||||||
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
|
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
|
||||||
FallbackGroupID *int64 // 降级分组 ID
|
FallbackGroupID *int64 // 降级分组 ID
|
||||||
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
|
ModelRouting map[string][]int64
|
||||||
|
ModelRoutingEnabled *bool // 是否启用模型路由
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateAccountInput struct {
|
type CreateAccountInput struct {
|
||||||
@@ -136,6 +143,7 @@ type CreateAccountInput struct {
|
|||||||
ProxyID *int64
|
ProxyID *int64
|
||||||
Concurrency int
|
Concurrency int
|
||||||
Priority int
|
Priority int
|
||||||
|
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
|
||||||
GroupIDs []int64
|
GroupIDs []int64
|
||||||
ExpiresAt *int64
|
ExpiresAt *int64
|
||||||
AutoPauseOnExpired *bool
|
AutoPauseOnExpired *bool
|
||||||
@@ -153,6 +161,7 @@ type UpdateAccountInput struct {
|
|||||||
ProxyID *int64
|
ProxyID *int64
|
||||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||||
Priority *int // 使用指针区分"未提供"和"设置为0"
|
Priority *int // 使用指针区分"未提供"和"设置为0"
|
||||||
|
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
|
||||||
Status string
|
Status string
|
||||||
GroupIDs *[]int64
|
GroupIDs *[]int64
|
||||||
ExpiresAt *int64
|
ExpiresAt *int64
|
||||||
@@ -167,6 +176,7 @@ type BulkUpdateAccountsInput struct {
|
|||||||
ProxyID *int64
|
ProxyID *int64
|
||||||
Concurrency *int
|
Concurrency *int
|
||||||
Priority *int
|
Priority *int
|
||||||
|
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
|
||||||
Status string
|
Status string
|
||||||
Schedulable *bool
|
Schedulable *bool
|
||||||
GroupIDs *[]int64
|
GroupIDs *[]int64
|
||||||
@@ -220,6 +230,16 @@ type GenerateRedeemCodesInput struct {
|
|||||||
ValidityDays int // 订阅类型专用:有效天数
|
ValidityDays int // 订阅类型专用:有效天数
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ProxyBatchDeleteResult struct {
|
||||||
|
DeletedIDs []int64 `json:"deleted_ids"`
|
||||||
|
Skipped []ProxyBatchDeleteSkipped `json:"skipped"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProxyBatchDeleteSkipped struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
}
|
||||||
|
|
||||||
// ProxyTestResult represents the result of testing a proxy
|
// ProxyTestResult represents the result of testing a proxy
|
||||||
type ProxyTestResult struct {
|
type ProxyTestResult struct {
|
||||||
Success bool `json:"success"`
|
Success bool `json:"success"`
|
||||||
@@ -229,14 +249,16 @@ type ProxyTestResult struct {
|
|||||||
City string `json:"city,omitempty"`
|
City string `json:"city,omitempty"`
|
||||||
Region string `json:"region,omitempty"`
|
Region string `json:"region,omitempty"`
|
||||||
Country string `json:"country,omitempty"`
|
Country string `json:"country,omitempty"`
|
||||||
|
CountryCode string `json:"country_code,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProxyExitInfo represents proxy exit information from ipinfo.io
|
// ProxyExitInfo represents proxy exit information from ip-api.com
|
||||||
type ProxyExitInfo struct {
|
type ProxyExitInfo struct {
|
||||||
IP string
|
IP string
|
||||||
City string
|
City string
|
||||||
Region string
|
Region string
|
||||||
Country string
|
Country string
|
||||||
|
CountryCode string
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProxyExitInfoProber tests proxy connectivity and retrieves exit information
|
// ProxyExitInfoProber tests proxy connectivity and retrieves exit information
|
||||||
@@ -254,6 +276,7 @@ type adminServiceImpl struct {
|
|||||||
redeemCodeRepo RedeemCodeRepository
|
redeemCodeRepo RedeemCodeRepository
|
||||||
billingCacheService *BillingCacheService
|
billingCacheService *BillingCacheService
|
||||||
proxyProber ProxyExitInfoProber
|
proxyProber ProxyExitInfoProber
|
||||||
|
proxyLatencyCache ProxyLatencyCache
|
||||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -267,6 +290,7 @@ func NewAdminService(
|
|||||||
redeemCodeRepo RedeemCodeRepository,
|
redeemCodeRepo RedeemCodeRepository,
|
||||||
billingCacheService *BillingCacheService,
|
billingCacheService *BillingCacheService,
|
||||||
proxyProber ProxyExitInfoProber,
|
proxyProber ProxyExitInfoProber,
|
||||||
|
proxyLatencyCache ProxyLatencyCache,
|
||||||
authCacheInvalidator APIKeyAuthCacheInvalidator,
|
authCacheInvalidator APIKeyAuthCacheInvalidator,
|
||||||
) AdminService {
|
) AdminService {
|
||||||
return &adminServiceImpl{
|
return &adminServiceImpl{
|
||||||
@@ -278,6 +302,7 @@ func NewAdminService(
|
|||||||
redeemCodeRepo: redeemCodeRepo,
|
redeemCodeRepo: redeemCodeRepo,
|
||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
proxyProber: proxyProber,
|
proxyProber: proxyProber,
|
||||||
|
proxyLatencyCache: proxyLatencyCache,
|
||||||
authCacheInvalidator: authCacheInvalidator,
|
authCacheInvalidator: authCacheInvalidator,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -562,6 +587,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
|||||||
ImagePrice4K: imagePrice4K,
|
ImagePrice4K: imagePrice4K,
|
||||||
ClaudeCodeOnly: input.ClaudeCodeOnly,
|
ClaudeCodeOnly: input.ClaudeCodeOnly,
|
||||||
FallbackGroupID: input.FallbackGroupID,
|
FallbackGroupID: input.FallbackGroupID,
|
||||||
|
ModelRouting: input.ModelRouting,
|
||||||
}
|
}
|
||||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -690,6 +716,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 模型路由配置
|
||||||
|
if input.ModelRouting != nil {
|
||||||
|
group.ModelRouting = input.ModelRouting
|
||||||
|
}
|
||||||
|
if input.ModelRoutingEnabled != nil {
|
||||||
|
group.ModelRoutingEnabled = *input.ModelRoutingEnabled
|
||||||
|
}
|
||||||
|
|
||||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -817,6 +851,12 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
|||||||
} else {
|
} else {
|
||||||
account.AutoPauseOnExpired = true
|
account.AutoPauseOnExpired = true
|
||||||
}
|
}
|
||||||
|
if input.RateMultiplier != nil {
|
||||||
|
if *input.RateMultiplier < 0 {
|
||||||
|
return nil, errors.New("rate_multiplier must be >= 0")
|
||||||
|
}
|
||||||
|
account.RateMultiplier = input.RateMultiplier
|
||||||
|
}
|
||||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -869,6 +909,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
|||||||
if input.Priority != nil {
|
if input.Priority != nil {
|
||||||
account.Priority = *input.Priority
|
account.Priority = *input.Priority
|
||||||
}
|
}
|
||||||
|
if input.RateMultiplier != nil {
|
||||||
|
if *input.RateMultiplier < 0 {
|
||||||
|
return nil, errors.New("rate_multiplier must be >= 0")
|
||||||
|
}
|
||||||
|
account.RateMultiplier = input.RateMultiplier
|
||||||
|
}
|
||||||
if input.Status != "" {
|
if input.Status != "" {
|
||||||
account.Status = input.Status
|
account.Status = input.Status
|
||||||
}
|
}
|
||||||
@@ -942,6 +988,12 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.RateMultiplier != nil {
|
||||||
|
if *input.RateMultiplier < 0 {
|
||||||
|
return nil, errors.New("rate_multiplier must be >= 0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Prepare bulk updates for columns and JSONB fields.
|
// Prepare bulk updates for columns and JSONB fields.
|
||||||
repoUpdates := AccountBulkUpdate{
|
repoUpdates := AccountBulkUpdate{
|
||||||
Credentials: input.Credentials,
|
Credentials: input.Credentials,
|
||||||
@@ -959,6 +1011,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
|||||||
if input.Priority != nil {
|
if input.Priority != nil {
|
||||||
repoUpdates.Priority = input.Priority
|
repoUpdates.Priority = input.Priority
|
||||||
}
|
}
|
||||||
|
if input.RateMultiplier != nil {
|
||||||
|
repoUpdates.RateMultiplier = input.RateMultiplier
|
||||||
|
}
|
||||||
if input.Status != "" {
|
if input.Status != "" {
|
||||||
repoUpdates.Status = &input.Status
|
repoUpdates.Status = &input.Status
|
||||||
}
|
}
|
||||||
@@ -1069,6 +1124,7 @@ func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
s.attachProxyLatency(ctx, proxies)
|
||||||
return proxies, result.Total, nil
|
return proxies, result.Total, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1077,7 +1133,12 @@ func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) {
|
func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) {
|
||||||
return s.proxyRepo.ListActiveWithAccountCount(ctx)
|
proxies, err := s.proxyRepo.ListActiveWithAccountCount(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.attachProxyLatency(ctx, proxies)
|
||||||
|
return proxies, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) {
|
func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) {
|
||||||
@@ -1097,6 +1158,8 @@ func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyIn
|
|||||||
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
|
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
// Probe latency asynchronously so creation isn't blocked by network timeout.
|
||||||
|
go s.probeProxyLatency(context.Background(), proxy)
|
||||||
return proxy, nil
|
return proxy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1135,12 +1198,53 @@ func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *Upd
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error {
|
func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error {
|
||||||
|
count, err := s.proxyRepo.CountAccountsByProxyID(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count > 0 {
|
||||||
|
return ErrProxyInUse
|
||||||
|
}
|
||||||
return s.proxyRepo.Delete(ctx, id)
|
return s.proxyRepo.Delete(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error) {
|
func (s *adminServiceImpl) BatchDeleteProxies(ctx context.Context, ids []int64) (*ProxyBatchDeleteResult, error) {
|
||||||
// Return mock data for now - would need a dedicated repository method
|
result := &ProxyBatchDeleteResult{}
|
||||||
return []Account{}, 0, nil
|
if len(ids) == 0 {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, id := range ids {
|
||||||
|
count, err := s.proxyRepo.CountAccountsByProxyID(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
|
||||||
|
ID: id,
|
||||||
|
Reason: err.Error(),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if count > 0 {
|
||||||
|
result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
|
||||||
|
ID: id,
|
||||||
|
Reason: ErrProxyInUse.Error(),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := s.proxyRepo.Delete(ctx, id); err != nil {
|
||||||
|
result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
|
||||||
|
ID: id,
|
||||||
|
Reason: err.Error(),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result.DeletedIDs = append(result.DeletedIDs, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) {
|
||||||
|
return s.proxyRepo.ListAccountSummariesByProxyID(ctx, proxyID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
||||||
@@ -1240,12 +1344,29 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
|
|||||||
proxyURL := proxy.URL()
|
proxyURL := proxy.URL()
|
||||||
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
|
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{
|
||||||
|
Success: false,
|
||||||
|
Message: err.Error(),
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
})
|
||||||
return &ProxyTestResult{
|
return &ProxyTestResult{
|
||||||
Success: false,
|
Success: false,
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
latency := latencyMs
|
||||||
|
s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{
|
||||||
|
Success: true,
|
||||||
|
LatencyMs: &latency,
|
||||||
|
Message: "Proxy is accessible",
|
||||||
|
IPAddress: exitInfo.IP,
|
||||||
|
Country: exitInfo.Country,
|
||||||
|
CountryCode: exitInfo.CountryCode,
|
||||||
|
Region: exitInfo.Region,
|
||||||
|
City: exitInfo.City,
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
})
|
||||||
return &ProxyTestResult{
|
return &ProxyTestResult{
|
||||||
Success: true,
|
Success: true,
|
||||||
Message: "Proxy is accessible",
|
Message: "Proxy is accessible",
|
||||||
@@ -1254,9 +1375,38 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
|
|||||||
City: exitInfo.City,
|
City: exitInfo.City,
|
||||||
Region: exitInfo.Region,
|
Region: exitInfo.Region,
|
||||||
Country: exitInfo.Country,
|
Country: exitInfo.Country,
|
||||||
|
CountryCode: exitInfo.CountryCode,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) {
|
||||||
|
if s.proxyProber == nil || proxy == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxy.URL())
|
||||||
|
if err != nil {
|
||||||
|
s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{
|
||||||
|
Success: false,
|
||||||
|
Message: err.Error(),
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
latency := latencyMs
|
||||||
|
s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{
|
||||||
|
Success: true,
|
||||||
|
LatencyMs: &latency,
|
||||||
|
Message: "Proxy is accessible",
|
||||||
|
IPAddress: exitInfo.IP,
|
||||||
|
Country: exitInfo.Country,
|
||||||
|
CountryCode: exitInfo.CountryCode,
|
||||||
|
Region: exitInfo.Region,
|
||||||
|
City: exitInfo.City,
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// checkMixedChannelRisk 检查分组中是否存在混合渠道(Antigravity + Anthropic)
|
// checkMixedChannelRisk 检查分组中是否存在混合渠道(Antigravity + Anthropic)
|
||||||
// 如果存在混合,返回错误提示用户确认
|
// 如果存在混合,返回错误提示用户确认
|
||||||
func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
|
func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
|
||||||
@@ -1306,6 +1456,51 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) {
|
||||||
|
if s.proxyLatencyCache == nil || len(proxies) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ids := make([]int64, 0, len(proxies))
|
||||||
|
for i := range proxies {
|
||||||
|
ids = append(ids, proxies[i].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, ids)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Warning: load proxy latency cache failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range proxies {
|
||||||
|
info := latencies[proxies[i].ID]
|
||||||
|
if info == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if info.Success {
|
||||||
|
proxies[i].LatencyStatus = "success"
|
||||||
|
proxies[i].LatencyMs = info.LatencyMs
|
||||||
|
} else {
|
||||||
|
proxies[i].LatencyStatus = "failed"
|
||||||
|
}
|
||||||
|
proxies[i].LatencyMessage = info.Message
|
||||||
|
proxies[i].IPAddress = info.IPAddress
|
||||||
|
proxies[i].Country = info.Country
|
||||||
|
proxies[i].CountryCode = info.CountryCode
|
||||||
|
proxies[i].Region = info.Region
|
||||||
|
proxies[i].City = info.City
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64, info *ProxyLatencyInfo) {
|
||||||
|
if s.proxyLatencyCache == nil || info == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, info); err != nil {
|
||||||
|
log.Printf("Warning: store proxy latency cache failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识
|
// getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识
|
||||||
func getAccountPlatform(accountPlatform string) string {
|
func getAccountPlatform(accountPlatform string) string {
|
||||||
switch strings.ToLower(strings.TrimSpace(accountPlatform)) {
|
switch strings.ToLower(strings.TrimSpace(accountPlatform)) {
|
||||||
|
|||||||
@@ -154,6 +154,8 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI
|
|||||||
|
|
||||||
type proxyRepoStub struct {
|
type proxyRepoStub struct {
|
||||||
deleteErr error
|
deleteErr error
|
||||||
|
countErr error
|
||||||
|
accountCount int64
|
||||||
deletedIDs []int64
|
deletedIDs []int64
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -199,7 +201,14 @@ func (s *proxyRepoStub) ExistsByHostPortAuth(ctx context.Context, host string, p
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *proxyRepoStub) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
func (s *proxyRepoStub) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
||||||
panic("unexpected CountAccountsByProxyID call")
|
if s.countErr != nil {
|
||||||
|
return 0, s.countErr
|
||||||
|
}
|
||||||
|
return s.accountCount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *proxyRepoStub) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) {
|
||||||
|
panic("unexpected ListAccountSummariesByProxyID call")
|
||||||
}
|
}
|
||||||
|
|
||||||
type redeemRepoStub struct {
|
type redeemRepoStub struct {
|
||||||
@@ -409,6 +418,15 @@ func TestAdminService_DeleteProxy_Idempotent(t *testing.T) {
|
|||||||
require.Equal(t, []int64{404}, repo.deletedIDs)
|
require.Equal(t, []int64{404}, repo.deletedIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAdminService_DeleteProxy_InUse(t *testing.T) {
|
||||||
|
repo := &proxyRepoStub{accountCount: 2}
|
||||||
|
svc := &adminServiceImpl{proxyRepo: repo}
|
||||||
|
|
||||||
|
err := svc.DeleteProxy(context.Background(), 77)
|
||||||
|
require.ErrorIs(t, err, ErrProxyInUse)
|
||||||
|
require.Empty(t, repo.deletedIDs)
|
||||||
|
}
|
||||||
|
|
||||||
func TestAdminService_DeleteProxy_Error(t *testing.T) {
|
func TestAdminService_DeleteProxy_Error(t *testing.T) {
|
||||||
deleteErr := errors.New("delete failed")
|
deleteErr := errors.New("delete failed")
|
||||||
repo := &proxyRepoStub{deleteErr: deleteErr}
|
repo := &proxyRepoStub{deleteErr: deleteErr}
|
||||||
|
|||||||
@@ -564,6 +564,10 @@ urlFallbackLoop:
|
|||||||
}
|
}
|
||||||
|
|
||||||
upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, action, accessToken, geminiBody)
|
upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, action, accessToken, geminiBody)
|
||||||
|
// Capture upstream request body for ops retry of this attempt.
|
||||||
|
if c != nil {
|
||||||
|
c.Set(OpsUpstreamRequestBodyKey, string(geminiBody))
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -574,6 +578,7 @@ urlFallbackLoop:
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: 0,
|
UpstreamStatusCode: 0,
|
||||||
Kind: "request_error",
|
Kind: "request_error",
|
||||||
Message: safeErr,
|
Message: safeErr,
|
||||||
@@ -615,6 +620,7 @@ urlFallbackLoop:
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: "retry",
|
Kind: "retry",
|
||||||
@@ -645,6 +651,7 @@ urlFallbackLoop:
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: "retry",
|
Kind: "retry",
|
||||||
@@ -697,6 +704,7 @@ urlFallbackLoop:
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: "signature_error",
|
Kind: "signature_error",
|
||||||
@@ -740,6 +748,7 @@ urlFallbackLoop:
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: 0,
|
UpstreamStatusCode: 0,
|
||||||
Kind: "signature_retry_request_error",
|
Kind: "signature_retry_request_error",
|
||||||
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
|
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
|
||||||
@@ -770,6 +779,7 @@ urlFallbackLoop:
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: retryResp.StatusCode,
|
UpstreamStatusCode: retryResp.StatusCode,
|
||||||
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
|
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
|
||||||
Kind: kind,
|
Kind: kind,
|
||||||
@@ -817,6 +827,7 @@ urlFallbackLoop:
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: "failover",
|
Kind: "failover",
|
||||||
@@ -1371,6 +1382,7 @@ urlFallbackLoop:
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: 0,
|
UpstreamStatusCode: 0,
|
||||||
Kind: "request_error",
|
Kind: "request_error",
|
||||||
Message: safeErr,
|
Message: safeErr,
|
||||||
@@ -1412,6 +1424,7 @@ urlFallbackLoop:
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: "retry",
|
Kind: "retry",
|
||||||
@@ -1442,6 +1455,7 @@ urlFallbackLoop:
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: "retry",
|
Kind: "retry",
|
||||||
@@ -1543,6 +1557,7 @@ urlFallbackLoop:
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: requestID,
|
UpstreamRequestID: requestID,
|
||||||
Kind: "failover",
|
Kind: "failover",
|
||||||
@@ -1559,6 +1574,7 @@ urlFallbackLoop:
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: requestID,
|
UpstreamRequestID: requestID,
|
||||||
Kind: "http_error",
|
Kind: "http_error",
|
||||||
@@ -2039,6 +2055,7 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: upstreamStatus,
|
UpstreamStatusCode: upstreamStatus,
|
||||||
UpstreamRequestID: upstreamRequestID,
|
UpstreamRequestID: upstreamRequestID,
|
||||||
Kind: "http_error",
|
Kind: "http_error",
|
||||||
|
|||||||
@@ -49,6 +49,9 @@ func (a *Account) IsSchedulableForModel(requestedModel string) bool {
|
|||||||
if !a.IsSchedulable() {
|
if !a.IsSchedulable() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
if a.isModelRateLimited(requestedModel) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
if a.Platform != PlatformAntigravity {
|
if a.Platform != PlatformAntigravity {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
|||||||
return "", errors.New("not an antigravity oauth account")
|
return "", errors.New("not an antigravity oauth account")
|
||||||
}
|
}
|
||||||
|
|
||||||
cacheKey := antigravityTokenCacheKey(account)
|
cacheKey := AntigravityTokenCacheKey(account)
|
||||||
|
|
||||||
// 1. 先尝试缓存
|
// 1. 先尝试缓存
|
||||||
if p.tokenCache != nil {
|
if p.tokenCache != nil {
|
||||||
@@ -121,7 +121,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
|||||||
return accessToken, nil
|
return accessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func antigravityTokenCacheKey(account *Account) string {
|
func AntigravityTokenCacheKey(account *Account) string {
|
||||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||||
if projectID != "" {
|
if projectID != "" {
|
||||||
return "ag:" + projectID
|
return "ag:" + projectID
|
||||||
|
|||||||
@@ -37,6 +37,11 @@ type APIKeyAuthGroupSnapshot struct {
|
|||||||
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
||||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||||
|
|
||||||
|
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
|
||||||
|
// Only anthropic groups use these fields; others may leave them empty.
|
||||||
|
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
||||||
|
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
|
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
|
||||||
|
|||||||
@@ -221,6 +221,8 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
|||||||
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||||
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
||||||
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
||||||
|
ModelRouting: apiKey.Group.ModelRouting,
|
||||||
|
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return snapshot
|
return snapshot
|
||||||
@@ -263,6 +265,8 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
|||||||
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||||
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
||||||
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
||||||
|
ModelRouting: snapshot.Group.ModelRouting,
|
||||||
|
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return apiKey
|
return apiKey
|
||||||
|
|||||||
@@ -178,6 +178,10 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
|
|||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
SubscriptionType: SubscriptionTypeStandard,
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
RateMultiplier: 1,
|
RateMultiplier: 1,
|
||||||
|
ModelRoutingEnabled: true,
|
||||||
|
ModelRouting: map[string][]int64{
|
||||||
|
"claude-opus-*": {1, 2},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -190,6 +194,8 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
|
|||||||
require.Equal(t, int64(1), apiKey.ID)
|
require.Equal(t, int64(1), apiKey.ID)
|
||||||
require.Equal(t, int64(2), apiKey.User.ID)
|
require.Equal(t, int64(2), apiKey.User.ID)
|
||||||
require.Equal(t, groupID, apiKey.Group.ID)
|
require.Equal(t, groupID, apiKey.Group.ID)
|
||||||
|
require.True(t, apiKey.Group.ModelRoutingEnabled)
|
||||||
|
require.Equal(t, map[string][]int64{"claude-opus-*": {1, 2}}, apiKey.Group.ModelRouting)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
|
func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
|
||||||
|
|||||||
208
backend/internal/service/claude_token_provider.go
Normal file
208
backend/internal/service/claude_token_provider.go
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"log/slog"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
claudeTokenRefreshSkew = 3 * time.Minute
|
||||||
|
claudeTokenCacheSkew = 5 * time.Minute
|
||||||
|
claudeLockWaitTime = 200 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
||||||
|
type ClaudeTokenCache = GeminiTokenCache
|
||||||
|
|
||||||
|
// ClaudeTokenProvider 管理 Claude (Anthropic) OAuth 账户的 access_token
|
||||||
|
type ClaudeTokenProvider struct {
|
||||||
|
accountRepo AccountRepository
|
||||||
|
tokenCache ClaudeTokenCache
|
||||||
|
oauthService *OAuthService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClaudeTokenProvider(
|
||||||
|
accountRepo AccountRepository,
|
||||||
|
tokenCache ClaudeTokenCache,
|
||||||
|
oauthService *OAuthService,
|
||||||
|
) *ClaudeTokenProvider {
|
||||||
|
return &ClaudeTokenProvider{
|
||||||
|
accountRepo: accountRepo,
|
||||||
|
tokenCache: tokenCache,
|
||||||
|
oauthService: oauthService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccessToken 获取有效的 access_token
|
||||||
|
func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||||
|
if account == nil {
|
||||||
|
return "", errors.New("account is nil")
|
||||||
|
}
|
||||||
|
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
|
||||||
|
return "", errors.New("not an anthropic oauth account")
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheKey := ClaudeTokenCacheKey(account)
|
||||||
|
|
||||||
|
// 1. 先尝试缓存
|
||||||
|
if p.tokenCache != nil {
|
||||||
|
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||||
|
slog.Debug("claude_token_cache_hit", "account_id", account.ID)
|
||||||
|
return token, nil
|
||||||
|
} else if err != nil {
|
||||||
|
slog.Warn("claude_token_cache_get_failed", "account_id", account.ID, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("claude_token_cache_miss", "account_id", account.ID)
|
||||||
|
|
||||||
|
// 2. 如果即将过期则刷新
|
||||||
|
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||||
|
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
|
||||||
|
refreshFailed := false
|
||||||
|
if needsRefresh && p.tokenCache != nil {
|
||||||
|
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||||
|
if lockErr == nil && locked {
|
||||||
|
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||||
|
|
||||||
|
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
|
||||||
|
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从数据库获取最新账户信息
|
||||||
|
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||||
|
if err == nil && fresh != nil {
|
||||||
|
account = fresh
|
||||||
|
}
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
|
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
|
||||||
|
if p.oauthService == nil {
|
||||||
|
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
|
||||||
|
refreshFailed = true // 无法刷新,标记失败
|
||||||
|
} else {
|
||||||
|
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
|
||||||
|
slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err)
|
||||||
|
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
||||||
|
} else {
|
||||||
|
// 构建新 credentials,保留原有字段
|
||||||
|
newCredentials := make(map[string]any)
|
||||||
|
for k, v := range account.Credentials {
|
||||||
|
newCredentials[k] = v
|
||||||
|
}
|
||||||
|
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||||
|
newCredentials["token_type"] = tokenInfo.TokenType
|
||||||
|
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
|
||||||
|
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||||
|
if tokenInfo.RefreshToken != "" {
|
||||||
|
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||||
|
}
|
||||||
|
if tokenInfo.Scope != "" {
|
||||||
|
newCredentials["scope"] = tokenInfo.Scope
|
||||||
|
}
|
||||||
|
account.Credentials = newCredentials
|
||||||
|
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||||
|
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
||||||
|
}
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if lockErr != nil {
|
||||||
|
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
|
||||||
|
slog.Warn("claude_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
|
||||||
|
|
||||||
|
// 检查 ctx 是否已取消
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return "", ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从数据库获取最新账户信息
|
||||||
|
if p.accountRepo != nil {
|
||||||
|
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||||
|
if err == nil && fresh != nil {
|
||||||
|
account = fresh
|
||||||
|
}
|
||||||
|
}
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
|
|
||||||
|
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
|
||||||
|
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
|
||||||
|
if p.oauthService == nil {
|
||||||
|
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
|
||||||
|
refreshFailed = true
|
||||||
|
} else {
|
||||||
|
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("claude_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
|
||||||
|
refreshFailed = true
|
||||||
|
} else {
|
||||||
|
// 构建新 credentials,保留原有字段
|
||||||
|
newCredentials := make(map[string]any)
|
||||||
|
for k, v := range account.Credentials {
|
||||||
|
newCredentials[k] = v
|
||||||
|
}
|
||||||
|
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||||
|
newCredentials["token_type"] = tokenInfo.TokenType
|
||||||
|
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
|
||||||
|
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||||
|
if tokenInfo.RefreshToken != "" {
|
||||||
|
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||||
|
}
|
||||||
|
if tokenInfo.Scope != "" {
|
||||||
|
newCredentials["scope"] = tokenInfo.Scope
|
||||||
|
}
|
||||||
|
account.Credentials = newCredentials
|
||||||
|
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||||
|
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
||||||
|
}
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
|
||||||
|
time.Sleep(claudeLockWaitTime)
|
||||||
|
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||||
|
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken := account.GetCredential("access_token")
|
||||||
|
if strings.TrimSpace(accessToken) == "" {
|
||||||
|
return "", errors.New("access_token not found in credentials")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 存入缓存
|
||||||
|
if p.tokenCache != nil {
|
||||||
|
ttl := 30 * time.Minute
|
||||||
|
if refreshFailed {
|
||||||
|
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
|
||||||
|
ttl = time.Minute
|
||||||
|
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
||||||
|
} else if expiresAt != nil {
|
||||||
|
until := time.Until(*expiresAt)
|
||||||
|
switch {
|
||||||
|
case until > claudeTokenCacheSkew:
|
||||||
|
ttl = until - claudeTokenCacheSkew
|
||||||
|
case until > 0:
|
||||||
|
ttl = until
|
||||||
|
default:
|
||||||
|
ttl = time.Minute
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
|
||||||
|
slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return accessToken, nil
|
||||||
|
}
|
||||||
939
backend/internal/service/claude_token_provider_test.go
Normal file
939
backend/internal/service/claude_token_provider_test.go
Normal file
@@ -0,0 +1,939 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// claudeTokenCacheStub implements ClaudeTokenCache for testing
|
||||||
|
type claudeTokenCacheStub struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
tokens map[string]string
|
||||||
|
getErr error
|
||||||
|
setErr error
|
||||||
|
deleteErr error
|
||||||
|
lockAcquired bool
|
||||||
|
lockErr error
|
||||||
|
releaseLockErr error
|
||||||
|
getCalled int32
|
||||||
|
setCalled int32
|
||||||
|
lockCalled int32
|
||||||
|
unlockCalled int32
|
||||||
|
simulateLockRace bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClaudeTokenCacheStub() *claudeTokenCacheStub {
|
||||||
|
return &claudeTokenCacheStub{
|
||||||
|
tokens: make(map[string]string),
|
||||||
|
lockAcquired: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeTokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
|
||||||
|
atomic.AddInt32(&s.getCalled, 1)
|
||||||
|
if s.getErr != nil {
|
||||||
|
return "", s.getErr
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
return s.tokens[cacheKey], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeTokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
|
||||||
|
atomic.AddInt32(&s.setCalled, 1)
|
||||||
|
if s.setErr != nil {
|
||||||
|
return s.setErr
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.tokens[cacheKey] = token
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeTokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
|
||||||
|
if s.deleteErr != nil {
|
||||||
|
return s.deleteErr
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
delete(s.tokens, cacheKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeTokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
|
||||||
|
atomic.AddInt32(&s.lockCalled, 1)
|
||||||
|
if s.lockErr != nil {
|
||||||
|
return false, s.lockErr
|
||||||
|
}
|
||||||
|
if s.simulateLockRace {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return s.lockAcquired, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeTokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
|
||||||
|
atomic.AddInt32(&s.unlockCalled, 1)
|
||||||
|
return s.releaseLockErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// claudeAccountRepoStub is a minimal stub implementing only the methods used by ClaudeTokenProvider
|
||||||
|
type claudeAccountRepoStub struct {
|
||||||
|
account *Account
|
||||||
|
getErr error
|
||||||
|
updateErr error
|
||||||
|
getCalled int32
|
||||||
|
updateCalled int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *claudeAccountRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||||
|
atomic.AddInt32(&r.getCalled, 1)
|
||||||
|
if r.getErr != nil {
|
||||||
|
return nil, r.getErr
|
||||||
|
}
|
||||||
|
return r.account, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *claudeAccountRepoStub) Update(ctx context.Context, account *Account) error {
|
||||||
|
atomic.AddInt32(&r.updateCalled, 1)
|
||||||
|
if r.updateErr != nil {
|
||||||
|
return r.updateErr
|
||||||
|
}
|
||||||
|
r.account = account
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// claudeOAuthServiceStub implements OAuthService methods for testing
|
||||||
|
type claudeOAuthServiceStub struct {
|
||||||
|
tokenInfo *TokenInfo
|
||||||
|
refreshErr error
|
||||||
|
refreshCalled int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeOAuthServiceStub) RefreshAccountToken(ctx context.Context, account *Account) (*TokenInfo, error) {
|
||||||
|
atomic.AddInt32(&s.refreshCalled, 1)
|
||||||
|
if s.refreshErr != nil {
|
||||||
|
return nil, s.refreshErr
|
||||||
|
}
|
||||||
|
return s.tokenInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// testClaudeTokenProvider is a test version that uses the stub OAuth service
|
||||||
|
type testClaudeTokenProvider struct {
|
||||||
|
accountRepo *claudeAccountRepoStub
|
||||||
|
tokenCache *claudeTokenCacheStub
|
||||||
|
oauthService *claudeOAuthServiceStub
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *testClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||||
|
if account == nil {
|
||||||
|
return "", errors.New("account is nil")
|
||||||
|
}
|
||||||
|
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
|
||||||
|
return "", errors.New("not an anthropic oauth account")
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheKey := ClaudeTokenCacheKey(account)
|
||||||
|
|
||||||
|
// 1. Check cache
|
||||||
|
if p.tokenCache != nil {
|
||||||
|
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Check if refresh needed
|
||||||
|
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||||
|
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
|
||||||
|
refreshFailed := false
|
||||||
|
if needsRefresh && p.tokenCache != nil {
|
||||||
|
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||||
|
if err == nil && locked {
|
||||||
|
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||||
|
|
||||||
|
// Check cache again after acquiring lock
|
||||||
|
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get fresh account from DB
|
||||||
|
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||||
|
if err == nil && fresh != nil {
|
||||||
|
account = fresh
|
||||||
|
}
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
|
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
|
||||||
|
if p.oauthService == nil {
|
||||||
|
refreshFailed = true // 无法刷新,标记失败
|
||||||
|
} else {
|
||||||
|
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
||||||
|
} else {
|
||||||
|
// Build new credentials
|
||||||
|
newCredentials := make(map[string]any)
|
||||||
|
for k, v := range account.Credentials {
|
||||||
|
newCredentials[k] = v
|
||||||
|
}
|
||||||
|
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||||
|
newCredentials["token_type"] = tokenInfo.TokenType
|
||||||
|
newCredentials["expires_at"] = time.Now().Add(time.Duration(tokenInfo.ExpiresIn) * time.Second).Format(time.RFC3339)
|
||||||
|
if tokenInfo.RefreshToken != "" {
|
||||||
|
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||||
|
}
|
||||||
|
account.Credentials = newCredentials
|
||||||
|
_ = p.accountRepo.Update(ctx, account)
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if p.tokenCache.simulateLockRace {
|
||||||
|
// Wait and retry cache
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken := account.GetCredential("access_token")
|
||||||
|
if accessToken == "" {
|
||||||
|
return "", errors.New("access_token not found in credentials")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Store in cache
|
||||||
|
if p.tokenCache != nil {
|
||||||
|
ttl := 30 * time.Minute
|
||||||
|
if refreshFailed {
|
||||||
|
ttl = time.Minute // 刷新失败时使用短 TTL
|
||||||
|
} else if expiresAt != nil {
|
||||||
|
until := time.Until(*expiresAt)
|
||||||
|
if until > claudeTokenCacheSkew {
|
||||||
|
ttl = until - claudeTokenCacheSkew
|
||||||
|
} else if until > 0 {
|
||||||
|
ttl = until
|
||||||
|
} else {
|
||||||
|
ttl = time.Minute
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
return accessToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_CacheHit(t *testing.T) {
|
||||||
|
cache := newClaudeTokenCacheStub()
|
||||||
|
account := &Account{
|
||||||
|
ID: 100,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "db-token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cacheKey := ClaudeTokenCacheKey(account)
|
||||||
|
cache.tokens[cacheKey] = "cached-token"
|
||||||
|
|
||||||
|
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||||
|
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "cached-token", token)
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalled))
|
||||||
|
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalled))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_CacheMiss_FromCredentials(t *testing.T) {
|
||||||
|
cache := newClaudeTokenCacheStub()
|
||||||
|
// Token expires in far future, no refresh needed
|
||||||
|
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 101,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "credential-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||||
|
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "credential-token", token)
|
||||||
|
|
||||||
|
// Should have stored in cache
|
||||||
|
cacheKey := ClaudeTokenCacheKey(account)
|
||||||
|
require.Equal(t, "credential-token", cache.tokens[cacheKey])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_TokenRefresh(t *testing.T) {
|
||||||
|
cache := newClaudeTokenCacheStub()
|
||||||
|
accountRepo := &claudeAccountRepoStub{}
|
||||||
|
oauthService := &claudeOAuthServiceStub{
|
||||||
|
tokenInfo: &TokenInfo{
|
||||||
|
AccessToken: "refreshed-token",
|
||||||
|
RefreshToken: "new-refresh-token",
|
||||||
|
TokenType: "Bearer",
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token expires soon (within refresh skew)
|
||||||
|
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 102,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "old-token",
|
||||||
|
"refresh_token": "old-refresh-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
accountRepo.account = account
|
||||||
|
|
||||||
|
provider := &testClaudeTokenProvider{
|
||||||
|
accountRepo: accountRepo,
|
||||||
|
tokenCache: cache,
|
||||||
|
oauthService: oauthService,
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "refreshed-token", token)
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&oauthService.refreshCalled))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_LockRaceCondition(t *testing.T) {
|
||||||
|
cache := newClaudeTokenCacheStub()
|
||||||
|
cache.simulateLockRace = true
|
||||||
|
accountRepo := &claudeAccountRepoStub{}
|
||||||
|
|
||||||
|
// Token expires soon
|
||||||
|
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 103,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "race-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
accountRepo.account = account
|
||||||
|
|
||||||
|
// Simulate another worker already refreshed and cached
|
||||||
|
cacheKey := ClaudeTokenCacheKey(account)
|
||||||
|
go func() {
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
cache.mu.Lock()
|
||||||
|
cache.tokens[cacheKey] = "winner-token"
|
||||||
|
cache.mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
provider := &testClaudeTokenProvider{
|
||||||
|
accountRepo: accountRepo,
|
||||||
|
tokenCache: cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_NilAccount(t *testing.T) {
|
||||||
|
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||||
|
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "account is nil")
|
||||||
|
require.Empty(t, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_WrongPlatform(t *testing.T) {
|
||||||
|
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||||
|
account := &Account{
|
||||||
|
ID: 104,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "not an anthropic oauth account")
|
||||||
|
require.Empty(t, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_WrongAccountType(t *testing.T) {
|
||||||
|
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||||
|
account := &Account{
|
||||||
|
ID: 105,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "not an anthropic oauth account")
|
||||||
|
require.Empty(t, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_SetupTokenType(t *testing.T) {
|
||||||
|
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||||
|
account := &Account{
|
||||||
|
ID: 106,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeSetupToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "not an anthropic oauth account")
|
||||||
|
require.Empty(t, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_NilCache(t *testing.T) {
|
||||||
|
// Token doesn't need refresh
|
||||||
|
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 107,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "nocache-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||||
|
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "nocache-token", token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_CacheGetError(t *testing.T) {
|
||||||
|
cache := newClaudeTokenCacheStub()
|
||||||
|
cache.getErr = errors.New("redis connection failed")
|
||||||
|
|
||||||
|
// Token doesn't need refresh
|
||||||
|
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 108,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "fallback-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||||
|
|
||||||
|
// Should gracefully degrade and return from credentials
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "fallback-token", token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_CacheSetError(t *testing.T) {
|
||||||
|
cache := newClaudeTokenCacheStub()
|
||||||
|
cache.setErr = errors.New("redis write failed")
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 109,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "still-works-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||||
|
|
||||||
|
// Should still work even if cache set fails
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "still-works-token", token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_MissingAccessToken(t *testing.T) {
|
||||||
|
cache := newClaudeTokenCacheStub()
|
||||||
|
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 110,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
// missing access_token
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||||
|
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "access_token not found")
|
||||||
|
require.Empty(t, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_RefreshError(t *testing.T) {
|
||||||
|
cache := newClaudeTokenCacheStub()
|
||||||
|
accountRepo := &claudeAccountRepoStub{}
|
||||||
|
oauthService := &claudeOAuthServiceStub{
|
||||||
|
refreshErr: errors.New("oauth refresh failed"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token expires soon
|
||||||
|
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 111,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "old-token",
|
||||||
|
"refresh_token": "old-refresh-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
accountRepo.account = account
|
||||||
|
|
||||||
|
provider := &testClaudeTokenProvider{
|
||||||
|
accountRepo: accountRepo,
|
||||||
|
tokenCache: cache,
|
||||||
|
oauthService: oauthService,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now with fallback behavior, should return existing token even if refresh fails
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "old-token", token) // Fallback to existing token
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_OAuthServiceNotConfigured(t *testing.T) {
|
||||||
|
cache := newClaudeTokenCacheStub()
|
||||||
|
accountRepo := &claudeAccountRepoStub{}
|
||||||
|
|
||||||
|
// Token expires soon
|
||||||
|
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 112,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "old-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
accountRepo.account = account
|
||||||
|
|
||||||
|
provider := &testClaudeTokenProvider{
|
||||||
|
accountRepo: accountRepo,
|
||||||
|
tokenCache: cache,
|
||||||
|
oauthService: nil, // not configured
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now with fallback behavior, should return existing token even if oauth service not configured
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "old-token", token) // Fallback to existing token
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_TTLCalculation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
expiresIn time.Duration
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "far_future_expiry",
|
||||||
|
expiresIn: 1 * time.Hour,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "medium_expiry",
|
||||||
|
expiresIn: 10 * time.Minute,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "near_expiry",
|
||||||
|
expiresIn: 6 * time.Minute,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cache := newClaudeTokenCacheStub()
|
||||||
|
expiresAt := time.Now().Add(tt.expiresIn).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 200,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "test-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||||
|
|
||||||
|
_, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify token was cached
|
||||||
|
cacheKey := ClaudeTokenCacheKey(account)
|
||||||
|
require.Equal(t, "test-token", cache.tokens[cacheKey])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_AccountRepoGetError(t *testing.T) {
|
||||||
|
cache := newClaudeTokenCacheStub()
|
||||||
|
accountRepo := &claudeAccountRepoStub{
|
||||||
|
getErr: errors.New("db connection failed"),
|
||||||
|
}
|
||||||
|
oauthService := &claudeOAuthServiceStub{
|
||||||
|
tokenInfo: &TokenInfo{
|
||||||
|
AccessToken: "refreshed-token",
|
||||||
|
RefreshToken: "new-refresh",
|
||||||
|
TokenType: "Bearer",
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token expires soon
|
||||||
|
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 113,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "old-token",
|
||||||
|
"refresh_token": "old-refresh",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := &testClaudeTokenProvider{
|
||||||
|
accountRepo: accountRepo,
|
||||||
|
tokenCache: cache,
|
||||||
|
oauthService: oauthService,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should still work, just using the passed-in account
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "refreshed-token", token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_AccountUpdateError(t *testing.T) {
|
||||||
|
cache := newClaudeTokenCacheStub()
|
||||||
|
accountRepo := &claudeAccountRepoStub{
|
||||||
|
updateErr: errors.New("db write failed"),
|
||||||
|
}
|
||||||
|
oauthService := &claudeOAuthServiceStub{
|
||||||
|
tokenInfo: &TokenInfo{
|
||||||
|
AccessToken: "refreshed-token",
|
||||||
|
RefreshToken: "new-refresh",
|
||||||
|
TokenType: "Bearer",
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token expires soon
|
||||||
|
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 114,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "old-token",
|
||||||
|
"refresh_token": "old-refresh",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
accountRepo.account = account
|
||||||
|
|
||||||
|
provider := &testClaudeTokenProvider{
|
||||||
|
accountRepo: accountRepo,
|
||||||
|
tokenCache: cache,
|
||||||
|
oauthService: oauthService,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should still return token even if update fails
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "refreshed-token", token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_RefreshPreservesExistingCredentials(t *testing.T) {
|
||||||
|
cache := newClaudeTokenCacheStub()
|
||||||
|
accountRepo := &claudeAccountRepoStub{}
|
||||||
|
oauthService := &claudeOAuthServiceStub{
|
||||||
|
tokenInfo: &TokenInfo{
|
||||||
|
AccessToken: "new-access-token",
|
||||||
|
RefreshToken: "new-refresh-token",
|
||||||
|
TokenType: "Bearer",
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token expires soon
|
||||||
|
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 115,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "old-access-token",
|
||||||
|
"refresh_token": "old-refresh-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
"custom_field": "should-be-preserved",
|
||||||
|
"organization": "test-org",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
accountRepo.account = account
|
||||||
|
|
||||||
|
provider := &testClaudeTokenProvider{
|
||||||
|
accountRepo: accountRepo,
|
||||||
|
tokenCache: cache,
|
||||||
|
oauthService: oauthService,
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "new-access-token", token)
|
||||||
|
|
||||||
|
// Verify existing fields are preserved
|
||||||
|
require.Equal(t, "should-be-preserved", accountRepo.account.Credentials["custom_field"])
|
||||||
|
require.Equal(t, "test-org", accountRepo.account.Credentials["organization"])
|
||||||
|
// Verify new fields are updated
|
||||||
|
require.Equal(t, "new-access-token", accountRepo.account.Credentials["access_token"])
|
||||||
|
require.Equal(t, "new-refresh-token", accountRepo.account.Credentials["refresh_token"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_DoubleCheckCacheAfterLock(t *testing.T) {
|
||||||
|
cache := newClaudeTokenCacheStub()
|
||||||
|
accountRepo := &claudeAccountRepoStub{}
|
||||||
|
oauthService := &claudeOAuthServiceStub{
|
||||||
|
tokenInfo: &TokenInfo{
|
||||||
|
AccessToken: "refreshed-token",
|
||||||
|
RefreshToken: "new-refresh",
|
||||||
|
TokenType: "Bearer",
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token expires soon
|
||||||
|
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 116,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "old-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
accountRepo.account = account
|
||||||
|
cacheKey := ClaudeTokenCacheKey(account)
|
||||||
|
|
||||||
|
// After lock is acquired, cache should have the token (simulating another worker)
|
||||||
|
go func() {
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
cache.mu.Lock()
|
||||||
|
cache.tokens[cacheKey] = "cached-by-other-worker"
|
||||||
|
cache.mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
provider := &testClaudeTokenProvider{
|
||||||
|
accountRepo: accountRepo,
|
||||||
|
tokenCache: cache,
|
||||||
|
oauthService: oauthService,
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests for real provider - to increase coverage
|
||||||
|
func TestClaudeTokenProvider_Real_LockFailedWait(t *testing.T) {
|
||||||
|
cache := newClaudeTokenCacheStub()
|
||||||
|
cache.lockAcquired = false // Lock acquisition fails
|
||||||
|
|
||||||
|
// Token expires soon (within refresh skew) to trigger lock attempt
|
||||||
|
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 300,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "fallback-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set token in cache after lock wait period (simulate other worker refreshing)
|
||||||
|
cacheKey := ClaudeTokenCacheKey(account)
|
||||||
|
go func() {
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
cache.mu.Lock()
|
||||||
|
cache.tokens[cacheKey] = "refreshed-by-other"
|
||||||
|
cache.mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_Real_CacheHitAfterWait(t *testing.T) {
|
||||||
|
cache := newClaudeTokenCacheStub()
|
||||||
|
cache.lockAcquired = false // Lock acquisition fails
|
||||||
|
|
||||||
|
// Token expires soon
|
||||||
|
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 301,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "original-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheKey := ClaudeTokenCacheKey(account)
|
||||||
|
// Set token in cache immediately after wait starts
|
||||||
|
go func() {
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
cache.mu.Lock()
|
||||||
|
cache.tokens[cacheKey] = "winner-token"
|
||||||
|
cache.mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_Real_NoExpiresAt(t *testing.T) {
|
||||||
|
cache := newClaudeTokenCacheStub()
|
||||||
|
cache.lockAcquired = false // Prevent entering refresh logic
|
||||||
|
|
||||||
|
// Token with nil expires_at (no expiry set)
|
||||||
|
account := &Account{
|
||||||
|
ID: 302,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "no-expiry-token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// After lock wait, return token from credentials
|
||||||
|
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "no-expiry-token", token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_Real_WhitespaceToken(t *testing.T) {
|
||||||
|
cache := newClaudeTokenCacheStub()
|
||||||
|
cacheKey := "claude:account:303"
|
||||||
|
cache.tokens[cacheKey] = " " // Whitespace only - should be treated as empty
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 303,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "real-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "real-token", token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_Real_EmptyCredentialToken(t *testing.T) {
|
||||||
|
cache := newClaudeTokenCacheStub()
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 304,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": " ", // Whitespace only
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "access_token not found")
|
||||||
|
require.Empty(t, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_Real_LockError(t *testing.T) {
|
||||||
|
cache := newClaudeTokenCacheStub()
|
||||||
|
cache.lockErr = errors.New("redis lock failed")
|
||||||
|
|
||||||
|
// Token expires soon (within refresh skew)
|
||||||
|
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 305,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "fallback-on-lock-error",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "fallback-on-lock-error", token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeTokenProvider_Real_NilCredentials(t *testing.T) {
|
||||||
|
cache := newClaudeTokenCacheStub()
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 306,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
// No access_token
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "access_token not found")
|
||||||
|
require.Empty(t, token)
|
||||||
|
}
|
||||||
@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
|
|||||||
return stats, nil
|
return stats, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) {
|
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) {
|
||||||
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID)
|
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get usage trend with filters: %w", err)
|
return nil, fmt.Errorf("get usage trend with filters: %w", err)
|
||||||
}
|
}
|
||||||
return trend, nil
|
return trend, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID int64) ([]usagestats.ModelStat, error) {
|
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) {
|
||||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, 0)
|
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get model stats with filters: %w", err)
|
return nil, fmt.Errorf("get model stats with filters: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -142,6 +142,9 @@ func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int6
|
|||||||
func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -157,6 +160,9 @@ func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int6
|
|||||||
func (m *mockAccountRepoForPlatform) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
func (m *mockAccountRepoForPlatform) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (m *mockAccountRepoForPlatform) ClearModelRateLimits(ctx context.Context, id int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
func (m *mockAccountRepoForPlatform) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
func (m *mockAccountRepoForPlatform) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1046,13 +1052,67 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: nil, // No concurrency service
|
concurrencyService: nil, // No concurrency service
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号")
|
require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("模型路由-无ConcurrencyService也生效", func(t *testing.T) {
|
||||||
|
groupID := int64(1)
|
||||||
|
sessionHash := "sticky"
|
||||||
|
|
||||||
|
repo := &mockAccountRepoForPlatform{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, AccountGroups: []AccountGroup{{GroupID: groupID}}},
|
||||||
|
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, AccountGroups: []AccountGroup{{GroupID: groupID}}},
|
||||||
|
},
|
||||||
|
accountsByID: map[int64]*Account{},
|
||||||
|
}
|
||||||
|
for i := range repo.accounts {
|
||||||
|
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &mockGatewayCacheForPlatform{
|
||||||
|
sessionBindings: map[string]int64{sessionHash: 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
groupRepo := &mockGroupRepoForGateway{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
groupID: {
|
||||||
|
ID: groupID,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Status: StatusActive,
|
||||||
|
Hydrated: true,
|
||||||
|
ModelRoutingEnabled: true,
|
||||||
|
ModelRouting: map[string][]int64{
|
||||||
|
"claude-a": {1},
|
||||||
|
"claude-b": {2},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := testConfig()
|
||||||
|
cfg.Gateway.Scheduling.LoadBatchEnabled = true
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
groupRepo: groupRepo,
|
||||||
|
cache: cache,
|
||||||
|
cfg: cfg,
|
||||||
|
concurrencyService: nil, // legacy path
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, result.Account)
|
||||||
|
require.Equal(t, int64(2), result.Account.ID, "切换到 claude-b 时应按模型路由切换账号")
|
||||||
|
require.Equal(t, int64(2), cache.sessionBindings[sessionHash], "粘性绑定应更新为路由选择的账号")
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) {
|
t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) {
|
||||||
repo := &mockAccountRepoForPlatform{
|
repo := &mockAccountRepoForPlatform{
|
||||||
accounts: []Account{
|
accounts: []Account{
|
||||||
@@ -1077,7 +1137,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: nil,
|
concurrencyService: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -1109,7 +1169,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
excludedIDs := map[int64]struct{}{1: {}}
|
excludedIDs := map[int64]struct{}{1: {}}
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs)
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -1143,7 +1203,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil)
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -1179,7 +1239,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil)
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -1206,7 +1266,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: nil,
|
concurrencyService: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, result)
|
require.Nil(t, result)
|
||||||
require.Contains(t, err.Error(), "no available accounts")
|
require.Contains(t, err.Error(), "no available accounts")
|
||||||
@@ -1238,7 +1298,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: nil,
|
concurrencyService: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -1271,7 +1331,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
concurrencyService: nil,
|
concurrencyService: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.Account)
|
require.NotNil(t, result.Account)
|
||||||
@@ -1341,6 +1401,7 @@ func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T)
|
|||||||
ID: groupID,
|
ID: groupID,
|
||||||
Platform: PlatformAnthropic,
|
Platform: PlatformAnthropic,
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
|
Hydrated: true,
|
||||||
}
|
}
|
||||||
groupRepo := &mockGroupRepoForGateway{
|
groupRepo := &mockGroupRepoForGateway{
|
||||||
groups: map[int64]*Group{groupID: group},
|
groups: map[int64]*Group{groupID: group},
|
||||||
@@ -1398,6 +1459,7 @@ func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) {
|
|||||||
ID: fallbackID,
|
ID: fallbackID,
|
||||||
Platform: PlatformAnthropic,
|
Platform: PlatformAnthropic,
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
|
Hydrated: true,
|
||||||
}
|
}
|
||||||
ctx = context.WithValue(ctx, ctxkey.Group, group)
|
ctx = context.WithValue(ctx, ctxkey.Group, group)
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -40,6 +41,21 @@ const (
|
|||||||
maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量
|
maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (s *GatewayService) debugModelRoutingEnabled() bool {
|
||||||
|
v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING")))
|
||||||
|
return v == "1" || v == "true" || v == "yes" || v == "on"
|
||||||
|
}
|
||||||
|
|
||||||
|
func shortSessionHash(sessionHash string) string {
|
||||||
|
if sessionHash == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if len(sessionHash) <= 8 {
|
||||||
|
return sessionHash
|
||||||
|
}
|
||||||
|
return sessionHash[:8]
|
||||||
|
}
|
||||||
|
|
||||||
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
||||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||||
var (
|
var (
|
||||||
@@ -196,6 +212,8 @@ type GatewayService struct {
|
|||||||
httpUpstream HTTPUpstream
|
httpUpstream HTTPUpstream
|
||||||
deferredService *DeferredService
|
deferredService *DeferredService
|
||||||
concurrencyService *ConcurrencyService
|
concurrencyService *ConcurrencyService
|
||||||
|
claudeTokenProvider *ClaudeTokenProvider
|
||||||
|
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGatewayService creates a new GatewayService
|
// NewGatewayService creates a new GatewayService
|
||||||
@@ -215,6 +233,8 @@ func NewGatewayService(
|
|||||||
identityService *IdentityService,
|
identityService *IdentityService,
|
||||||
httpUpstream HTTPUpstream,
|
httpUpstream HTTPUpstream,
|
||||||
deferredService *DeferredService,
|
deferredService *DeferredService,
|
||||||
|
claudeTokenProvider *ClaudeTokenProvider,
|
||||||
|
sessionLimitCache SessionLimitCache,
|
||||||
) *GatewayService {
|
) *GatewayService {
|
||||||
return &GatewayService{
|
return &GatewayService{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
@@ -232,6 +252,8 @@ func NewGatewayService(
|
|||||||
identityService: identityService,
|
identityService: identityService,
|
||||||
httpUpstream: httpUpstream,
|
httpUpstream: httpUpstream,
|
||||||
deferredService: deferredService,
|
deferredService: deferredService,
|
||||||
|
claudeTokenProvider: claudeTokenProvider,
|
||||||
|
sessionLimitCache: sessionLimitCache,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -797,8 +819,12 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
|
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
|
||||||
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
|
// metadataUserID: 原始 metadata.user_id 字段(用于提取会话 UUID 进行会话数量限制)
|
||||||
|
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) {
|
||||||
cfg := s.schedulingConfig()
|
cfg := s.schedulingConfig()
|
||||||
|
// 提取会话 UUID(用于会话数量限制)
|
||||||
|
sessionUUID := extractSessionUUID(metadataUserID)
|
||||||
|
|
||||||
var stickyAccountID int64
|
var stickyAccountID int64
|
||||||
if sessionHash != "" && s.cache != nil {
|
if sessionHash != "" && s.cache != nil {
|
||||||
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
|
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
|
||||||
@@ -813,6 +839,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
}
|
}
|
||||||
ctx = s.withGroupContext(ctx, group)
|
ctx = s.withGroupContext(ctx, group)
|
||||||
|
|
||||||
|
if s.debugModelRoutingEnabled() && requestedModel != "" {
|
||||||
|
groupPlatform := ""
|
||||||
|
if group != nil {
|
||||||
|
groupPlatform = group.Platform
|
||||||
|
}
|
||||||
|
log.Printf("[ModelRoutingDebug] select entry: group_id=%v group_platform=%s model=%s session=%s sticky_account=%d load_batch=%v concurrency=%v",
|
||||||
|
derefGroupID(groupID), groupPlatform, requestedModel, shortSessionHash(sessionHash), stickyAccountID, cfg.LoadBatchEnabled, s.concurrencyService != nil)
|
||||||
|
}
|
||||||
|
|
||||||
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
|
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
|
||||||
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -856,6 +891,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
preferOAuth := platform == PlatformGemini
|
preferOAuth := platform == PlatformGemini
|
||||||
|
if s.debugModelRoutingEnabled() && platform == PlatformAnthropic && requestedModel != "" {
|
||||||
|
log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform)
|
||||||
|
}
|
||||||
|
|
||||||
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -873,22 +911,235 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
return excluded
|
return excluded
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============ Layer 1: 粘性会话优先 ============
|
// 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
|
||||||
if sessionHash != "" && s.cache != nil {
|
|
||||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
|
||||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
|
||||||
// 粘性命中仅在当前可调度候选集中生效。
|
|
||||||
accountByID := make(map[int64]*Account, len(accounts))
|
accountByID := make(map[int64]*Account, len(accounts))
|
||||||
for i := range accounts {
|
for i := range accounts {
|
||||||
accountByID[accounts[i].ID] = &accounts[i]
|
accountByID[accounts[i].ID] = &accounts[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 获取模型路由配置(仅 anthropic 平台)
|
||||||
|
var routingAccountIDs []int64
|
||||||
|
if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic {
|
||||||
|
routingAccountIDs = group.GetRoutingAccountIDs(requestedModel)
|
||||||
|
if s.debugModelRoutingEnabled() {
|
||||||
|
log.Printf("[ModelRoutingDebug] context group routing: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v session=%s sticky_account=%d",
|
||||||
|
group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), routingAccountIDs, shortSessionHash(sessionHash), stickyAccountID)
|
||||||
|
if len(routingAccountIDs) == 0 && group.ModelRoutingEnabled && len(group.ModelRouting) > 0 {
|
||||||
|
keys := make([]string, 0, len(group.ModelRouting))
|
||||||
|
for k := range group.ModelRouting {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
const maxKeys = 20
|
||||||
|
if len(keys) > maxKeys {
|
||||||
|
keys = keys[:maxKeys]
|
||||||
|
}
|
||||||
|
log.Printf("[ModelRoutingDebug] context group routing miss: group_id=%d model=%s patterns(sample)=%v", group.ID, requestedModel, keys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ Layer 1: 模型路由优先选择(优先级高于粘性会话) ============
|
||||||
|
if len(routingAccountIDs) > 0 && s.concurrencyService != nil {
|
||||||
|
// 1. 过滤出路由列表中可调度的账号
|
||||||
|
var routingCandidates []*Account
|
||||||
|
var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost int
|
||||||
|
for _, routingAccountID := range routingAccountIDs {
|
||||||
|
if isExcluded(routingAccountID) {
|
||||||
|
filteredExcluded++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
account, ok := accountByID[routingAccountID]
|
||||||
|
if !ok || !account.IsSchedulable() {
|
||||||
|
if !ok {
|
||||||
|
filteredMissing++
|
||||||
|
} else {
|
||||||
|
filteredUnsched++
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !s.isAccountAllowedForPlatform(account, platform, useMixed) {
|
||||||
|
filteredPlatform++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !account.IsSchedulableForModel(requestedModel) {
|
||||||
|
filteredModelScope++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if requestedModel != "" && !s.isModelSupportedByAccount(account, requestedModel) {
|
||||||
|
filteredModelMapping++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 窗口费用检查(非粘性会话路径)
|
||||||
|
if !s.isAccountSchedulableForWindowCost(ctx, account, false) {
|
||||||
|
filteredWindowCost++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
routingCandidates = append(routingCandidates, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.debugModelRoutingEnabled() {
|
||||||
|
log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)",
|
||||||
|
derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates),
|
||||||
|
filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(routingCandidates) > 0 {
|
||||||
|
// 1.5. 在路由账号范围内检查粘性会话
|
||||||
|
if sessionHash != "" && s.cache != nil {
|
||||||
|
stickyAccountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
|
if err == nil && stickyAccountID > 0 && containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
|
||||||
|
// 粘性账号在路由列表中,优先使用
|
||||||
|
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
|
||||||
|
if stickyAccount.IsSchedulable() &&
|
||||||
|
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
|
||||||
|
stickyAccount.IsSchedulableForModel(requestedModel) &&
|
||||||
|
(requestedModel == "" || s.isModelSupportedByAccount(stickyAccount, requestedModel)) &&
|
||||||
|
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查
|
||||||
|
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
|
||||||
|
if err == nil && result.Acquired {
|
||||||
|
// 会话数量限制检查
|
||||||
|
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionUUID) {
|
||||||
|
result.ReleaseFunc() // 释放槽位
|
||||||
|
// 继续到负载感知选择
|
||||||
|
} else {
|
||||||
|
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||||
|
if s.debugModelRoutingEnabled() {
|
||||||
|
log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
|
||||||
|
}
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: stickyAccount,
|
||||||
|
Acquired: true,
|
||||||
|
ReleaseFunc: result.ReleaseFunc,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
|
||||||
|
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: stickyAccount,
|
||||||
|
WaitPlan: &AccountWaitPlan{
|
||||||
|
AccountID: stickyAccountID,
|
||||||
|
MaxConcurrency: stickyAccount.Concurrency,
|
||||||
|
Timeout: cfg.StickySessionWaitTimeout,
|
||||||
|
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 批量获取负载信息
|
||||||
|
routingLoads := make([]AccountWithConcurrency, 0, len(routingCandidates))
|
||||||
|
for _, acc := range routingCandidates {
|
||||||
|
routingLoads = append(routingLoads, AccountWithConcurrency{
|
||||||
|
ID: acc.ID,
|
||||||
|
MaxConcurrency: acc.Concurrency,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads)
|
||||||
|
|
||||||
|
// 3. 按负载感知排序
|
||||||
|
type accountWithLoad struct {
|
||||||
|
account *Account
|
||||||
|
loadInfo *AccountLoadInfo
|
||||||
|
}
|
||||||
|
var routingAvailable []accountWithLoad
|
||||||
|
for _, acc := range routingCandidates {
|
||||||
|
loadInfo := routingLoadMap[acc.ID]
|
||||||
|
if loadInfo == nil {
|
||||||
|
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
|
||||||
|
}
|
||||||
|
if loadInfo.LoadRate < 100 {
|
||||||
|
routingAvailable = append(routingAvailable, accountWithLoad{account: acc, loadInfo: loadInfo})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(routingAvailable) > 0 {
|
||||||
|
// 排序:优先级 > 负载率 > 最后使用时间
|
||||||
|
sort.SliceStable(routingAvailable, func(i, j int) bool {
|
||||||
|
a, b := routingAvailable[i], routingAvailable[j]
|
||||||
|
if a.account.Priority != b.account.Priority {
|
||||||
|
return a.account.Priority < b.account.Priority
|
||||||
|
}
|
||||||
|
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
||||||
|
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
|
||||||
|
return true
|
||||||
|
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
|
||||||
|
return false
|
||||||
|
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// 4. 尝试获取槽位
|
||||||
|
for _, item := range routingAvailable {
|
||||||
|
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||||
|
if err == nil && result.Acquired {
|
||||||
|
// 会话数量限制检查
|
||||||
|
if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
|
||||||
|
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if sessionHash != "" && s.cache != nil {
|
||||||
|
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
|
||||||
|
}
|
||||||
|
if s.debugModelRoutingEnabled() {
|
||||||
|
log.Printf("[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
|
||||||
|
}
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: item.account,
|
||||||
|
Acquired: true,
|
||||||
|
ReleaseFunc: result.ReleaseFunc,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. 所有路由账号槽位满,返回等待计划(选择负载最低的)
|
||||||
|
acc := routingAvailable[0].account
|
||||||
|
if s.debugModelRoutingEnabled() {
|
||||||
|
log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), acc.ID)
|
||||||
|
}
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: acc,
|
||||||
|
WaitPlan: &AccountWaitPlan{
|
||||||
|
AccountID: acc.ID,
|
||||||
|
MaxConcurrency: acc.Concurrency,
|
||||||
|
Timeout: cfg.StickySessionWaitTimeout,
|
||||||
|
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
// 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退
|
||||||
|
log.Printf("[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============
|
||||||
|
if len(routingAccountIDs) == 0 && sessionHash != "" && s.cache != nil {
|
||||||
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
|
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||||
account, ok := accountByID[accountID]
|
account, ok := accountByID[accountID]
|
||||||
if ok && s.isAccountInGroup(account, groupID) &&
|
if ok && s.isAccountInGroup(account, groupID) &&
|
||||||
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||||
account.IsSchedulableForModel(requestedModel) &&
|
account.IsSchedulableForModel(requestedModel) &&
|
||||||
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
|
||||||
|
s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
|
||||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||||
if err == nil && result.Acquired {
|
if err == nil && result.Acquired {
|
||||||
|
// 会话数量限制检查
|
||||||
|
if !s.checkAndRegisterSession(ctx, account, sessionUUID) {
|
||||||
|
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
||||||
|
} else {
|
||||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||||
return &AccountSelectionResult{
|
return &AccountSelectionResult{
|
||||||
Account: account,
|
Account: account,
|
||||||
@@ -896,6 +1147,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
ReleaseFunc: result.ReleaseFunc,
|
ReleaseFunc: result.ReleaseFunc,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||||
@@ -935,6 +1187,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// 窗口费用检查(非粘性会话路径)
|
||||||
|
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
candidates = append(candidates, acc)
|
candidates = append(candidates, acc)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -952,7 +1208,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
|
|
||||||
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
|
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth, sessionUUID); ok {
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -1001,6 +1257,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
for _, item := range available {
|
for _, item := range available {
|
||||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||||
if err == nil && result.Acquired {
|
if err == nil && result.Acquired {
|
||||||
|
// 会话数量限制检查
|
||||||
|
if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
|
||||||
|
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||||
|
continue
|
||||||
|
}
|
||||||
if sessionHash != "" && s.cache != nil {
|
if sessionHash != "" && s.cache != nil {
|
||||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
|
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
|
||||||
}
|
}
|
||||||
@@ -1030,13 +1291,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
return nil, errors.New("no available accounts")
|
return nil, errors.New("no available accounts")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
|
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool, sessionUUID string) (*AccountSelectionResult, bool) {
|
||||||
ordered := append([]*Account(nil), candidates...)
|
ordered := append([]*Account(nil), candidates...)
|
||||||
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
|
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
|
||||||
|
|
||||||
for _, acc := range ordered {
|
for _, acc := range ordered {
|
||||||
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
||||||
if err == nil && result.Acquired {
|
if err == nil && result.Acquired {
|
||||||
|
// 会话数量限制检查
|
||||||
|
if !s.checkAndRegisterSession(ctx, acc, sessionUUID) {
|
||||||
|
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||||
|
continue
|
||||||
|
}
|
||||||
if sessionHash != "" && s.cache != nil {
|
if sessionHash != "" && s.cache != nil {
|
||||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL)
|
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL)
|
||||||
}
|
}
|
||||||
@@ -1093,6 +1359,32 @@ func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*
|
|||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupID *int64, requestedModel string, platform string) []int64 {
|
||||||
|
if groupID == nil || requestedModel == "" || platform != PlatformAnthropic {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
group, err := s.resolveGroupByID(ctx, *groupID)
|
||||||
|
if err != nil || group == nil {
|
||||||
|
if s.debugModelRoutingEnabled() {
|
||||||
|
log.Printf("[ModelRoutingDebug] resolve group failed: group_id=%v model=%s platform=%s err=%v", derefGroupID(groupID), requestedModel, platform, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Preserve existing behavior: model routing only applies to anthropic groups.
|
||||||
|
if group.Platform != PlatformAnthropic {
|
||||||
|
if s.debugModelRoutingEnabled() {
|
||||||
|
log.Printf("[ModelRoutingDebug] skip: non-anthropic group platform: group_id=%d group_platform=%s model=%s", group.ID, group.Platform, requestedModel)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
ids := group.GetRoutingAccountIDs(requestedModel)
|
||||||
|
if s.debugModelRoutingEnabled() {
|
||||||
|
log.Printf("[ModelRoutingDebug] routing lookup: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v",
|
||||||
|
group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), ids)
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
func (s *GatewayService) resolveGatewayGroup(ctx context.Context, groupID *int64) (*Group, *int64, error) {
|
func (s *GatewayService) resolveGatewayGroup(ctx context.Context, groupID *int64) (*Group, *int64, error) {
|
||||||
if groupID == nil {
|
if groupID == nil {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
@@ -1242,6 +1534,107 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in
|
|||||||
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度
|
||||||
|
// 仅适用于 Anthropic OAuth/SetupToken 账号
|
||||||
|
// 返回 true 表示可调度,false 表示不可调度
|
||||||
|
func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context, account *Account, isSticky bool) bool {
|
||||||
|
// 只检查 Anthropic OAuth/SetupToken 账号
|
||||||
|
if !account.IsAnthropicOAuthOrSetupToken() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
limit := account.GetWindowCostLimit()
|
||||||
|
if limit <= 0 {
|
||||||
|
return true // 未启用窗口费用限制
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试从缓存获取窗口费用
|
||||||
|
var currentCost float64
|
||||||
|
if s.sessionLimitCache != nil {
|
||||||
|
if cost, hit, err := s.sessionLimitCache.GetWindowCost(ctx, account.ID); err == nil && hit {
|
||||||
|
currentCost = cost
|
||||||
|
goto checkSchedulability
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 缓存未命中,从数据库查询
|
||||||
|
{
|
||||||
|
var startTime time.Time
|
||||||
|
if account.SessionWindowStart != nil {
|
||||||
|
startTime = *account.SessionWindowStart
|
||||||
|
} else {
|
||||||
|
startTime = time.Now().Add(-5 * time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
|
||||||
|
if err != nil {
|
||||||
|
// 失败开放:查询失败时允许调度
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用标准费用(不含账号倍率)
|
||||||
|
currentCost = stats.StandardCost
|
||||||
|
|
||||||
|
// 设置缓存(忽略错误)
|
||||||
|
if s.sessionLimitCache != nil {
|
||||||
|
_ = s.sessionLimitCache.SetWindowCost(ctx, account.ID, currentCost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
checkSchedulability:
|
||||||
|
schedulability := account.CheckWindowCostSchedulability(currentCost)
|
||||||
|
|
||||||
|
switch schedulability {
|
||||||
|
case WindowCostSchedulable:
|
||||||
|
return true
|
||||||
|
case WindowCostStickyOnly:
|
||||||
|
return isSticky
|
||||||
|
case WindowCostNotSchedulable:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
|
||||||
|
// 仅适用于 Anthropic OAuth/SetupToken 账号
|
||||||
|
// 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
|
||||||
|
func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionUUID string) bool {
|
||||||
|
// 只检查 Anthropic OAuth/SetupToken 账号
|
||||||
|
if !account.IsAnthropicOAuthOrSetupToken() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
maxSessions := account.GetMaxSessions()
|
||||||
|
if maxSessions <= 0 || sessionUUID == "" {
|
||||||
|
return true // 未启用会话限制或无会话ID
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.sessionLimitCache == nil {
|
||||||
|
return true // 缓存不可用时允许通过
|
||||||
|
}
|
||||||
|
|
||||||
|
idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute
|
||||||
|
|
||||||
|
allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionUUID, maxSessions, idleTimeout)
|
||||||
|
if err != nil {
|
||||||
|
// 失败开放:缓存错误时允许通过
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return allowed
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractSessionUUID 从 metadata.user_id 中提取会话 UUID
|
||||||
|
// 格式: user_{64位hex}_account__session_{uuid}
|
||||||
|
func extractSessionUUID(metadataUserID string) string {
|
||||||
|
if metadataUserID == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if match := sessionIDRegex.FindStringSubmatch(metadataUserID); len(match) > 1 {
|
||||||
|
return match[1]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||||
if s.schedulerSnapshot != nil {
|
if s.schedulerSnapshot != nil {
|
||||||
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||||
@@ -1274,6 +1667,116 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
|
|||||||
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
|
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
|
||||||
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
||||||
preferOAuth := platform == PlatformGemini
|
preferOAuth := platform == PlatformGemini
|
||||||
|
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
|
||||||
|
|
||||||
|
var accounts []Account
|
||||||
|
accountsLoaded := false
|
||||||
|
|
||||||
|
// ============ Model Routing (legacy path): apply before sticky session ============
|
||||||
|
// When load-awareness is disabled (e.g. concurrency service not configured), we still honor model routing
|
||||||
|
// so switching model can switch upstream account within the same sticky session.
|
||||||
|
if len(routingAccountIDs) > 0 {
|
||||||
|
if s.debugModelRoutingEnabled() {
|
||||||
|
log.Printf("[ModelRoutingDebug] legacy routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v",
|
||||||
|
derefGroupID(groupID), requestedModel, platform, shortSessionHash(sessionHash), routingAccountIDs)
|
||||||
|
}
|
||||||
|
// 1) Sticky session only applies if the bound account is within the routing set.
|
||||||
|
if sessionHash != "" && s.cache != nil {
|
||||||
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
|
if err == nil && accountID > 0 && containsInt64(routingAccountIDs, accountID) {
|
||||||
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
|
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||||
|
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
||||||
|
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||||
|
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||||
|
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||||
|
}
|
||||||
|
if s.debugModelRoutingEnabled() {
|
||||||
|
log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||||
|
}
|
||||||
|
return account, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) Select an account from the routed candidates.
|
||||||
|
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||||
|
if hasForcePlatform && forcePlatform == "" {
|
||||||
|
hasForcePlatform = false
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
accounts, _, err = s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||||
|
}
|
||||||
|
accountsLoaded = true
|
||||||
|
|
||||||
|
routingSet := make(map[int64]struct{}, len(routingAccountIDs))
|
||||||
|
for _, id := range routingAccountIDs {
|
||||||
|
if id > 0 {
|
||||||
|
routingSet[id] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var selected *Account
|
||||||
|
for i := range accounts {
|
||||||
|
acc := &accounts[i]
|
||||||
|
if _, ok := routingSet[acc.ID]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
|
||||||
|
// avoid selecting accounts that were recently rate-limited/overloaded.
|
||||||
|
if !acc.IsSchedulable() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !acc.IsSchedulableForModel(requestedModel) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if selected == nil {
|
||||||
|
selected = acc
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if acc.Priority < selected.Priority {
|
||||||
|
selected = acc
|
||||||
|
} else if acc.Priority == selected.Priority {
|
||||||
|
switch {
|
||||||
|
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
|
||||||
|
selected = acc
|
||||||
|
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||||||
|
// keep selected (never used is preferred)
|
||||||
|
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||||
|
if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
|
||||||
|
selected = acc
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
||||||
|
selected = acc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if selected != nil {
|
||||||
|
if sessionHash != "" && s.cache != nil {
|
||||||
|
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||||
|
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.debugModelRoutingEnabled() {
|
||||||
|
log.Printf("[ModelRoutingDebug] legacy routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID)
|
||||||
|
}
|
||||||
|
return selected, nil
|
||||||
|
}
|
||||||
|
log.Printf("[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel)
|
||||||
|
}
|
||||||
|
|
||||||
// 1. 查询粘性会话
|
// 1. 查询粘性会话
|
||||||
if sessionHash != "" && s.cache != nil {
|
if sessionHash != "" && s.cache != nil {
|
||||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
@@ -1292,14 +1795,17 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2. 获取可调度账号列表(单平台)
|
// 2. 获取可调度账号列表(单平台)
|
||||||
|
if !accountsLoaded {
|
||||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||||
if hasForcePlatform && forcePlatform == "" {
|
if hasForcePlatform && forcePlatform == "" {
|
||||||
hasForcePlatform = false
|
hasForcePlatform = false
|
||||||
}
|
}
|
||||||
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
var err error
|
||||||
|
accounts, _, err = s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 3. 按优先级+最久未用选择(考虑模型支持)
|
// 3. 按优先级+最久未用选择(考虑模型支持)
|
||||||
var selected *Account
|
var selected *Account
|
||||||
@@ -1364,6 +1870,115 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
|
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
|
||||||
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
|
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
|
||||||
preferOAuth := nativePlatform == PlatformGemini
|
preferOAuth := nativePlatform == PlatformGemini
|
||||||
|
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform)
|
||||||
|
|
||||||
|
var accounts []Account
|
||||||
|
accountsLoaded := false
|
||||||
|
|
||||||
|
// ============ Model Routing (legacy path): apply before sticky session ============
|
||||||
|
if len(routingAccountIDs) > 0 {
|
||||||
|
if s.debugModelRoutingEnabled() {
|
||||||
|
log.Printf("[ModelRoutingDebug] legacy mixed routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v",
|
||||||
|
derefGroupID(groupID), requestedModel, nativePlatform, shortSessionHash(sessionHash), routingAccountIDs)
|
||||||
|
}
|
||||||
|
// 1) Sticky session only applies if the bound account is within the routing set.
|
||||||
|
if sessionHash != "" && s.cache != nil {
|
||||||
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
|
if err == nil && accountID > 0 && containsInt64(routingAccountIDs, accountID) {
|
||||||
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
|
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||||
|
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
||||||
|
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||||
|
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||||
|
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||||
|
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||||
|
}
|
||||||
|
if s.debugModelRoutingEnabled() {
|
||||||
|
log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||||
|
}
|
||||||
|
return account, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) Select an account from the routed candidates.
|
||||||
|
var err error
|
||||||
|
accounts, _, err = s.listSchedulableAccounts(ctx, groupID, nativePlatform, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||||
|
}
|
||||||
|
accountsLoaded = true
|
||||||
|
|
||||||
|
routingSet := make(map[int64]struct{}, len(routingAccountIDs))
|
||||||
|
for _, id := range routingAccountIDs {
|
||||||
|
if id > 0 {
|
||||||
|
routingSet[id] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var selected *Account
|
||||||
|
for i := range accounts {
|
||||||
|
acc := &accounts[i]
|
||||||
|
if _, ok := routingSet[acc.ID]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
|
||||||
|
// avoid selecting accounts that were recently rate-limited/overloaded.
|
||||||
|
if !acc.IsSchedulable() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
|
||||||
|
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !acc.IsSchedulableForModel(requestedModel) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if selected == nil {
|
||||||
|
selected = acc
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if acc.Priority < selected.Priority {
|
||||||
|
selected = acc
|
||||||
|
} else if acc.Priority == selected.Priority {
|
||||||
|
switch {
|
||||||
|
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
|
||||||
|
selected = acc
|
||||||
|
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||||||
|
// keep selected (never used is preferred)
|
||||||
|
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||||
|
if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
|
||||||
|
selected = acc
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
||||||
|
selected = acc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if selected != nil {
|
||||||
|
if sessionHash != "" && s.cache != nil {
|
||||||
|
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||||
|
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.debugModelRoutingEnabled() {
|
||||||
|
log.Printf("[ModelRoutingDebug] legacy mixed routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID)
|
||||||
|
}
|
||||||
|
return selected, nil
|
||||||
|
}
|
||||||
|
log.Printf("[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel)
|
||||||
|
}
|
||||||
|
|
||||||
// 1. 查询粘性会话
|
// 1. 查询粘性会话
|
||||||
if sessionHash != "" && s.cache != nil {
|
if sessionHash != "" && s.cache != nil {
|
||||||
@@ -1385,10 +2000,13 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2. 获取可调度账号列表
|
// 2. 获取可调度账号列表
|
||||||
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, nativePlatform, false)
|
if !accountsLoaded {
|
||||||
|
var err error
|
||||||
|
accounts, _, err = s.listSchedulableAccounts(ctx, groupID, nativePlatform, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
|
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
|
||||||
var selected *Account
|
var selected *Account
|
||||||
@@ -1488,6 +2106,16 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) {
|
func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) {
|
||||||
|
// 对于 Anthropic OAuth 账号,使用 ClaudeTokenProvider 获取缓存的 token
|
||||||
|
if account.Platform == PlatformAnthropic && account.Type == AccountTypeOAuth && s.claudeTokenProvider != nil {
|
||||||
|
accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
return accessToken, "oauth", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 其他情况(Gemini 有自己的 TokenProvider,setup-token 类型等)直接从账号读取
|
||||||
accessToken := account.GetCredential("access_token")
|
accessToken := account.GetCredential("access_token")
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return "", "", errors.New("access_token not found in credentials")
|
return "", "", errors.New("access_token not found in credentials")
|
||||||
@@ -1901,6 +2529,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
retryStart := time.Now()
|
retryStart := time.Now()
|
||||||
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
||||||
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
|
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
|
||||||
|
// Capture upstream request body for ops retry of this attempt.
|
||||||
|
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1918,6 +2548,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: 0,
|
UpstreamStatusCode: 0,
|
||||||
Kind: "request_error",
|
Kind: "request_error",
|
||||||
Message: safeErr,
|
Message: safeErr,
|
||||||
@@ -1942,6 +2573,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: "signature_error",
|
Kind: "signature_error",
|
||||||
@@ -1993,6 +2625,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: retryResp.StatusCode,
|
UpstreamStatusCode: retryResp.StatusCode,
|
||||||
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
|
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
|
||||||
Kind: "signature_retry_thinking",
|
Kind: "signature_retry_thinking",
|
||||||
@@ -2021,6 +2654,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: 0,
|
UpstreamStatusCode: 0,
|
||||||
Kind: "signature_retry_tools_request_error",
|
Kind: "signature_retry_tools_request_error",
|
||||||
Message: sanitizeUpstreamErrorMessage(retryErr2.Error()),
|
Message: sanitizeUpstreamErrorMessage(retryErr2.Error()),
|
||||||
@@ -2079,6 +2713,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: "retry",
|
Kind: "retry",
|
||||||
@@ -2127,6 +2762,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: "retry_exhausted_failover",
|
Kind: "retry_exhausted_failover",
|
||||||
@@ -2193,6 +2829,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: "failover_on_400",
|
Kind: "failover_on_400",
|
||||||
@@ -3283,6 +3920,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
if result.ImageSize != "" {
|
if result.ImageSize != "" {
|
||||||
imageSize = &result.ImageSize
|
imageSize = &result.ImageSize
|
||||||
}
|
}
|
||||||
|
accountRateMultiplier := account.BillingRateMultiplier()
|
||||||
usageLog := &UsageLog{
|
usageLog := &UsageLog{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
APIKeyID: apiKey.ID,
|
APIKeyID: apiKey.ID,
|
||||||
@@ -3300,6 +3938,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
TotalCost: cost.TotalCost,
|
TotalCost: cost.TotalCost,
|
||||||
ActualCost: cost.ActualCost,
|
ActualCost: cost.ActualCost,
|
||||||
RateMultiplier: multiplier,
|
RateMultiplier: multiplier,
|
||||||
|
AccountRateMultiplier: &accountRateMultiplier,
|
||||||
BillingType: billingType,
|
BillingType: billingType,
|
||||||
Stream: result.Stream,
|
Stream: result.Stream,
|
||||||
DurationMs: &durationMs,
|
DurationMs: &durationMs,
|
||||||
|
|||||||
@@ -545,12 +545,19 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
}
|
}
|
||||||
requestIDHeader = idHeader
|
requestIDHeader = idHeader
|
||||||
|
|
||||||
|
// Capture upstream request body for ops retry of this attempt.
|
||||||
|
if c != nil {
|
||||||
|
// In this code path `body` is already the JSON sent to upstream.
|
||||||
|
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: 0,
|
UpstreamStatusCode: 0,
|
||||||
Kind: "request_error",
|
Kind: "request_error",
|
||||||
Message: safeErr,
|
Message: safeErr,
|
||||||
@@ -588,6 +595,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: upstreamReqID,
|
UpstreamRequestID: upstreamReqID,
|
||||||
Kind: "signature_error",
|
Kind: "signature_error",
|
||||||
@@ -662,6 +670,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: upstreamReqID,
|
UpstreamRequestID: upstreamReqID,
|
||||||
Kind: "retry",
|
Kind: "retry",
|
||||||
@@ -711,6 +720,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: upstreamReqID,
|
UpstreamRequestID: upstreamReqID,
|
||||||
Kind: "failover",
|
Kind: "failover",
|
||||||
@@ -737,6 +747,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: upstreamReqID,
|
UpstreamRequestID: upstreamReqID,
|
||||||
Kind: "failover",
|
Kind: "failover",
|
||||||
@@ -972,12 +983,19 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
}
|
}
|
||||||
requestIDHeader = idHeader
|
requestIDHeader = idHeader
|
||||||
|
|
||||||
|
// Capture upstream request body for ops retry of this attempt.
|
||||||
|
if c != nil {
|
||||||
|
// In this code path `body` is already the JSON sent to upstream.
|
||||||
|
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: 0,
|
UpstreamStatusCode: 0,
|
||||||
Kind: "request_error",
|
Kind: "request_error",
|
||||||
Message: safeErr,
|
Message: safeErr,
|
||||||
@@ -1036,6 +1054,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: upstreamReqID,
|
UpstreamRequestID: upstreamReqID,
|
||||||
Kind: "retry",
|
Kind: "retry",
|
||||||
@@ -1120,6 +1139,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: requestID,
|
UpstreamRequestID: requestID,
|
||||||
Kind: "failover",
|
Kind: "failover",
|
||||||
@@ -1143,6 +1163,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: requestID,
|
UpstreamRequestID: requestID,
|
||||||
Kind: "failover",
|
Kind: "failover",
|
||||||
@@ -1168,6 +1189,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: requestID,
|
UpstreamRequestID: requestID,
|
||||||
Kind: "http_error",
|
Kind: "http_error",
|
||||||
@@ -1300,6 +1322,7 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, acc
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: upstreamStatus,
|
UpstreamStatusCode: upstreamStatus,
|
||||||
UpstreamRequestID: upstreamRequestID,
|
UpstreamRequestID: upstreamRequestID,
|
||||||
Kind: "http_error",
|
Kind: "http_error",
|
||||||
|
|||||||
@@ -125,6 +125,9 @@ func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64,
|
|||||||
func (m *mockAccountRepoForGemini) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
func (m *mockAccountRepoForGemini) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -138,6 +141,9 @@ func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64)
|
|||||||
func (m *mockAccountRepoForGemini) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
func (m *mockAccountRepoForGemini) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (m *mockAccountRepoForGemini) ClearModelRateLimits(ctx context.Context, id int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ type GeminiTokenCache interface {
|
|||||||
// cacheKey should be stable for the token scope; for GeminiCli OAuth we primarily use project_id.
|
// cacheKey should be stable for the token scope; for GeminiCli OAuth we primarily use project_id.
|
||||||
GetAccessToken(ctx context.Context, cacheKey string) (string, error)
|
GetAccessToken(ctx context.Context, cacheKey string) (string, error)
|
||||||
SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error
|
SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error
|
||||||
|
DeleteAccessToken(ctx context.Context, cacheKey string) error
|
||||||
|
|
||||||
AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error)
|
AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error)
|
||||||
ReleaseRefreshLock(ctx context.Context, cacheKey string) error
|
ReleaseRefreshLock(ctx context.Context, cacheKey string) error
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
return "", errors.New("not a gemini oauth account")
|
return "", errors.New("not a gemini oauth account")
|
||||||
}
|
}
|
||||||
|
|
||||||
cacheKey := geminiTokenCacheKey(account)
|
cacheKey := GeminiTokenCacheKey(account)
|
||||||
|
|
||||||
// 1) Try cache first.
|
// 1) Try cache first.
|
||||||
if p.tokenCache != nil {
|
if p.tokenCache != nil {
|
||||||
@@ -151,10 +151,10 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
return accessToken, nil
|
return accessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func geminiTokenCacheKey(account *Account) string {
|
func GeminiTokenCacheKey(account *Account) string {
|
||||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||||
if projectID != "" {
|
if projectID != "" {
|
||||||
return projectID
|
return "gemini:" + projectID
|
||||||
}
|
}
|
||||||
return "account:" + strconv.FormatInt(account.ID, 10)
|
return "gemini:account:" + strconv.FormatInt(account.ID, 10)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import "time"
|
import (
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
type Group struct {
|
type Group struct {
|
||||||
ID int64
|
ID int64
|
||||||
@@ -27,6 +30,12 @@ type Group struct {
|
|||||||
ClaudeCodeOnly bool
|
ClaudeCodeOnly bool
|
||||||
FallbackGroupID *int64
|
FallbackGroupID *int64
|
||||||
|
|
||||||
|
// 模型路由配置
|
||||||
|
// key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*")
|
||||||
|
// value: 优先账号 ID 列表
|
||||||
|
ModelRouting map[string][]int64
|
||||||
|
ModelRoutingEnabled bool
|
||||||
|
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
|
|
||||||
@@ -90,3 +99,41 @@ func IsGroupContextValid(group *Group) bool {
|
|||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetRoutingAccountIDs 根据请求模型获取路由账号 ID 列表
|
||||||
|
// 返回匹配的优先账号 ID 列表,如果没有匹配规则则返回 nil
|
||||||
|
func (g *Group) GetRoutingAccountIDs(requestedModel string) []int64 {
|
||||||
|
if !g.ModelRoutingEnabled || len(g.ModelRouting) == 0 || requestedModel == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. 精确匹配优先
|
||||||
|
if accountIDs, ok := g.ModelRouting[requestedModel]; ok && len(accountIDs) > 0 {
|
||||||
|
return accountIDs
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 通配符匹配(前缀匹配)
|
||||||
|
for pattern, accountIDs := range g.ModelRouting {
|
||||||
|
if matchModelPattern(pattern, requestedModel) && len(accountIDs) > 0 {
|
||||||
|
return accountIDs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchModelPattern 检查模型是否匹配模式
|
||||||
|
// 支持 * 通配符,如 "claude-opus-*" 匹配 "claude-opus-4-20250514"
|
||||||
|
func matchModelPattern(pattern, model string) bool {
|
||||||
|
if pattern == model {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理 * 通配符(仅支持末尾通配符)
|
||||||
|
if strings.HasSuffix(pattern, "*") {
|
||||||
|
prefix := strings.TrimSuffix(pattern, "*")
|
||||||
|
return strings.HasPrefix(model, prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
56
backend/internal/service/model_rate_limit.go
Normal file
56
backend/internal/service/model_rate_limit.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const modelRateLimitsKey = "model_rate_limits"
|
||||||
|
const modelRateLimitScopeClaudeSonnet = "claude_sonnet"
|
||||||
|
|
||||||
|
func resolveModelRateLimitScope(requestedModel string) (string, bool) {
|
||||||
|
model := strings.ToLower(strings.TrimSpace(requestedModel))
|
||||||
|
if model == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
model = strings.TrimPrefix(model, "models/")
|
||||||
|
if strings.Contains(model, "sonnet") {
|
||||||
|
return modelRateLimitScopeClaudeSonnet, true
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) isModelRateLimited(requestedModel string) bool {
|
||||||
|
scope, ok := resolveModelRateLimitScope(requestedModel)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
resetAt := a.modelRateLimitResetAt(scope)
|
||||||
|
if resetAt == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return time.Now().Before(*resetAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) modelRateLimitResetAt(scope string) *time.Time {
|
||||||
|
if a == nil || a.Extra == nil || scope == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
rawLimits, ok := a.Extra[modelRateLimitsKey].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
rawLimit, ok := rawLimits[scope].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
resetAtRaw, ok := rawLimit["rate_limit_reset_at"].(string)
|
||||||
|
if !ok || strings.TrimSpace(resetAtRaw) == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
resetAt, err := time.Parse(time.RFC3339, resetAtRaw)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &resetAt
|
||||||
|
}
|
||||||
@@ -93,6 +93,8 @@ type OpenAIGatewayService struct {
|
|||||||
billingCacheService *BillingCacheService
|
billingCacheService *BillingCacheService
|
||||||
httpUpstream HTTPUpstream
|
httpUpstream HTTPUpstream
|
||||||
deferredService *DeferredService
|
deferredService *DeferredService
|
||||||
|
openAITokenProvider *OpenAITokenProvider
|
||||||
|
toolCorrector *CodexToolCorrector
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOpenAIGatewayService creates a new OpenAIGatewayService
|
// NewOpenAIGatewayService creates a new OpenAIGatewayService
|
||||||
@@ -110,6 +112,7 @@ func NewOpenAIGatewayService(
|
|||||||
billingCacheService *BillingCacheService,
|
billingCacheService *BillingCacheService,
|
||||||
httpUpstream HTTPUpstream,
|
httpUpstream HTTPUpstream,
|
||||||
deferredService *DeferredService,
|
deferredService *DeferredService,
|
||||||
|
openAITokenProvider *OpenAITokenProvider,
|
||||||
) *OpenAIGatewayService {
|
) *OpenAIGatewayService {
|
||||||
return &OpenAIGatewayService{
|
return &OpenAIGatewayService{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
@@ -125,6 +128,8 @@ func NewOpenAIGatewayService(
|
|||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
httpUpstream: httpUpstream,
|
httpUpstream: httpUpstream,
|
||||||
deferredService: deferredService,
|
deferredService: deferredService,
|
||||||
|
openAITokenProvider: openAITokenProvider,
|
||||||
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -503,6 +508,15 @@ func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig
|
|||||||
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
|
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
|
||||||
switch account.Type {
|
switch account.Type {
|
||||||
case AccountTypeOAuth:
|
case AccountTypeOAuth:
|
||||||
|
// 使用 TokenProvider 获取缓存的 token
|
||||||
|
if s.openAITokenProvider != nil {
|
||||||
|
accessToken, err := s.openAITokenProvider.GetAccessToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
return accessToken, "oauth", nil
|
||||||
|
}
|
||||||
|
// 降级:TokenProvider 未配置时直接从账号读取
|
||||||
accessToken := account.GetOpenAIAccessToken()
|
accessToken := account.GetOpenAIAccessToken()
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return "", "", errors.New("access_token not found in credentials")
|
return "", "", errors.New("access_token not found in credentials")
|
||||||
@@ -664,6 +678,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
proxyURL = account.Proxy.URL()
|
proxyURL = account.Proxy.URL()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Capture upstream request body for ops retry of this attempt.
|
||||||
|
if c != nil {
|
||||||
|
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
// Send request
|
// Send request
|
||||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -673,6 +692,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: 0,
|
UpstreamStatusCode: 0,
|
||||||
Kind: "request_error",
|
Kind: "request_error",
|
||||||
Message: safeErr,
|
Message: safeErr,
|
||||||
@@ -707,6 +727,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: "failover",
|
Kind: "failover",
|
||||||
@@ -864,6 +885,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: "http_error",
|
Kind: "http_error",
|
||||||
@@ -894,6 +916,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: kind,
|
Kind: kind,
|
||||||
@@ -1097,6 +1120,12 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
|
||||||
|
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
|
||||||
|
data = correctedData
|
||||||
|
line = "data: " + correctedData
|
||||||
|
}
|
||||||
|
|
||||||
// 写入客户端(客户端断开后继续 drain 上游)
|
// 写入客户端(客户端断开后继续 drain 上游)
|
||||||
if !clientDisconnected {
|
if !clientDisconnected {
|
||||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||||
@@ -1199,6 +1228,20 @@ func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel st
|
|||||||
return line
|
return line
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// correctToolCallsInResponseBody 修正响应体中的工具调用
|
||||||
|
func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byte {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyStr := string(body)
|
||||||
|
corrected, changed := s.toolCorrector.CorrectToolCallsInSSEData(bodyStr)
|
||||||
|
if changed {
|
||||||
|
return []byte(corrected)
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
|
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
|
||||||
// Parse response.completed event for usage (OpenAI Responses format)
|
// Parse response.completed event for usage (OpenAI Responses format)
|
||||||
var event struct {
|
var event struct {
|
||||||
@@ -1302,6 +1345,8 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
|
|||||||
if originalModel != mappedModel {
|
if originalModel != mappedModel {
|
||||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||||
}
|
}
|
||||||
|
// Correct tool calls in final response
|
||||||
|
body = s.correctToolCallsInResponseBody(body)
|
||||||
} else {
|
} else {
|
||||||
usage = s.parseSSEUsageFromBody(bodyText)
|
usage = s.parseSSEUsageFromBody(bodyText)
|
||||||
if originalModel != mappedModel {
|
if originalModel != mappedModel {
|
||||||
@@ -1470,6 +1515,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
|
|
||||||
// Create usage log
|
// Create usage log
|
||||||
durationMs := int(result.Duration.Milliseconds())
|
durationMs := int(result.Duration.Milliseconds())
|
||||||
|
accountRateMultiplier := account.BillingRateMultiplier()
|
||||||
usageLog := &UsageLog{
|
usageLog := &UsageLog{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
APIKeyID: apiKey.ID,
|
APIKeyID: apiKey.ID,
|
||||||
@@ -1487,6 +1533,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
TotalCost: cost.TotalCost,
|
TotalCost: cost.TotalCost,
|
||||||
ActualCost: cost.ActualCost,
|
ActualCost: cost.ActualCost,
|
||||||
RateMultiplier: multiplier,
|
RateMultiplier: multiplier,
|
||||||
|
AccountRateMultiplier: &accountRateMultiplier,
|
||||||
BillingType: billingType,
|
BillingType: billingType,
|
||||||
Stream: result.Stream,
|
Stream: result.Stream,
|
||||||
DurationMs: &durationMs,
|
DurationMs: &durationMs,
|
||||||
|
|||||||
@@ -0,0 +1,133 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestOpenAIGatewayService_ToolCorrection 测试 OpenAIGatewayService 中的工具修正集成
|
||||||
|
func TestOpenAIGatewayService_ToolCorrection(t *testing.T) {
|
||||||
|
// 创建一个简单的 service 实例来测试工具修正
|
||||||
|
service := &OpenAIGatewayService{
|
||||||
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input []byte
|
||||||
|
expected string
|
||||||
|
changed bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "correct apply_patch in response body",
|
||||||
|
input: []byte(`{
|
||||||
|
"choices": [{
|
||||||
|
"message": {
|
||||||
|
"tool_calls": [{
|
||||||
|
"function": {"name": "apply_patch"}
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}`),
|
||||||
|
expected: "edit",
|
||||||
|
changed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "correct update_plan in response body",
|
||||||
|
input: []byte(`{
|
||||||
|
"tool_calls": [{
|
||||||
|
"function": {"name": "update_plan"}
|
||||||
|
}]
|
||||||
|
}`),
|
||||||
|
expected: "todowrite",
|
||||||
|
changed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no change for correct tool name",
|
||||||
|
input: []byte(`{
|
||||||
|
"tool_calls": [{
|
||||||
|
"function": {"name": "edit"}
|
||||||
|
}]
|
||||||
|
}`),
|
||||||
|
expected: "edit",
|
||||||
|
changed: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := service.correctToolCallsInResponseBody(tt.input)
|
||||||
|
resultStr := string(result)
|
||||||
|
|
||||||
|
// 检查是否包含期望的工具名称
|
||||||
|
if !strings.Contains(resultStr, tt.expected) {
|
||||||
|
t.Errorf("expected result to contain %q, got %q", tt.expected, resultStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 对于预期有变化的情况,验证结果与输入不同
|
||||||
|
if tt.changed && string(result) == string(tt.input) {
|
||||||
|
t.Error("expected result to be different from input, but they are the same")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 对于预期无变化的情况,验证结果与输入相同
|
||||||
|
if !tt.changed && string(result) != string(tt.input) {
|
||||||
|
t.Error("expected result to be same as input, but they are different")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOpenAIGatewayService_ToolCorrectorInitialization 测试工具修正器是否正确初始化
|
||||||
|
func TestOpenAIGatewayService_ToolCorrectorInitialization(t *testing.T) {
|
||||||
|
service := &OpenAIGatewayService{
|
||||||
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if service.toolCorrector == nil {
|
||||||
|
t.Fatal("toolCorrector should not be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试修正器可以正常工作
|
||||||
|
data := `{"tool_calls":[{"function":{"name":"apply_patch"}}]}`
|
||||||
|
corrected, changed := service.toolCorrector.CorrectToolCallsInSSEData(data)
|
||||||
|
|
||||||
|
if !changed {
|
||||||
|
t.Error("expected tool call to be corrected")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(corrected, "edit") {
|
||||||
|
t.Errorf("expected corrected data to contain 'edit', got %q", corrected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToolCorrectionStats 测试工具修正统计功能
|
||||||
|
func TestToolCorrectionStats(t *testing.T) {
|
||||||
|
service := &OpenAIGatewayService{
|
||||||
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行几次修正
|
||||||
|
testData := []string{
|
||||||
|
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`,
|
||||||
|
`{"tool_calls":[{"function":{"name":"update_plan"}}]}`,
|
||||||
|
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, data := range testData {
|
||||||
|
service.toolCorrector.CorrectToolCallsInSSEData(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := service.toolCorrector.GetStats()
|
||||||
|
|
||||||
|
if stats.TotalCorrected != 3 {
|
||||||
|
t.Errorf("expected 3 corrections, got %d", stats.TotalCorrected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats.CorrectionsByTool["apply_patch->edit"] != 2 {
|
||||||
|
t.Errorf("expected 2 apply_patch->edit corrections, got %d", stats.CorrectionsByTool["apply_patch->edit"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats.CorrectionsByTool["update_plan->todowrite"] != 1 {
|
||||||
|
t.Errorf("expected 1 update_plan->todowrite correction, got %d", stats.CorrectionsByTool["update_plan->todowrite"])
|
||||||
|
}
|
||||||
|
}
|
||||||
189
backend/internal/service/openai_token_provider.go
Normal file
189
backend/internal/service/openai_token_provider.go
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
openAITokenRefreshSkew = 3 * time.Minute
|
||||||
|
openAITokenCacheSkew = 5 * time.Minute
|
||||||
|
openAILockWaitTime = 200 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
||||||
|
type OpenAITokenCache = GeminiTokenCache
|
||||||
|
|
||||||
|
// OpenAITokenProvider 管理 OpenAI OAuth 账户的 access_token
|
||||||
|
type OpenAITokenProvider struct {
|
||||||
|
accountRepo AccountRepository
|
||||||
|
tokenCache OpenAITokenCache
|
||||||
|
openAIOAuthService *OpenAIOAuthService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOpenAITokenProvider(
|
||||||
|
accountRepo AccountRepository,
|
||||||
|
tokenCache OpenAITokenCache,
|
||||||
|
openAIOAuthService *OpenAIOAuthService,
|
||||||
|
) *OpenAITokenProvider {
|
||||||
|
return &OpenAITokenProvider{
|
||||||
|
accountRepo: accountRepo,
|
||||||
|
tokenCache: tokenCache,
|
||||||
|
openAIOAuthService: openAIOAuthService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccessToken 获取有效的 access_token
|
||||||
|
func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||||
|
if account == nil {
|
||||||
|
return "", errors.New("account is nil")
|
||||||
|
}
|
||||||
|
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
|
||||||
|
return "", errors.New("not an openai oauth account")
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheKey := OpenAITokenCacheKey(account)
|
||||||
|
|
||||||
|
// 1. 先尝试缓存
|
||||||
|
if p.tokenCache != nil {
|
||||||
|
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||||
|
slog.Debug("openai_token_cache_hit", "account_id", account.ID)
|
||||||
|
return token, nil
|
||||||
|
} else if err != nil {
|
||||||
|
slog.Warn("openai_token_cache_get_failed", "account_id", account.ID, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("openai_token_cache_miss", "account_id", account.ID)
|
||||||
|
|
||||||
|
// 2. 如果即将过期则刷新
|
||||||
|
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||||
|
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
|
||||||
|
refreshFailed := false
|
||||||
|
if needsRefresh && p.tokenCache != nil {
|
||||||
|
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||||
|
if lockErr == nil && locked {
|
||||||
|
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||||
|
|
||||||
|
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
|
||||||
|
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从数据库获取最新账户信息
|
||||||
|
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||||
|
if err == nil && fresh != nil {
|
||||||
|
account = fresh
|
||||||
|
}
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
|
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||||
|
if p.openAIOAuthService == nil {
|
||||||
|
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
||||||
|
refreshFailed = true // 无法刷新,标记失败
|
||||||
|
} else {
|
||||||
|
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
|
||||||
|
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
|
||||||
|
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
||||||
|
} else {
|
||||||
|
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
|
for k, v := range account.Credentials {
|
||||||
|
if _, exists := newCredentials[k]; !exists {
|
||||||
|
newCredentials[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
account.Credentials = newCredentials
|
||||||
|
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||||
|
slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
||||||
|
}
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if lockErr != nil {
|
||||||
|
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
|
||||||
|
slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
|
||||||
|
|
||||||
|
// 检查 ctx 是否已取消
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return "", ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从数据库获取最新账户信息
|
||||||
|
if p.accountRepo != nil {
|
||||||
|
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||||
|
if err == nil && fresh != nil {
|
||||||
|
account = fresh
|
||||||
|
}
|
||||||
|
}
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
|
|
||||||
|
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
|
||||||
|
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||||
|
if p.openAIOAuthService == nil {
|
||||||
|
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
||||||
|
refreshFailed = true
|
||||||
|
} else {
|
||||||
|
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
|
||||||
|
refreshFailed = true
|
||||||
|
} else {
|
||||||
|
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
|
for k, v := range account.Credentials {
|
||||||
|
if _, exists := newCredentials[k]; !exists {
|
||||||
|
newCredentials[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
account.Credentials = newCredentials
|
||||||
|
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||||
|
slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
||||||
|
}
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
|
||||||
|
time.Sleep(openAILockWaitTime)
|
||||||
|
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||||
|
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken := account.GetOpenAIAccessToken()
|
||||||
|
if strings.TrimSpace(accessToken) == "" {
|
||||||
|
return "", errors.New("access_token not found in credentials")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 存入缓存
|
||||||
|
if p.tokenCache != nil {
|
||||||
|
ttl := 30 * time.Minute
|
||||||
|
if refreshFailed {
|
||||||
|
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
|
||||||
|
ttl = time.Minute
|
||||||
|
slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
||||||
|
} else if expiresAt != nil {
|
||||||
|
until := time.Until(*expiresAt)
|
||||||
|
switch {
|
||||||
|
case until > openAITokenCacheSkew:
|
||||||
|
ttl = until - openAITokenCacheSkew
|
||||||
|
case until > 0:
|
||||||
|
ttl = until
|
||||||
|
default:
|
||||||
|
ttl = time.Minute
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
|
||||||
|
slog.Warn("openai_token_cache_set_failed", "account_id", account.ID, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return accessToken, nil
|
||||||
|
}
|
||||||
810
backend/internal/service/openai_token_provider_test.go
Normal file
810
backend/internal/service/openai_token_provider_test.go
Normal file
@@ -0,0 +1,810 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// openAITokenCacheStub implements OpenAITokenCache for testing
|
||||||
|
type openAITokenCacheStub struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
tokens map[string]string
|
||||||
|
getErr error
|
||||||
|
setErr error
|
||||||
|
deleteErr error
|
||||||
|
lockAcquired bool
|
||||||
|
lockErr error
|
||||||
|
releaseLockErr error
|
||||||
|
getCalled int32
|
||||||
|
setCalled int32
|
||||||
|
lockCalled int32
|
||||||
|
unlockCalled int32
|
||||||
|
simulateLockRace bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newOpenAITokenCacheStub() *openAITokenCacheStub {
|
||||||
|
return &openAITokenCacheStub{
|
||||||
|
tokens: make(map[string]string),
|
||||||
|
lockAcquired: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAITokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
|
||||||
|
atomic.AddInt32(&s.getCalled, 1)
|
||||||
|
if s.getErr != nil {
|
||||||
|
return "", s.getErr
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
return s.tokens[cacheKey], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAITokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
|
||||||
|
atomic.AddInt32(&s.setCalled, 1)
|
||||||
|
if s.setErr != nil {
|
||||||
|
return s.setErr
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.tokens[cacheKey] = token
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAITokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
|
||||||
|
if s.deleteErr != nil {
|
||||||
|
return s.deleteErr
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
delete(s.tokens, cacheKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAITokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
|
||||||
|
atomic.AddInt32(&s.lockCalled, 1)
|
||||||
|
if s.lockErr != nil {
|
||||||
|
return false, s.lockErr
|
||||||
|
}
|
||||||
|
if s.simulateLockRace {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return s.lockAcquired, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAITokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
|
||||||
|
atomic.AddInt32(&s.unlockCalled, 1)
|
||||||
|
return s.releaseLockErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// openAIAccountRepoStub is a minimal stub implementing only the methods used by OpenAITokenProvider
|
||||||
|
type openAIAccountRepoStub struct {
|
||||||
|
account *Account
|
||||||
|
getErr error
|
||||||
|
updateErr error
|
||||||
|
getCalled int32
|
||||||
|
updateCalled int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *openAIAccountRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||||
|
atomic.AddInt32(&r.getCalled, 1)
|
||||||
|
if r.getErr != nil {
|
||||||
|
return nil, r.getErr
|
||||||
|
}
|
||||||
|
return r.account, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *openAIAccountRepoStub) Update(ctx context.Context, account *Account) error {
|
||||||
|
atomic.AddInt32(&r.updateCalled, 1)
|
||||||
|
if r.updateErr != nil {
|
||||||
|
return r.updateErr
|
||||||
|
}
|
||||||
|
r.account = account
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// openAIOAuthServiceStub implements OpenAIOAuthService methods for testing
|
||||||
|
type openAIOAuthServiceStub struct {
|
||||||
|
tokenInfo *OpenAITokenInfo
|
||||||
|
refreshErr error
|
||||||
|
refreshCalled int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIOAuthServiceStub) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
||||||
|
atomic.AddInt32(&s.refreshCalled, 1)
|
||||||
|
if s.refreshErr != nil {
|
||||||
|
return nil, s.refreshErr
|
||||||
|
}
|
||||||
|
return s.tokenInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIOAuthServiceStub) BuildAccountCredentials(info *OpenAITokenInfo) map[string]any {
|
||||||
|
now := time.Now()
|
||||||
|
return map[string]any{
|
||||||
|
"access_token": info.AccessToken,
|
||||||
|
"refresh_token": info.RefreshToken,
|
||||||
|
"expires_at": now.Add(time.Duration(info.ExpiresIn) * time.Second).Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_CacheHit(t *testing.T) {
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
account := &Account{
|
||||||
|
ID: 100,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "db-token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cacheKey := OpenAITokenCacheKey(account)
|
||||||
|
cache.tokens[cacheKey] = "cached-token"
|
||||||
|
|
||||||
|
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||||
|
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "cached-token", token)
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalled))
|
||||||
|
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalled))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_CacheMiss_FromCredentials(t *testing.T) {
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
// Token expires in far future, no refresh needed
|
||||||
|
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 101,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "credential-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||||
|
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "credential-token", token)
|
||||||
|
|
||||||
|
// Should have stored in cache
|
||||||
|
cacheKey := OpenAITokenCacheKey(account)
|
||||||
|
require.Equal(t, "credential-token", cache.tokens[cacheKey])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_TokenRefresh(t *testing.T) {
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
accountRepo := &openAIAccountRepoStub{}
|
||||||
|
oauthService := &openAIOAuthServiceStub{
|
||||||
|
tokenInfo: &OpenAITokenInfo{
|
||||||
|
AccessToken: "refreshed-token",
|
||||||
|
RefreshToken: "new-refresh-token",
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token expires soon (within refresh skew)
|
||||||
|
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 102,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "old-token",
|
||||||
|
"refresh_token": "old-refresh-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
accountRepo.account = account
|
||||||
|
|
||||||
|
// We need to directly test with the stub - create a custom provider
|
||||||
|
customProvider := &testOpenAITokenProvider{
|
||||||
|
accountRepo: accountRepo,
|
||||||
|
tokenCache: cache,
|
||||||
|
oauthService: oauthService,
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := customProvider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "refreshed-token", token)
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&oauthService.refreshCalled))
|
||||||
|
}
|
||||||
|
|
||||||
|
// testOpenAITokenProvider is a test version that uses the stub OAuth service
|
||||||
|
type testOpenAITokenProvider struct {
|
||||||
|
accountRepo *openAIAccountRepoStub
|
||||||
|
tokenCache *openAITokenCacheStub
|
||||||
|
oauthService *openAIOAuthServiceStub
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *testOpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||||
|
if account == nil {
|
||||||
|
return "", errors.New("account is nil")
|
||||||
|
}
|
||||||
|
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
|
||||||
|
return "", errors.New("not an openai oauth account")
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheKey := OpenAITokenCacheKey(account)
|
||||||
|
|
||||||
|
// 1. Check cache
|
||||||
|
if p.tokenCache != nil {
|
||||||
|
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Check if refresh needed
|
||||||
|
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||||
|
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
|
||||||
|
refreshFailed := false
|
||||||
|
if needsRefresh && p.tokenCache != nil {
|
||||||
|
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||||
|
if err == nil && locked {
|
||||||
|
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||||
|
|
||||||
|
// Check cache again after acquiring lock
|
||||||
|
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get fresh account from DB
|
||||||
|
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||||
|
if err == nil && fresh != nil {
|
||||||
|
account = fresh
|
||||||
|
}
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
|
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||||
|
if p.oauthService == nil {
|
||||||
|
refreshFailed = true // 无法刷新,标记失败
|
||||||
|
} else {
|
||||||
|
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
||||||
|
} else {
|
||||||
|
newCredentials := p.oauthService.BuildAccountCredentials(tokenInfo)
|
||||||
|
for k, v := range account.Credentials {
|
||||||
|
if _, exists := newCredentials[k]; !exists {
|
||||||
|
newCredentials[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
account.Credentials = newCredentials
|
||||||
|
_ = p.accountRepo.Update(ctx, account)
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if p.tokenCache.simulateLockRace {
|
||||||
|
// Wait and retry cache
|
||||||
|
time.Sleep(10 * time.Millisecond) // Short wait for test
|
||||||
|
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken := account.GetOpenAIAccessToken()
|
||||||
|
if accessToken == "" {
|
||||||
|
return "", errors.New("access_token not found in credentials")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Store in cache
|
||||||
|
if p.tokenCache != nil {
|
||||||
|
ttl := 30 * time.Minute
|
||||||
|
if refreshFailed {
|
||||||
|
ttl = time.Minute // 刷新失败时使用短 TTL
|
||||||
|
} else if expiresAt != nil {
|
||||||
|
until := time.Until(*expiresAt)
|
||||||
|
if until > openAITokenCacheSkew {
|
||||||
|
ttl = until - openAITokenCacheSkew
|
||||||
|
} else if until > 0 {
|
||||||
|
ttl = until
|
||||||
|
} else {
|
||||||
|
ttl = time.Minute
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
return accessToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_LockRaceCondition(t *testing.T) {
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
cache.simulateLockRace = true
|
||||||
|
accountRepo := &openAIAccountRepoStub{}
|
||||||
|
|
||||||
|
// Token expires soon
|
||||||
|
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 103,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "race-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
accountRepo.account = account
|
||||||
|
|
||||||
|
// Simulate another worker already refreshed and cached
|
||||||
|
cacheKey := OpenAITokenCacheKey(account)
|
||||||
|
go func() {
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
cache.mu.Lock()
|
||||||
|
cache.tokens[cacheKey] = "winner-token"
|
||||||
|
cache.mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
provider := &testOpenAITokenProvider{
|
||||||
|
accountRepo: accountRepo,
|
||||||
|
tokenCache: cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// Should get the token set by the "winner" or the original
|
||||||
|
require.NotEmpty(t, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_NilAccount(t *testing.T) {
|
||||||
|
provider := NewOpenAITokenProvider(nil, nil, nil)
|
||||||
|
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "account is nil")
|
||||||
|
require.Empty(t, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_WrongPlatform(t *testing.T) {
|
||||||
|
provider := NewOpenAITokenProvider(nil, nil, nil)
|
||||||
|
account := &Account{
|
||||||
|
ID: 104,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "not an openai oauth account")
|
||||||
|
require.Empty(t, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_WrongAccountType(t *testing.T) {
|
||||||
|
provider := NewOpenAITokenProvider(nil, nil, nil)
|
||||||
|
account := &Account{
|
||||||
|
ID: 105,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "not an openai oauth account")
|
||||||
|
require.Empty(t, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_NilCache(t *testing.T) {
|
||||||
|
// Token doesn't need refresh
|
||||||
|
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 106,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "nocache-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewOpenAITokenProvider(nil, nil, nil)
|
||||||
|
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "nocache-token", token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_CacheGetError(t *testing.T) {
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
cache.getErr = errors.New("redis connection failed")
|
||||||
|
|
||||||
|
// Token doesn't need refresh
|
||||||
|
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 107,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "fallback-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||||
|
|
||||||
|
// Should gracefully degrade and return from credentials
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "fallback-token", token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_CacheSetError(t *testing.T) {
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
cache.setErr = errors.New("redis write failed")
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 108,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "still-works-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||||
|
|
||||||
|
// Should still work even if cache set fails
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "still-works-token", token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_MissingAccessToken(t *testing.T) {
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 109,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
// missing access_token
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||||
|
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "access_token not found")
|
||||||
|
require.Empty(t, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_RefreshError(t *testing.T) {
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
accountRepo := &openAIAccountRepoStub{}
|
||||||
|
oauthService := &openAIOAuthServiceStub{
|
||||||
|
refreshErr: errors.New("oauth refresh failed"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token expires soon
|
||||||
|
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 110,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "old-token",
|
||||||
|
"refresh_token": "old-refresh-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
accountRepo.account = account
|
||||||
|
|
||||||
|
provider := &testOpenAITokenProvider{
|
||||||
|
accountRepo: accountRepo,
|
||||||
|
tokenCache: cache,
|
||||||
|
oauthService: oauthService,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now with fallback behavior, should return existing token even if refresh fails
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "old-token", token) // Fallback to existing token
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_OAuthServiceNotConfigured(t *testing.T) {
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
accountRepo := &openAIAccountRepoStub{}
|
||||||
|
|
||||||
|
// Token expires soon
|
||||||
|
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 111,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "old-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
accountRepo.account = account
|
||||||
|
|
||||||
|
provider := &testOpenAITokenProvider{
|
||||||
|
accountRepo: accountRepo,
|
||||||
|
tokenCache: cache,
|
||||||
|
oauthService: nil, // not configured
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now with fallback behavior, should return existing token even if oauth service not configured
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "old-token", token) // Fallback to existing token
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_TTLCalculation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
expiresIn time.Duration
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "far_future_expiry",
|
||||||
|
expiresIn: 1 * time.Hour,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "medium_expiry",
|
||||||
|
expiresIn: 10 * time.Minute,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "near_expiry",
|
||||||
|
expiresIn: 6 * time.Minute,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
expiresAt := time.Now().Add(tt.expiresIn).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 200,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "test-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||||
|
|
||||||
|
_, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify token was cached
|
||||||
|
cacheKey := OpenAITokenCacheKey(account)
|
||||||
|
require.Equal(t, "test-token", cache.tokens[cacheKey])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_DoubleCheckAfterLock(t *testing.T) {
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
accountRepo := &openAIAccountRepoStub{}
|
||||||
|
oauthService := &openAIOAuthServiceStub{
|
||||||
|
tokenInfo: &OpenAITokenInfo{
|
||||||
|
AccessToken: "refreshed-token",
|
||||||
|
RefreshToken: "new-refresh",
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token expires soon
|
||||||
|
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 112,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "old-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
accountRepo.account = account
|
||||||
|
cacheKey := OpenAITokenCacheKey(account)
|
||||||
|
|
||||||
|
// Simulate: first GetAccessToken returns empty, but after lock acquired, cache has token
|
||||||
|
originalGet := int32(0)
|
||||||
|
cache.tokens[cacheKey] = "" // Empty initially
|
||||||
|
|
||||||
|
provider := &testOpenAITokenProvider{
|
||||||
|
accountRepo: accountRepo,
|
||||||
|
tokenCache: cache,
|
||||||
|
oauthService: oauthService,
|
||||||
|
}
|
||||||
|
|
||||||
|
// In a goroutine, set the cached token after a small delay (simulating race)
|
||||||
|
go func() {
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
cache.mu.Lock()
|
||||||
|
cache.tokens[cacheKey] = "cached-by-other"
|
||||||
|
cache.mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// Should get either the refreshed token or the cached one
|
||||||
|
require.NotEmpty(t, token)
|
||||||
|
_ = originalGet // Suppress unused warning
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests for real provider - to increase coverage
|
||||||
|
func TestOpenAITokenProvider_Real_LockFailedWait(t *testing.T) {
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
cache.lockAcquired = false // Lock acquisition fails
|
||||||
|
|
||||||
|
// Token expires soon (within refresh skew) to trigger lock attempt
|
||||||
|
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 200,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "fallback-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set token in cache after lock wait period (simulate other worker refreshing)
|
||||||
|
cacheKey := OpenAITokenCacheKey(account)
|
||||||
|
go func() {
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
cache.mu.Lock()
|
||||||
|
cache.tokens[cacheKey] = "refreshed-by-other"
|
||||||
|
cache.mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// Should get either the fallback token or the refreshed one
|
||||||
|
require.NotEmpty(t, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_Real_CacheHitAfterWait(t *testing.T) {
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
cache.lockAcquired = false // Lock acquisition fails
|
||||||
|
|
||||||
|
// Token expires soon
|
||||||
|
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 201,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "original-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheKey := OpenAITokenCacheKey(account)
|
||||||
|
// Set token in cache immediately after wait starts
|
||||||
|
go func() {
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
cache.mu.Lock()
|
||||||
|
cache.tokens[cacheKey] = "winner-token"
|
||||||
|
cache.mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_Real_ExpiredWithoutRefreshToken(t *testing.T) {
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
cache.lockAcquired = false // Prevent entering refresh logic
|
||||||
|
|
||||||
|
// Token with nil expires_at (no expiry set) - should use credentials
|
||||||
|
account := &Account{
|
||||||
|
ID: 202,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "no-expiry-token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
// Without OAuth service, refresh will fail but token should be returned from credentials
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "no-expiry-token", token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_Real_WhitespaceToken(t *testing.T) {
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
cacheKey := "openai:account:203"
|
||||||
|
cache.tokens[cacheKey] = " " // Whitespace only - should be treated as empty
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 203,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "real-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "real-token", token) // Should fall back to credentials
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_Real_LockError(t *testing.T) {
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
cache.lockErr = errors.New("redis lock failed")
|
||||||
|
|
||||||
|
// Token expires soon (within refresh skew)
|
||||||
|
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 204,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "fallback-on-lock-error",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "fallback-on-lock-error", token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_Real_WhitespaceCredentialToken(t *testing.T) {
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 205,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": " ", // Whitespace only
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "access_token not found")
|
||||||
|
require.Empty(t, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_Real_NilCredentials(t *testing.T) {
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 206,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
// No access_token
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "access_token not found")
|
||||||
|
require.Empty(t, token)
|
||||||
|
}
|
||||||
307
backend/internal/service/openai_tool_corrector.go
Normal file
307
backend/internal/service/openai_tool_corrector.go
Normal file
@@ -0,0 +1,307 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// codexToolNameMapping 定义 Codex 原生工具名称到 OpenCode 工具名称的映射
|
||||||
|
var codexToolNameMapping = map[string]string{
|
||||||
|
"apply_patch": "edit",
|
||||||
|
"applyPatch": "edit",
|
||||||
|
"update_plan": "todowrite",
|
||||||
|
"updatePlan": "todowrite",
|
||||||
|
"read_plan": "todoread",
|
||||||
|
"readPlan": "todoread",
|
||||||
|
"search_files": "grep",
|
||||||
|
"searchFiles": "grep",
|
||||||
|
"list_files": "glob",
|
||||||
|
"listFiles": "glob",
|
||||||
|
"read_file": "read",
|
||||||
|
"readFile": "read",
|
||||||
|
"write_file": "write",
|
||||||
|
"writeFile": "write",
|
||||||
|
"execute_bash": "bash",
|
||||||
|
"executeBash": "bash",
|
||||||
|
"exec_bash": "bash",
|
||||||
|
"execBash": "bash",
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolCorrectionStats 记录工具修正的统计信息(导出用于 JSON 序列化)
|
||||||
|
type ToolCorrectionStats struct {
|
||||||
|
TotalCorrected int `json:"total_corrected"`
|
||||||
|
CorrectionsByTool map[string]int `json:"corrections_by_tool"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CodexToolCorrector 处理 Codex 工具调用的自动修正
|
||||||
|
type CodexToolCorrector struct {
|
||||||
|
stats ToolCorrectionStats
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCodexToolCorrector 创建新的工具修正器
|
||||||
|
func NewCodexToolCorrector() *CodexToolCorrector {
|
||||||
|
return &CodexToolCorrector{
|
||||||
|
stats: ToolCorrectionStats{
|
||||||
|
CorrectionsByTool: make(map[string]int),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CorrectToolCallsInSSEData 修正 SSE 数据中的工具调用
|
||||||
|
// 返回修正后的数据和是否进行了修正
|
||||||
|
func (c *CodexToolCorrector) CorrectToolCallsInSSEData(data string) (string, bool) {
|
||||||
|
if data == "" || data == "\n" {
|
||||||
|
return data, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试解析 JSON
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(data), &payload); err != nil {
|
||||||
|
// 不是有效的 JSON,直接返回原数据
|
||||||
|
return data, false
|
||||||
|
}
|
||||||
|
|
||||||
|
corrected := false
|
||||||
|
|
||||||
|
// 处理 tool_calls 数组
|
||||||
|
if toolCalls, ok := payload["tool_calls"].([]any); ok {
|
||||||
|
if c.correctToolCallsArray(toolCalls) {
|
||||||
|
corrected = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理 function_call 对象
|
||||||
|
if functionCall, ok := payload["function_call"].(map[string]any); ok {
|
||||||
|
if c.correctFunctionCall(functionCall) {
|
||||||
|
corrected = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理 delta.tool_calls
|
||||||
|
if delta, ok := payload["delta"].(map[string]any); ok {
|
||||||
|
if toolCalls, ok := delta["tool_calls"].([]any); ok {
|
||||||
|
if c.correctToolCallsArray(toolCalls) {
|
||||||
|
corrected = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if functionCall, ok := delta["function_call"].(map[string]any); ok {
|
||||||
|
if c.correctFunctionCall(functionCall) {
|
||||||
|
corrected = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理 choices[].message.tool_calls 和 choices[].delta.tool_calls
|
||||||
|
if choices, ok := payload["choices"].([]any); ok {
|
||||||
|
for _, choice := range choices {
|
||||||
|
if choiceMap, ok := choice.(map[string]any); ok {
|
||||||
|
// 处理 message 中的工具调用
|
||||||
|
if message, ok := choiceMap["message"].(map[string]any); ok {
|
||||||
|
if toolCalls, ok := message["tool_calls"].([]any); ok {
|
||||||
|
if c.correctToolCallsArray(toolCalls) {
|
||||||
|
corrected = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if functionCall, ok := message["function_call"].(map[string]any); ok {
|
||||||
|
if c.correctFunctionCall(functionCall) {
|
||||||
|
corrected = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 处理 delta 中的工具调用
|
||||||
|
if delta, ok := choiceMap["delta"].(map[string]any); ok {
|
||||||
|
if toolCalls, ok := delta["tool_calls"].([]any); ok {
|
||||||
|
if c.correctToolCallsArray(toolCalls) {
|
||||||
|
corrected = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if functionCall, ok := delta["function_call"].(map[string]any); ok {
|
||||||
|
if c.correctFunctionCall(functionCall) {
|
||||||
|
corrected = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !corrected {
|
||||||
|
return data, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 序列化回 JSON
|
||||||
|
correctedBytes, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[CodexToolCorrector] Failed to marshal corrected data: %v", err)
|
||||||
|
return data, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(correctedBytes), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// correctToolCallsArray 修正工具调用数组中的工具名称
|
||||||
|
func (c *CodexToolCorrector) correctToolCallsArray(toolCalls []any) bool {
|
||||||
|
corrected := false
|
||||||
|
for _, toolCall := range toolCalls {
|
||||||
|
if toolCallMap, ok := toolCall.(map[string]any); ok {
|
||||||
|
if function, ok := toolCallMap["function"].(map[string]any); ok {
|
||||||
|
if c.correctFunctionCall(function) {
|
||||||
|
corrected = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return corrected
|
||||||
|
}
|
||||||
|
|
||||||
|
// correctFunctionCall 修正单个函数调用的工具名称和参数
|
||||||
|
func (c *CodexToolCorrector) correctFunctionCall(functionCall map[string]any) bool {
|
||||||
|
name, ok := functionCall["name"].(string)
|
||||||
|
if !ok || name == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
corrected := false
|
||||||
|
|
||||||
|
// 查找并修正工具名称
|
||||||
|
if correctName, found := codexToolNameMapping[name]; found {
|
||||||
|
functionCall["name"] = correctName
|
||||||
|
c.recordCorrection(name, correctName)
|
||||||
|
corrected = true
|
||||||
|
name = correctName // 使用修正后的名称进行参数修正
|
||||||
|
}
|
||||||
|
|
||||||
|
// 修正工具参数(基于工具名称)
|
||||||
|
if c.correctToolParameters(name, functionCall) {
|
||||||
|
corrected = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return corrected
|
||||||
|
}
|
||||||
|
|
||||||
|
// correctToolParameters 修正工具参数以符合 OpenCode 规范
|
||||||
|
func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall map[string]any) bool {
|
||||||
|
arguments, ok := functionCall["arguments"]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// arguments 可能是字符串(JSON)或已解析的 map
|
||||||
|
var argsMap map[string]any
|
||||||
|
switch v := arguments.(type) {
|
||||||
|
case string:
|
||||||
|
// 解析 JSON 字符串
|
||||||
|
if err := json.Unmarshal([]byte(v), &argsMap); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case map[string]any:
|
||||||
|
argsMap = v
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
corrected := false
|
||||||
|
|
||||||
|
// 根据工具名称应用特定的参数修正规则
|
||||||
|
switch toolName {
|
||||||
|
case "bash":
|
||||||
|
// 移除 workdir 参数(OpenCode 不支持)
|
||||||
|
if _, exists := argsMap["workdir"]; exists {
|
||||||
|
delete(argsMap, "workdir")
|
||||||
|
corrected = true
|
||||||
|
log.Printf("[CodexToolCorrector] Removed 'workdir' parameter from bash tool")
|
||||||
|
}
|
||||||
|
if _, exists := argsMap["work_dir"]; exists {
|
||||||
|
delete(argsMap, "work_dir")
|
||||||
|
corrected = true
|
||||||
|
log.Printf("[CodexToolCorrector] Removed 'work_dir' parameter from bash tool")
|
||||||
|
}
|
||||||
|
|
||||||
|
case "edit":
|
||||||
|
// OpenCode edit 使用 old_string/new_string,Codex 可能使用其他名称
|
||||||
|
// 这里可以添加参数名称的映射逻辑
|
||||||
|
if _, exists := argsMap["file_path"]; !exists {
|
||||||
|
if path, exists := argsMap["path"]; exists {
|
||||||
|
argsMap["file_path"] = path
|
||||||
|
delete(argsMap, "path")
|
||||||
|
corrected = true
|
||||||
|
log.Printf("[CodexToolCorrector] Renamed 'path' to 'file_path' in edit tool")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果修正了参数,需要重新序列化
|
||||||
|
if corrected {
|
||||||
|
if _, wasString := arguments.(string); wasString {
|
||||||
|
// 原本是字符串,序列化回字符串
|
||||||
|
if newArgsJSON, err := json.Marshal(argsMap); err == nil {
|
||||||
|
functionCall["arguments"] = string(newArgsJSON)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 原本是 map,直接赋值
|
||||||
|
functionCall["arguments"] = argsMap
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return corrected
|
||||||
|
}
|
||||||
|
|
||||||
|
// recordCorrection 记录一次工具名称修正
|
||||||
|
func (c *CodexToolCorrector) recordCorrection(from, to string) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
c.stats.TotalCorrected++
|
||||||
|
key := fmt.Sprintf("%s->%s", from, to)
|
||||||
|
c.stats.CorrectionsByTool[key]++
|
||||||
|
|
||||||
|
log.Printf("[CodexToolCorrector] Corrected tool call: %s -> %s (total: %d)",
|
||||||
|
from, to, c.stats.TotalCorrected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStats 获取工具修正统计信息
|
||||||
|
func (c *CodexToolCorrector) GetStats() ToolCorrectionStats {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
// 返回副本以避免并发问题
|
||||||
|
statsCopy := ToolCorrectionStats{
|
||||||
|
TotalCorrected: c.stats.TotalCorrected,
|
||||||
|
CorrectionsByTool: make(map[string]int, len(c.stats.CorrectionsByTool)),
|
||||||
|
}
|
||||||
|
for k, v := range c.stats.CorrectionsByTool {
|
||||||
|
statsCopy.CorrectionsByTool[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
return statsCopy
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetStats 重置统计信息
|
||||||
|
func (c *CodexToolCorrector) ResetStats() {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
c.stats.TotalCorrected = 0
|
||||||
|
c.stats.CorrectionsByTool = make(map[string]int)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CorrectToolName 直接修正工具名称(用于非 SSE 场景)
|
||||||
|
func CorrectToolName(name string) (string, bool) {
|
||||||
|
if correctName, found := codexToolNameMapping[name]; found {
|
||||||
|
return correctName, true
|
||||||
|
}
|
||||||
|
return name, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetToolNameMapping 获取工具名称映射表
|
||||||
|
func GetToolNameMapping() map[string]string {
|
||||||
|
// 返回副本以避免外部修改
|
||||||
|
mapping := make(map[string]string, len(codexToolNameMapping))
|
||||||
|
for k, v := range codexToolNameMapping {
|
||||||
|
mapping[k] = v
|
||||||
|
}
|
||||||
|
return mapping
|
||||||
|
}
|
||||||
503
backend/internal/service/openai_tool_corrector_test.go
Normal file
503
backend/internal/service/openai_tool_corrector_test.go
Normal file
@@ -0,0 +1,503 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCorrectToolCallsInSSEData(t *testing.T) {
|
||||||
|
corrector := NewCodexToolCorrector()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expectCorrected bool
|
||||||
|
checkFunc func(t *testing.T, result string)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
input: "",
|
||||||
|
expectCorrected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "newline only",
|
||||||
|
input: "\n",
|
||||||
|
expectCorrected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid json",
|
||||||
|
input: "not a json",
|
||||||
|
expectCorrected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "correct apply_patch in tool_calls",
|
||||||
|
input: `{"tool_calls":[{"function":{"name":"apply_patch","arguments":"{}"}}]}`,
|
||||||
|
expectCorrected: true,
|
||||||
|
checkFunc: func(t *testing.T, result string) {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to parse result: %v", err)
|
||||||
|
}
|
||||||
|
toolCalls, ok := payload["tool_calls"].([]any)
|
||||||
|
if !ok || len(toolCalls) == 0 {
|
||||||
|
t.Fatal("No tool_calls found in result")
|
||||||
|
}
|
||||||
|
toolCall, ok := toolCalls[0].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid tool_call format")
|
||||||
|
}
|
||||||
|
functionCall, ok := toolCall["function"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid function format")
|
||||||
|
}
|
||||||
|
if functionCall["name"] != "edit" {
|
||||||
|
t.Errorf("Expected tool name 'edit', got '%v'", functionCall["name"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "correct update_plan in function_call",
|
||||||
|
input: `{"function_call":{"name":"update_plan","arguments":"{}"}}`,
|
||||||
|
expectCorrected: true,
|
||||||
|
checkFunc: func(t *testing.T, result string) {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to parse result: %v", err)
|
||||||
|
}
|
||||||
|
functionCall, ok := payload["function_call"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid function_call format")
|
||||||
|
}
|
||||||
|
if functionCall["name"] != "todowrite" {
|
||||||
|
t.Errorf("Expected tool name 'todowrite', got '%v'", functionCall["name"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "correct search_files in delta.tool_calls",
|
||||||
|
input: `{"delta":{"tool_calls":[{"function":{"name":"search_files"}}]}}`,
|
||||||
|
expectCorrected: true,
|
||||||
|
checkFunc: func(t *testing.T, result string) {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to parse result: %v", err)
|
||||||
|
}
|
||||||
|
delta, ok := payload["delta"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid delta format")
|
||||||
|
}
|
||||||
|
toolCalls, ok := delta["tool_calls"].([]any)
|
||||||
|
if !ok || len(toolCalls) == 0 {
|
||||||
|
t.Fatal("No tool_calls found in delta")
|
||||||
|
}
|
||||||
|
toolCall, ok := toolCalls[0].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid tool_call format")
|
||||||
|
}
|
||||||
|
functionCall, ok := toolCall["function"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid function format")
|
||||||
|
}
|
||||||
|
if functionCall["name"] != "grep" {
|
||||||
|
t.Errorf("Expected tool name 'grep', got '%v'", functionCall["name"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "correct list_files in choices.message.tool_calls",
|
||||||
|
input: `{"choices":[{"message":{"tool_calls":[{"function":{"name":"list_files"}}]}}]}`,
|
||||||
|
expectCorrected: true,
|
||||||
|
checkFunc: func(t *testing.T, result string) {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to parse result: %v", err)
|
||||||
|
}
|
||||||
|
choices, ok := payload["choices"].([]any)
|
||||||
|
if !ok || len(choices) == 0 {
|
||||||
|
t.Fatal("No choices found in result")
|
||||||
|
}
|
||||||
|
choice, ok := choices[0].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid choice format")
|
||||||
|
}
|
||||||
|
message, ok := choice["message"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid message format")
|
||||||
|
}
|
||||||
|
toolCalls, ok := message["tool_calls"].([]any)
|
||||||
|
if !ok || len(toolCalls) == 0 {
|
||||||
|
t.Fatal("No tool_calls found in message")
|
||||||
|
}
|
||||||
|
toolCall, ok := toolCalls[0].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid tool_call format")
|
||||||
|
}
|
||||||
|
functionCall, ok := toolCall["function"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid function format")
|
||||||
|
}
|
||||||
|
if functionCall["name"] != "glob" {
|
||||||
|
t.Errorf("Expected tool name 'glob', got '%v'", functionCall["name"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no correction needed",
|
||||||
|
input: `{"tool_calls":[{"function":{"name":"read","arguments":"{}"}}]}`,
|
||||||
|
expectCorrected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "correct multiple tool calls",
|
||||||
|
input: `{"tool_calls":[{"function":{"name":"apply_patch"}},{"function":{"name":"read_file"}}]}`,
|
||||||
|
expectCorrected: true,
|
||||||
|
checkFunc: func(t *testing.T, result string) {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to parse result: %v", err)
|
||||||
|
}
|
||||||
|
toolCalls, ok := payload["tool_calls"].([]any)
|
||||||
|
if !ok || len(toolCalls) < 2 {
|
||||||
|
t.Fatal("Expected at least 2 tool_calls")
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCall1, ok := toolCalls[0].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid first tool_call format")
|
||||||
|
}
|
||||||
|
func1, ok := toolCall1["function"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid first function format")
|
||||||
|
}
|
||||||
|
if func1["name"] != "edit" {
|
||||||
|
t.Errorf("Expected first tool name 'edit', got '%v'", func1["name"])
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCall2, ok := toolCalls[1].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid second tool_call format")
|
||||||
|
}
|
||||||
|
func2, ok := toolCall2["function"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid second function format")
|
||||||
|
}
|
||||||
|
if func2["name"] != "read" {
|
||||||
|
t.Errorf("Expected second tool name 'read', got '%v'", func2["name"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "camelCase format - applyPatch",
|
||||||
|
input: `{"tool_calls":[{"function":{"name":"applyPatch"}}]}`,
|
||||||
|
expectCorrected: true,
|
||||||
|
checkFunc: func(t *testing.T, result string) {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to parse result: %v", err)
|
||||||
|
}
|
||||||
|
toolCalls, ok := payload["tool_calls"].([]any)
|
||||||
|
if !ok || len(toolCalls) == 0 {
|
||||||
|
t.Fatal("No tool_calls found in result")
|
||||||
|
}
|
||||||
|
toolCall, ok := toolCalls[0].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid tool_call format")
|
||||||
|
}
|
||||||
|
functionCall, ok := toolCall["function"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid function format")
|
||||||
|
}
|
||||||
|
if functionCall["name"] != "edit" {
|
||||||
|
t.Errorf("Expected tool name 'edit', got '%v'", functionCall["name"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, corrected := corrector.CorrectToolCallsInSSEData(tt.input)
|
||||||
|
|
||||||
|
if corrected != tt.expectCorrected {
|
||||||
|
t.Errorf("Expected corrected=%v, got %v", tt.expectCorrected, corrected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !corrected && result != tt.input {
|
||||||
|
t.Errorf("Expected unchanged result when not corrected")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.checkFunc != nil {
|
||||||
|
tt.checkFunc(t, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCorrectToolName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
corrected bool
|
||||||
|
}{
|
||||||
|
{"apply_patch", "edit", true},
|
||||||
|
{"applyPatch", "edit", true},
|
||||||
|
{"update_plan", "todowrite", true},
|
||||||
|
{"updatePlan", "todowrite", true},
|
||||||
|
{"read_plan", "todoread", true},
|
||||||
|
{"readPlan", "todoread", true},
|
||||||
|
{"search_files", "grep", true},
|
||||||
|
{"searchFiles", "grep", true},
|
||||||
|
{"list_files", "glob", true},
|
||||||
|
{"listFiles", "glob", true},
|
||||||
|
{"read_file", "read", true},
|
||||||
|
{"readFile", "read", true},
|
||||||
|
{"write_file", "write", true},
|
||||||
|
{"writeFile", "write", true},
|
||||||
|
{"execute_bash", "bash", true},
|
||||||
|
{"executeBash", "bash", true},
|
||||||
|
{"exec_bash", "bash", true},
|
||||||
|
{"execBash", "bash", true},
|
||||||
|
{"unknown_tool", "unknown_tool", false},
|
||||||
|
{"read", "read", false},
|
||||||
|
{"edit", "edit", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
result, corrected := CorrectToolName(tt.input)
|
||||||
|
|
||||||
|
if corrected != tt.corrected {
|
||||||
|
t.Errorf("Expected corrected=%v, got %v", tt.corrected, corrected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetToolNameMapping(t *testing.T) {
|
||||||
|
mapping := GetToolNameMapping()
|
||||||
|
|
||||||
|
expectedMappings := map[string]string{
|
||||||
|
"apply_patch": "edit",
|
||||||
|
"update_plan": "todowrite",
|
||||||
|
"read_plan": "todoread",
|
||||||
|
"search_files": "grep",
|
||||||
|
"list_files": "glob",
|
||||||
|
}
|
||||||
|
|
||||||
|
for from, to := range expectedMappings {
|
||||||
|
if mapping[from] != to {
|
||||||
|
t.Errorf("Expected mapping[%s] = %s, got %s", from, to, mapping[from])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mapping["test_tool"] = "test_value"
|
||||||
|
newMapping := GetToolNameMapping()
|
||||||
|
if _, exists := newMapping["test_tool"]; exists {
|
||||||
|
t.Error("Modifications to returned mapping should not affect original")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCorrectorStats(t *testing.T) {
|
||||||
|
corrector := NewCodexToolCorrector()
|
||||||
|
|
||||||
|
stats := corrector.GetStats()
|
||||||
|
if stats.TotalCorrected != 0 {
|
||||||
|
t.Errorf("Expected TotalCorrected=0, got %d", stats.TotalCorrected)
|
||||||
|
}
|
||||||
|
if len(stats.CorrectionsByTool) != 0 {
|
||||||
|
t.Errorf("Expected empty CorrectionsByTool, got length %d", len(stats.CorrectionsByTool))
|
||||||
|
}
|
||||||
|
|
||||||
|
corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)
|
||||||
|
corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)
|
||||||
|
corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"update_plan"}}]}`)
|
||||||
|
|
||||||
|
stats = corrector.GetStats()
|
||||||
|
if stats.TotalCorrected != 3 {
|
||||||
|
t.Errorf("Expected TotalCorrected=3, got %d", stats.TotalCorrected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats.CorrectionsByTool["apply_patch->edit"] != 2 {
|
||||||
|
t.Errorf("Expected apply_patch->edit count=2, got %d", stats.CorrectionsByTool["apply_patch->edit"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats.CorrectionsByTool["update_plan->todowrite"] != 1 {
|
||||||
|
t.Errorf("Expected update_plan->todowrite count=1, got %d", stats.CorrectionsByTool["update_plan->todowrite"])
|
||||||
|
}
|
||||||
|
|
||||||
|
corrector.ResetStats()
|
||||||
|
stats = corrector.GetStats()
|
||||||
|
if stats.TotalCorrected != 0 {
|
||||||
|
t.Errorf("Expected TotalCorrected=0 after reset, got %d", stats.TotalCorrected)
|
||||||
|
}
|
||||||
|
if len(stats.CorrectionsByTool) != 0 {
|
||||||
|
t.Errorf("Expected empty CorrectionsByTool after reset, got length %d", len(stats.CorrectionsByTool))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComplexSSEData(t *testing.T) {
|
||||||
|
corrector := NewCodexToolCorrector()
|
||||||
|
|
||||||
|
input := `{
|
||||||
|
"id": "chatcmpl-123",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": 1234567890,
|
||||||
|
"model": "gpt-5.1-codex",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"function": {
|
||||||
|
"name": "apply_patch",
|
||||||
|
"arguments": "{\"file\":\"test.go\"}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"finish_reason": null
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
|
||||||
|
result, corrected := corrector.CorrectToolCallsInSSEData(input)
|
||||||
|
|
||||||
|
if !corrected {
|
||||||
|
t.Error("Expected data to be corrected")
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to parse result: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
choices, ok := payload["choices"].([]any)
|
||||||
|
if !ok || len(choices) == 0 {
|
||||||
|
t.Fatal("No choices found in result")
|
||||||
|
}
|
||||||
|
choice, ok := choices[0].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid choice format")
|
||||||
|
}
|
||||||
|
delta, ok := choice["delta"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid delta format")
|
||||||
|
}
|
||||||
|
toolCalls, ok := delta["tool_calls"].([]any)
|
||||||
|
if !ok || len(toolCalls) == 0 {
|
||||||
|
t.Fatal("No tool_calls found in delta")
|
||||||
|
}
|
||||||
|
toolCall, ok := toolCalls[0].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid tool_call format")
|
||||||
|
}
|
||||||
|
function, ok := toolCall["function"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid function format")
|
||||||
|
}
|
||||||
|
|
||||||
|
if function["name"] != "edit" {
|
||||||
|
t.Errorf("Expected tool name 'edit', got '%v'", function["name"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCorrectToolParameters 测试工具参数修正
|
||||||
|
func TestCorrectToolParameters(t *testing.T) {
|
||||||
|
corrector := NewCodexToolCorrector()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected map[string]bool // key: 期待存在的参数, value: true表示应该存在
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "remove workdir from bash tool",
|
||||||
|
input: `{
|
||||||
|
"tool_calls": [{
|
||||||
|
"function": {
|
||||||
|
"name": "bash",
|
||||||
|
"arguments": "{\"command\":\"ls\",\"workdir\":\"/tmp\"}"
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}`,
|
||||||
|
expected: map[string]bool{
|
||||||
|
"command": true,
|
||||||
|
"workdir": false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "rename path to file_path in edit tool",
|
||||||
|
input: `{
|
||||||
|
"tool_calls": [{
|
||||||
|
"function": {
|
||||||
|
"name": "apply_patch",
|
||||||
|
"arguments": "{\"path\":\"/foo/bar.go\",\"old_string\":\"old\",\"new_string\":\"new\"}"
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}`,
|
||||||
|
expected: map[string]bool{
|
||||||
|
"file_path": true,
|
||||||
|
"path": false,
|
||||||
|
"old_string": true,
|
||||||
|
"new_string": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
corrected, changed := corrector.CorrectToolCallsInSSEData(tt.input)
|
||||||
|
if !changed {
|
||||||
|
t.Error("expected data to be corrected")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析修正后的数据
|
||||||
|
var result map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(corrected), &result); err != nil {
|
||||||
|
t.Fatalf("failed to parse corrected data: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查工具调用
|
||||||
|
toolCalls, ok := result["tool_calls"].([]any)
|
||||||
|
if !ok || len(toolCalls) == 0 {
|
||||||
|
t.Fatal("no tool_calls found in corrected data")
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCall, ok := toolCalls[0].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("invalid tool_call structure")
|
||||||
|
}
|
||||||
|
|
||||||
|
function, ok := toolCall["function"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("no function found in tool_call")
|
||||||
|
}
|
||||||
|
|
||||||
|
argumentsStr, ok := function["arguments"].(string)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("arguments is not a string")
|
||||||
|
}
|
||||||
|
|
||||||
|
var args map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(argumentsStr), &args); err != nil {
|
||||||
|
t.Fatalf("failed to parse arguments: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证期望的参数
|
||||||
|
for param, shouldExist := range tt.expected {
|
||||||
|
_, exists := args[param]
|
||||||
|
if shouldExist && !exists {
|
||||||
|
t.Errorf("expected parameter %q to exist, but it doesn't", param)
|
||||||
|
}
|
||||||
|
if !shouldExist && exists {
|
||||||
|
t.Errorf("expected parameter %q to not exist, but it does", param)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -235,11 +236,13 @@ func (s *OpsAggregationService) aggregateHourly() {
|
|||||||
successAt := finishedAt
|
successAt := finishedAt
|
||||||
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
defer hbCancel()
|
defer hbCancel()
|
||||||
|
result := truncateString(fmt.Sprintf("window=%s..%s", start.Format(time.RFC3339), end.Format(time.RFC3339)), 2048)
|
||||||
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
|
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
|
||||||
JobName: opsAggHourlyJobName,
|
JobName: opsAggHourlyJobName,
|
||||||
LastRunAt: &runAt,
|
LastRunAt: &runAt,
|
||||||
LastSuccessAt: &successAt,
|
LastSuccessAt: &successAt,
|
||||||
LastDurationMs: &dur,
|
LastDurationMs: &dur,
|
||||||
|
LastResult: &result,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -331,11 +334,13 @@ func (s *OpsAggregationService) aggregateDaily() {
|
|||||||
successAt := finishedAt
|
successAt := finishedAt
|
||||||
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
defer hbCancel()
|
defer hbCancel()
|
||||||
|
result := truncateString(fmt.Sprintf("window=%s..%s", start.Format(time.RFC3339), end.Format(time.RFC3339)), 2048)
|
||||||
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
|
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
|
||||||
JobName: opsAggDailyJobName,
|
JobName: opsAggDailyJobName,
|
||||||
LastRunAt: &runAt,
|
LastRunAt: &runAt,
|
||||||
LastSuccessAt: &successAt,
|
LastSuccessAt: &successAt,
|
||||||
LastDurationMs: &dur,
|
LastDurationMs: &dur,
|
||||||
|
LastResult: &result,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -190,6 +190,13 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rulesTotal := len(rules)
|
||||||
|
rulesEnabled := 0
|
||||||
|
rulesEvaluated := 0
|
||||||
|
eventsCreated := 0
|
||||||
|
eventsResolved := 0
|
||||||
|
emailsSent := 0
|
||||||
|
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
safeEnd := now.Truncate(time.Minute)
|
safeEnd := now.Truncate(time.Minute)
|
||||||
if safeEnd.IsZero() {
|
if safeEnd.IsZero() {
|
||||||
@@ -205,8 +212,9 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
|
|||||||
if rule == nil || !rule.Enabled || rule.ID <= 0 {
|
if rule == nil || !rule.Enabled || rule.ID <= 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
rulesEnabled++
|
||||||
|
|
||||||
scopePlatform, scopeGroupID := parseOpsAlertRuleScope(rule.Filters)
|
scopePlatform, scopeGroupID, scopeRegion := parseOpsAlertRuleScope(rule.Filters)
|
||||||
|
|
||||||
windowMinutes := rule.WindowMinutes
|
windowMinutes := rule.WindowMinutes
|
||||||
if windowMinutes <= 0 {
|
if windowMinutes <= 0 {
|
||||||
@@ -220,6 +228,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
|
|||||||
s.resetRuleState(rule.ID, now)
|
s.resetRuleState(rule.ID, now)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
rulesEvaluated++
|
||||||
|
|
||||||
breachedNow := compareMetric(metricValue, rule.Operator, rule.Threshold)
|
breachedNow := compareMetric(metricValue, rule.Operator, rule.Threshold)
|
||||||
required := requiredSustainedBreaches(rule.SustainedMinutes, interval)
|
required := requiredSustainedBreaches(rule.SustainedMinutes, interval)
|
||||||
@@ -236,6 +245,17 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Scoped silencing: if a matching silence exists, skip creating a firing event.
|
||||||
|
if s.opsService != nil {
|
||||||
|
platform := strings.TrimSpace(scopePlatform)
|
||||||
|
region := scopeRegion
|
||||||
|
if platform != "" {
|
||||||
|
if ok, err := s.opsService.IsAlertSilenced(ctx, rule.ID, platform, scopeGroupID, region, now); err == nil && ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
latestEvent, err := s.opsRepo.GetLatestAlertEvent(ctx, rule.ID)
|
latestEvent, err := s.opsRepo.GetLatestAlertEvent(ctx, rule.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[OpsAlertEvaluator] get latest event failed (rule=%d): %v", rule.ID, err)
|
log.Printf("[OpsAlertEvaluator] get latest event failed (rule=%d): %v", rule.ID, err)
|
||||||
@@ -267,8 +287,11 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
eventsCreated++
|
||||||
if created != nil && created.ID > 0 {
|
if created != nil && created.ID > 0 {
|
||||||
s.maybeSendAlertEmail(ctx, runtimeCfg, rule, created)
|
if s.maybeSendAlertEmail(ctx, runtimeCfg, rule, created) {
|
||||||
|
emailsSent++
|
||||||
|
}
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -278,11 +301,14 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
|
|||||||
resolvedAt := now
|
resolvedAt := now
|
||||||
if err := s.opsRepo.UpdateAlertEventStatus(ctx, activeEvent.ID, OpsAlertStatusResolved, &resolvedAt); err != nil {
|
if err := s.opsRepo.UpdateAlertEventStatus(ctx, activeEvent.ID, OpsAlertStatusResolved, &resolvedAt); err != nil {
|
||||||
log.Printf("[OpsAlertEvaluator] resolve event failed (event=%d): %v", activeEvent.ID, err)
|
log.Printf("[OpsAlertEvaluator] resolve event failed (event=%d): %v", activeEvent.ID, err)
|
||||||
|
} else {
|
||||||
|
eventsResolved++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.recordHeartbeatSuccess(runAt, time.Since(startedAt))
|
result := truncateString(fmt.Sprintf("rules=%d enabled=%d evaluated=%d created=%d resolved=%d emails_sent=%d", rulesTotal, rulesEnabled, rulesEvaluated, eventsCreated, eventsResolved, emailsSent), 2048)
|
||||||
|
s.recordHeartbeatSuccess(runAt, time.Since(startedAt), result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpsAlertEvaluatorService) pruneRuleStates(rules []*OpsAlertRule) {
|
func (s *OpsAlertEvaluatorService) pruneRuleStates(rules []*OpsAlertRule) {
|
||||||
@@ -359,9 +385,9 @@ func requiredSustainedBreaches(sustainedMinutes int, interval time.Duration) int
|
|||||||
return required
|
return required
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseOpsAlertRuleScope(filters map[string]any) (platform string, groupID *int64) {
|
func parseOpsAlertRuleScope(filters map[string]any) (platform string, groupID *int64, region *string) {
|
||||||
if filters == nil {
|
if filters == nil {
|
||||||
return "", nil
|
return "", nil, nil
|
||||||
}
|
}
|
||||||
if v, ok := filters["platform"]; ok {
|
if v, ok := filters["platform"]; ok {
|
||||||
if s, ok := v.(string); ok {
|
if s, ok := v.(string); ok {
|
||||||
@@ -392,7 +418,15 @@ func parseOpsAlertRuleScope(filters map[string]any) (platform string, groupID *i
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return platform, groupID
|
if v, ok := filters["region"]; ok {
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
vv := strings.TrimSpace(s)
|
||||||
|
if vv != "" {
|
||||||
|
region = &vv
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return platform, groupID, region
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpsAlertEvaluatorService) computeRuleMetric(
|
func (s *OpsAlertEvaluatorService) computeRuleMetric(
|
||||||
@@ -504,16 +538,6 @@ func (s *OpsAlertEvaluatorService) computeRuleMetric(
|
|||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
return overview.UpstreamErrorRate * 100, true
|
return overview.UpstreamErrorRate * 100, true
|
||||||
case "p95_latency_ms":
|
|
||||||
if overview.Duration.P95 == nil {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
return float64(*overview.Duration.P95), true
|
|
||||||
case "p99_latency_ms":
|
|
||||||
if overview.Duration.P99 == nil {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
return float64(*overview.Duration.P99), true
|
|
||||||
default:
|
default:
|
||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
@@ -576,32 +600,32 @@ func buildOpsAlertDescription(rule *OpsAlertRule, value float64, windowMinutes i
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runtimeCfg *OpsAlertRuntimeSettings, rule *OpsAlertRule, event *OpsAlertEvent) {
|
func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runtimeCfg *OpsAlertRuntimeSettings, rule *OpsAlertRule, event *OpsAlertEvent) bool {
|
||||||
if s == nil || s.emailService == nil || s.opsService == nil || event == nil || rule == nil {
|
if s == nil || s.emailService == nil || s.opsService == nil || event == nil || rule == nil {
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
if event.EmailSent {
|
if event.EmailSent {
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
if !rule.NotifyEmail {
|
if !rule.NotifyEmail {
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
emailCfg, err := s.opsService.GetEmailNotificationConfig(ctx)
|
emailCfg, err := s.opsService.GetEmailNotificationConfig(ctx)
|
||||||
if err != nil || emailCfg == nil || !emailCfg.Alert.Enabled {
|
if err != nil || emailCfg == nil || !emailCfg.Alert.Enabled {
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(emailCfg.Alert.Recipients) == 0 {
|
if len(emailCfg.Alert.Recipients) == 0 {
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
if !shouldSendOpsAlertEmailByMinSeverity(strings.TrimSpace(emailCfg.Alert.MinSeverity), strings.TrimSpace(rule.Severity)) {
|
if !shouldSendOpsAlertEmailByMinSeverity(strings.TrimSpace(emailCfg.Alert.MinSeverity), strings.TrimSpace(rule.Severity)) {
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if runtimeCfg != nil && runtimeCfg.Silencing.Enabled {
|
if runtimeCfg != nil && runtimeCfg.Silencing.Enabled {
|
||||||
if isOpsAlertSilenced(time.Now().UTC(), rule, event, runtimeCfg.Silencing) {
|
if isOpsAlertSilenced(time.Now().UTC(), rule, event, runtimeCfg.Silencing) {
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -630,6 +654,7 @@ func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runt
|
|||||||
if anySent {
|
if anySent {
|
||||||
_ = s.opsRepo.UpdateAlertEventEmailSent(context.Background(), event.ID, true)
|
_ = s.opsRepo.UpdateAlertEventEmailSent(context.Background(), event.ID, true)
|
||||||
}
|
}
|
||||||
|
return anySent
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildOpsAlertEmailBody(rule *OpsAlertRule, event *OpsAlertEvent) string {
|
func buildOpsAlertEmailBody(rule *OpsAlertRule, event *OpsAlertEvent) string {
|
||||||
@@ -797,7 +822,7 @@ func (s *OpsAlertEvaluatorService) maybeLogSkip(key string) {
|
|||||||
log.Printf("[OpsAlertEvaluator] leader lock held by another instance; skipping (key=%q)", key)
|
log.Printf("[OpsAlertEvaluator] leader lock held by another instance; skipping (key=%q)", key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpsAlertEvaluatorService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration) {
|
func (s *OpsAlertEvaluatorService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration, result string) {
|
||||||
if s == nil || s.opsRepo == nil {
|
if s == nil || s.opsRepo == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -805,11 +830,17 @@ func (s *OpsAlertEvaluatorService) recordHeartbeatSuccess(runAt time.Time, durat
|
|||||||
durMs := duration.Milliseconds()
|
durMs := duration.Milliseconds()
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
msg := strings.TrimSpace(result)
|
||||||
|
if msg == "" {
|
||||||
|
msg = "ok"
|
||||||
|
}
|
||||||
|
msg = truncateString(msg, 2048)
|
||||||
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
|
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
|
||||||
JobName: opsAlertEvaluatorJobName,
|
JobName: opsAlertEvaluatorJobName,
|
||||||
LastRunAt: &runAt,
|
LastRunAt: &runAt,
|
||||||
LastSuccessAt: &now,
|
LastSuccessAt: &now,
|
||||||
LastDurationMs: &durMs,
|
LastDurationMs: &durMs,
|
||||||
|
LastResult: &msg,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import "time"
|
|||||||
const (
|
const (
|
||||||
OpsAlertStatusFiring = "firing"
|
OpsAlertStatusFiring = "firing"
|
||||||
OpsAlertStatusResolved = "resolved"
|
OpsAlertStatusResolved = "resolved"
|
||||||
|
OpsAlertStatusManualResolved = "manual_resolved"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OpsAlertRule struct {
|
type OpsAlertRule struct {
|
||||||
@@ -58,12 +59,32 @@ type OpsAlertEvent struct {
|
|||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OpsAlertSilence struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
|
||||||
|
RuleID int64 `json:"rule_id"`
|
||||||
|
Platform string `json:"platform"`
|
||||||
|
GroupID *int64 `json:"group_id,omitempty"`
|
||||||
|
Region *string `json:"region,omitempty"`
|
||||||
|
|
||||||
|
Until time.Time `json:"until"`
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
|
||||||
|
CreatedBy *int64 `json:"created_by,omitempty"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
type OpsAlertEventFilter struct {
|
type OpsAlertEventFilter struct {
|
||||||
Limit int
|
Limit int
|
||||||
|
|
||||||
|
// Cursor pagination (descending by fired_at, then id).
|
||||||
|
BeforeFiredAt *time.Time
|
||||||
|
BeforeID *int64
|
||||||
|
|
||||||
// Optional filters.
|
// Optional filters.
|
||||||
Status string
|
Status string
|
||||||
Severity string
|
Severity string
|
||||||
|
EmailSent *bool
|
||||||
|
|
||||||
StartTime *time.Time
|
StartTime *time.Time
|
||||||
EndTime *time.Time
|
EndTime *time.Time
|
||||||
|
|||||||
@@ -88,6 +88,29 @@ func (s *OpsService) ListAlertEvents(ctx context.Context, filter *OpsAlertEventF
|
|||||||
return s.opsRepo.ListAlertEvents(ctx, filter)
|
return s.opsRepo.ListAlertEvents(ctx, filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpsService) GetAlertEventByID(ctx context.Context, eventID int64) (*OpsAlertEvent, error) {
|
||||||
|
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.opsRepo == nil {
|
||||||
|
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||||
|
}
|
||||||
|
if eventID <= 0 {
|
||||||
|
return nil, infraerrors.BadRequest("INVALID_EVENT_ID", "invalid event id")
|
||||||
|
}
|
||||||
|
ev, err := s.opsRepo.GetAlertEventByID(ctx, eventID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, infraerrors.NotFound("OPS_ALERT_EVENT_NOT_FOUND", "alert event not found")
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if ev == nil {
|
||||||
|
return nil, infraerrors.NotFound("OPS_ALERT_EVENT_NOT_FOUND", "alert event not found")
|
||||||
|
}
|
||||||
|
return ev, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpsService) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
|
func (s *OpsService) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
|
||||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -101,6 +124,49 @@ func (s *OpsService) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*Op
|
|||||||
return s.opsRepo.GetActiveAlertEvent(ctx, ruleID)
|
return s.opsRepo.GetActiveAlertEvent(ctx, ruleID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpsService) CreateAlertSilence(ctx context.Context, input *OpsAlertSilence) (*OpsAlertSilence, error) {
|
||||||
|
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.opsRepo == nil {
|
||||||
|
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||||
|
}
|
||||||
|
if input == nil {
|
||||||
|
return nil, infraerrors.BadRequest("INVALID_SILENCE", "invalid silence")
|
||||||
|
}
|
||||||
|
if input.RuleID <= 0 {
|
||||||
|
return nil, infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(input.Platform) == "" {
|
||||||
|
return nil, infraerrors.BadRequest("INVALID_PLATFORM", "invalid platform")
|
||||||
|
}
|
||||||
|
if input.Until.IsZero() {
|
||||||
|
return nil, infraerrors.BadRequest("INVALID_UNTIL", "invalid until")
|
||||||
|
}
|
||||||
|
|
||||||
|
created, err := s.opsRepo.CreateAlertSilence(ctx, input)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return created, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsService) IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error) {
|
||||||
|
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if s.opsRepo == nil {
|
||||||
|
return false, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||||
|
}
|
||||||
|
if ruleID <= 0 {
|
||||||
|
return false, infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(platform) == "" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return s.opsRepo.IsAlertSilenced(ctx, ruleID, platform, groupID, region, now)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpsService) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
|
func (s *OpsService) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
|
||||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -142,7 +208,11 @@ func (s *OpsService) UpdateAlertEventStatus(ctx context.Context, eventID int64,
|
|||||||
if eventID <= 0 {
|
if eventID <= 0 {
|
||||||
return infraerrors.BadRequest("INVALID_EVENT_ID", "invalid event id")
|
return infraerrors.BadRequest("INVALID_EVENT_ID", "invalid event id")
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(status) == "" {
|
status = strings.TrimSpace(status)
|
||||||
|
if status == "" {
|
||||||
|
return infraerrors.BadRequest("INVALID_STATUS", "invalid status")
|
||||||
|
}
|
||||||
|
if status != OpsAlertStatusResolved && status != OpsAlertStatusManualResolved {
|
||||||
return infraerrors.BadRequest("INVALID_STATUS", "invalid status")
|
return infraerrors.BadRequest("INVALID_STATUS", "invalid status")
|
||||||
}
|
}
|
||||||
return s.opsRepo.UpdateAlertEventStatus(ctx, eventID, status, resolvedAt)
|
return s.opsRepo.UpdateAlertEventStatus(ctx, eventID, status, resolvedAt)
|
||||||
|
|||||||
@@ -149,7 +149,7 @@ func (s *OpsCleanupService) runScheduled() {
|
|||||||
log.Printf("[OpsCleanup] cleanup failed: %v", err)
|
log.Printf("[OpsCleanup] cleanup failed: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.recordHeartbeatSuccess(runAt, time.Since(startedAt))
|
s.recordHeartbeatSuccess(runAt, time.Since(startedAt), counts)
|
||||||
log.Printf("[OpsCleanup] cleanup complete: %s", counts)
|
log.Printf("[OpsCleanup] cleanup complete: %s", counts)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -330,12 +330,13 @@ func (s *OpsCleanupService) tryAcquireLeaderLock(ctx context.Context) (func(), b
|
|||||||
return release, true
|
return release, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpsCleanupService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration) {
|
func (s *OpsCleanupService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration, counts opsCleanupDeletedCounts) {
|
||||||
if s == nil || s.opsRepo == nil {
|
if s == nil || s.opsRepo == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
durMs := duration.Milliseconds()
|
durMs := duration.Milliseconds()
|
||||||
|
result := truncateString(counts.String(), 2048)
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
|
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
|
||||||
@@ -343,6 +344,7 @@ func (s *OpsCleanupService) recordHeartbeatSuccess(runAt time.Time, duration tim
|
|||||||
LastRunAt: &runAt,
|
LastRunAt: &runAt,
|
||||||
LastSuccessAt: &now,
|
LastSuccessAt: &now,
|
||||||
LastDurationMs: &durMs,
|
LastDurationMs: &durMs,
|
||||||
|
LastResult: &result,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -32,49 +32,38 @@ func computeDashboardHealthScore(now time.Time, overview *OpsDashboardOverview)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// computeBusinessHealth calculates business health score (0-100)
|
// computeBusinessHealth calculates business health score (0-100)
|
||||||
// Components: SLA (50%) + Error Rate (30%) + Latency (20%)
|
// Components: Error Rate (50%) + TTFT (50%)
|
||||||
func computeBusinessHealth(overview *OpsDashboardOverview) float64 {
|
func computeBusinessHealth(overview *OpsDashboardOverview) float64 {
|
||||||
// SLA score: 99.5% → 100, 95% → 0 (linear)
|
// Error rate score: 1% → 100, 10% → 0 (linear)
|
||||||
slaScore := 100.0
|
|
||||||
slaPct := clampFloat64(overview.SLA*100, 0, 100)
|
|
||||||
if slaPct < 99.5 {
|
|
||||||
if slaPct >= 95 {
|
|
||||||
slaScore = (slaPct - 95) / 4.5 * 100
|
|
||||||
} else {
|
|
||||||
slaScore = 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error rate score: 0.5% → 100, 5% → 0 (linear)
|
|
||||||
// Combines request errors and upstream errors
|
// Combines request errors and upstream errors
|
||||||
errorScore := 100.0
|
errorScore := 100.0
|
||||||
errorPct := clampFloat64(overview.ErrorRate*100, 0, 100)
|
errorPct := clampFloat64(overview.ErrorRate*100, 0, 100)
|
||||||
upstreamPct := clampFloat64(overview.UpstreamErrorRate*100, 0, 100)
|
upstreamPct := clampFloat64(overview.UpstreamErrorRate*100, 0, 100)
|
||||||
combinedErrorPct := math.Max(errorPct, upstreamPct) // Use worst case
|
combinedErrorPct := math.Max(errorPct, upstreamPct) // Use worst case
|
||||||
if combinedErrorPct > 0.5 {
|
if combinedErrorPct > 1.0 {
|
||||||
if combinedErrorPct <= 5 {
|
if combinedErrorPct <= 10.0 {
|
||||||
errorScore = (5 - combinedErrorPct) / 4.5 * 100
|
errorScore = (10.0 - combinedErrorPct) / 9.0 * 100
|
||||||
} else {
|
} else {
|
||||||
errorScore = 0
|
errorScore = 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Latency score: 1s → 100, 10s → 0 (linear)
|
// TTFT score: 1s → 100, 3s → 0 (linear)
|
||||||
// Uses P99 of duration (TTFT is less critical for overall health)
|
// Time to first token is critical for user experience
|
||||||
latencyScore := 100.0
|
ttftScore := 100.0
|
||||||
if overview.Duration.P99 != nil {
|
if overview.TTFT.P99 != nil {
|
||||||
p99 := float64(*overview.Duration.P99)
|
p99 := float64(*overview.TTFT.P99)
|
||||||
if p99 > 1000 {
|
if p99 > 1000 {
|
||||||
if p99 <= 10000 {
|
if p99 <= 3000 {
|
||||||
latencyScore = (10000 - p99) / 9000 * 100
|
ttftScore = (3000 - p99) / 2000 * 100
|
||||||
} else {
|
} else {
|
||||||
latencyScore = 0
|
ttftScore = 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Weighted combination
|
// Weighted combination: 50% error rate + 50% TTFT
|
||||||
return slaScore*0.5 + errorScore*0.3 + latencyScore*0.2
|
return errorScore*0.5 + ttftScore*0.5
|
||||||
}
|
}
|
||||||
|
|
||||||
// computeInfraHealth calculates infrastructure health score (0-100)
|
// computeInfraHealth calculates infrastructure health score (0-100)
|
||||||
|
|||||||
@@ -127,8 +127,8 @@ func TestComputeDashboardHealthScore_Comprehensive(t *testing.T) {
|
|||||||
MemoryUsagePercent: float64Ptr(75),
|
MemoryUsagePercent: float64Ptr(75),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
wantMin: 60,
|
wantMin: 96,
|
||||||
wantMax: 85,
|
wantMax: 97,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "DB failure",
|
name: "DB failure",
|
||||||
@@ -203,8 +203,8 @@ func TestComputeDashboardHealthScore_Comprehensive(t *testing.T) {
|
|||||||
MemoryUsagePercent: float64Ptr(30),
|
MemoryUsagePercent: float64Ptr(30),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
wantMin: 25,
|
wantMin: 84,
|
||||||
wantMax: 50,
|
wantMax: 85,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "combined failures - business healthy + infra degraded",
|
name: "combined failures - business healthy + infra degraded",
|
||||||
@@ -277,30 +277,41 @@ func TestComputeBusinessHealth(t *testing.T) {
|
|||||||
UpstreamErrorRate: 0,
|
UpstreamErrorRate: 0,
|
||||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||||
},
|
},
|
||||||
wantMin: 50,
|
wantMin: 100,
|
||||||
wantMax: 60,
|
wantMax: 100,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "error rate boundary 0.5%",
|
name: "error rate boundary 1%",
|
||||||
overview: &OpsDashboardOverview{
|
overview: &OpsDashboardOverview{
|
||||||
SLA: 0.995,
|
SLA: 0.99,
|
||||||
ErrorRate: 0.005,
|
ErrorRate: 0.01,
|
||||||
UpstreamErrorRate: 0,
|
UpstreamErrorRate: 0,
|
||||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||||
},
|
},
|
||||||
wantMin: 95,
|
wantMin: 100,
|
||||||
wantMax: 100,
|
wantMax: 100,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "latency boundary 1000ms",
|
name: "error rate 5%",
|
||||||
overview: &OpsDashboardOverview{
|
overview: &OpsDashboardOverview{
|
||||||
SLA: 0.995,
|
SLA: 0.95,
|
||||||
|
ErrorRate: 0.05,
|
||||||
|
UpstreamErrorRate: 0,
|
||||||
|
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||||
|
},
|
||||||
|
wantMin: 77,
|
||||||
|
wantMax: 78,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TTFT boundary 2s",
|
||||||
|
overview: &OpsDashboardOverview{
|
||||||
|
SLA: 0.99,
|
||||||
ErrorRate: 0,
|
ErrorRate: 0,
|
||||||
UpstreamErrorRate: 0,
|
UpstreamErrorRate: 0,
|
||||||
Duration: OpsPercentiles{P99: intPtr(1000)},
|
TTFT: OpsPercentiles{P99: intPtr(2000)},
|
||||||
},
|
},
|
||||||
wantMin: 95,
|
wantMin: 75,
|
||||||
wantMax: 100,
|
wantMax: 75,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "upstream error dominates",
|
name: "upstream error dominates",
|
||||||
@@ -310,7 +321,7 @@ func TestComputeBusinessHealth(t *testing.T) {
|
|||||||
UpstreamErrorRate: 0.03,
|
UpstreamErrorRate: 0.03,
|
||||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||||
},
|
},
|
||||||
wantMin: 75,
|
wantMin: 88,
|
||||||
wantMax: 90,
|
wantMax: 90,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,24 +6,43 @@ type OpsErrorLog struct {
|
|||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
|
||||||
|
// Standardized classification
|
||||||
|
// - phase: request|auth|routing|upstream|network|internal
|
||||||
|
// - owner: client|provider|platform
|
||||||
|
// - source: client_request|upstream_http|gateway
|
||||||
Phase string `json:"phase"`
|
Phase string `json:"phase"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
|
||||||
|
Owner string `json:"error_owner"`
|
||||||
|
Source string `json:"error_source"`
|
||||||
|
|
||||||
Severity string `json:"severity"`
|
Severity string `json:"severity"`
|
||||||
|
|
||||||
StatusCode int `json:"status_code"`
|
StatusCode int `json:"status_code"`
|
||||||
Platform string `json:"platform"`
|
Platform string `json:"platform"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
|
||||||
LatencyMs *int `json:"latency_ms"`
|
IsRetryable bool `json:"is_retryable"`
|
||||||
|
RetryCount int `json:"retry_count"`
|
||||||
|
|
||||||
|
Resolved bool `json:"resolved"`
|
||||||
|
ResolvedAt *time.Time `json:"resolved_at"`
|
||||||
|
ResolvedByUserID *int64 `json:"resolved_by_user_id"`
|
||||||
|
ResolvedByUserName string `json:"resolved_by_user_name"`
|
||||||
|
ResolvedRetryID *int64 `json:"resolved_retry_id"`
|
||||||
|
ResolvedStatusRaw string `json:"-"`
|
||||||
|
|
||||||
ClientRequestID string `json:"client_request_id"`
|
ClientRequestID string `json:"client_request_id"`
|
||||||
RequestID string `json:"request_id"`
|
RequestID string `json:"request_id"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
|
|
||||||
UserID *int64 `json:"user_id"`
|
UserID *int64 `json:"user_id"`
|
||||||
|
UserEmail string `json:"user_email"`
|
||||||
APIKeyID *int64 `json:"api_key_id"`
|
APIKeyID *int64 `json:"api_key_id"`
|
||||||
AccountID *int64 `json:"account_id"`
|
AccountID *int64 `json:"account_id"`
|
||||||
|
AccountName string `json:"account_name"`
|
||||||
GroupID *int64 `json:"group_id"`
|
GroupID *int64 `json:"group_id"`
|
||||||
|
GroupName string `json:"group_name"`
|
||||||
|
|
||||||
ClientIP *string `json:"client_ip"`
|
ClientIP *string `json:"client_ip"`
|
||||||
RequestPath string `json:"request_path"`
|
RequestPath string `json:"request_path"`
|
||||||
@@ -68,8 +87,23 @@ type OpsErrorLogFilter struct {
|
|||||||
AccountID *int64
|
AccountID *int64
|
||||||
|
|
||||||
StatusCodes []int
|
StatusCodes []int
|
||||||
|
StatusCodesOther bool
|
||||||
Phase string
|
Phase string
|
||||||
|
Owner string
|
||||||
|
Source string
|
||||||
|
Resolved *bool
|
||||||
Query string
|
Query string
|
||||||
|
UserQuery string // Search by user email
|
||||||
|
|
||||||
|
// Optional correlation keys for exact matching.
|
||||||
|
RequestID string
|
||||||
|
ClientRequestID string
|
||||||
|
|
||||||
|
// View controls error categorization for list endpoints.
|
||||||
|
// - errors: show actionable errors (exclude business-limited / 429 / 529)
|
||||||
|
// - excluded: only show excluded errors
|
||||||
|
// - all: show everything
|
||||||
|
View string
|
||||||
|
|
||||||
Page int
|
Page int
|
||||||
PageSize int
|
PageSize int
|
||||||
@@ -90,12 +124,23 @@ type OpsRetryAttempt struct {
|
|||||||
SourceErrorID int64 `json:"source_error_id"`
|
SourceErrorID int64 `json:"source_error_id"`
|
||||||
Mode string `json:"mode"`
|
Mode string `json:"mode"`
|
||||||
PinnedAccountID *int64 `json:"pinned_account_id"`
|
PinnedAccountID *int64 `json:"pinned_account_id"`
|
||||||
|
PinnedAccountName string `json:"pinned_account_name"`
|
||||||
|
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
StartedAt *time.Time `json:"started_at"`
|
StartedAt *time.Time `json:"started_at"`
|
||||||
FinishedAt *time.Time `json:"finished_at"`
|
FinishedAt *time.Time `json:"finished_at"`
|
||||||
DurationMs *int64 `json:"duration_ms"`
|
DurationMs *int64 `json:"duration_ms"`
|
||||||
|
|
||||||
|
// Persisted execution results (best-effort)
|
||||||
|
Success *bool `json:"success"`
|
||||||
|
HTTPStatusCode *int `json:"http_status_code"`
|
||||||
|
UpstreamRequestID *string `json:"upstream_request_id"`
|
||||||
|
UsedAccountID *int64 `json:"used_account_id"`
|
||||||
|
UsedAccountName string `json:"used_account_name"`
|
||||||
|
ResponsePreview *string `json:"response_preview"`
|
||||||
|
ResponseTruncated *bool `json:"response_truncated"`
|
||||||
|
|
||||||
|
// Optional correlation
|
||||||
ResultRequestID *string `json:"result_request_id"`
|
ResultRequestID *string `json:"result_request_id"`
|
||||||
ResultErrorID *int64 `json:"result_error_id"`
|
ResultErrorID *int64 `json:"result_error_id"`
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ type OpsRepository interface {
|
|||||||
InsertRetryAttempt(ctx context.Context, input *OpsInsertRetryAttemptInput) (int64, error)
|
InsertRetryAttempt(ctx context.Context, input *OpsInsertRetryAttemptInput) (int64, error)
|
||||||
UpdateRetryAttempt(ctx context.Context, input *OpsUpdateRetryAttemptInput) error
|
UpdateRetryAttempt(ctx context.Context, input *OpsUpdateRetryAttemptInput) error
|
||||||
GetLatestRetryAttemptForError(ctx context.Context, sourceErrorID int64) (*OpsRetryAttempt, error)
|
GetLatestRetryAttemptForError(ctx context.Context, sourceErrorID int64) (*OpsRetryAttempt, error)
|
||||||
|
ListRetryAttemptsByErrorID(ctx context.Context, sourceErrorID int64, limit int) ([]*OpsRetryAttempt, error)
|
||||||
|
UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64, resolvedAt *time.Time) error
|
||||||
|
|
||||||
// Lightweight window stats (for realtime WS / quick sampling).
|
// Lightweight window stats (for realtime WS / quick sampling).
|
||||||
GetWindowStats(ctx context.Context, filter *OpsDashboardFilter) (*OpsWindowStats, error)
|
GetWindowStats(ctx context.Context, filter *OpsDashboardFilter) (*OpsWindowStats, error)
|
||||||
@@ -39,12 +41,17 @@ type OpsRepository interface {
|
|||||||
DeleteAlertRule(ctx context.Context, id int64) error
|
DeleteAlertRule(ctx context.Context, id int64) error
|
||||||
|
|
||||||
ListAlertEvents(ctx context.Context, filter *OpsAlertEventFilter) ([]*OpsAlertEvent, error)
|
ListAlertEvents(ctx context.Context, filter *OpsAlertEventFilter) ([]*OpsAlertEvent, error)
|
||||||
|
GetAlertEventByID(ctx context.Context, eventID int64) (*OpsAlertEvent, error)
|
||||||
GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error)
|
GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error)
|
||||||
GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error)
|
GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error)
|
||||||
CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) (*OpsAlertEvent, error)
|
CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) (*OpsAlertEvent, error)
|
||||||
UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error
|
UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error
|
||||||
UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error
|
UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error
|
||||||
|
|
||||||
|
// Alert silences
|
||||||
|
CreateAlertSilence(ctx context.Context, input *OpsAlertSilence) (*OpsAlertSilence, error)
|
||||||
|
IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error)
|
||||||
|
|
||||||
// Pre-aggregation (hourly/daily) used for long-window dashboard performance.
|
// Pre-aggregation (hourly/daily) used for long-window dashboard performance.
|
||||||
UpsertHourlyMetrics(ctx context.Context, startTime, endTime time.Time) error
|
UpsertHourlyMetrics(ctx context.Context, startTime, endTime time.Time) error
|
||||||
UpsertDailyMetrics(ctx context.Context, startTime, endTime time.Time) error
|
UpsertDailyMetrics(ctx context.Context, startTime, endTime time.Time) error
|
||||||
@@ -91,7 +98,6 @@ type OpsInsertErrorLogInput struct {
|
|||||||
// It is set by OpsService.RecordError before persisting.
|
// It is set by OpsService.RecordError before persisting.
|
||||||
UpstreamErrorsJSON *string
|
UpstreamErrorsJSON *string
|
||||||
|
|
||||||
DurationMs *int
|
|
||||||
TimeToFirstTokenMs *int64
|
TimeToFirstTokenMs *int64
|
||||||
|
|
||||||
RequestBodyJSON *string // sanitized json string (not raw bytes)
|
RequestBodyJSON *string // sanitized json string (not raw bytes)
|
||||||
@@ -124,7 +130,15 @@ type OpsUpdateRetryAttemptInput struct {
|
|||||||
FinishedAt time.Time
|
FinishedAt time.Time
|
||||||
DurationMs int64
|
DurationMs int64
|
||||||
|
|
||||||
// Optional correlation
|
// Persisted execution results (best-effort)
|
||||||
|
Success *bool
|
||||||
|
HTTPStatusCode *int
|
||||||
|
UpstreamRequestID *string
|
||||||
|
UsedAccountID *int64
|
||||||
|
ResponsePreview *string
|
||||||
|
ResponseTruncated *bool
|
||||||
|
|
||||||
|
// Optional correlation (legacy fields kept)
|
||||||
ResultRequestID *string
|
ResultRequestID *string
|
||||||
ResultErrorID *int64
|
ResultErrorID *int64
|
||||||
|
|
||||||
@@ -221,6 +235,9 @@ type OpsUpsertJobHeartbeatInput struct {
|
|||||||
LastErrorAt *time.Time
|
LastErrorAt *time.Time
|
||||||
LastError *string
|
LastError *string
|
||||||
LastDurationMs *int64
|
LastDurationMs *int64
|
||||||
|
|
||||||
|
// LastResult is an optional human-readable summary of the last successful run.
|
||||||
|
LastResult *string
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpsJobHeartbeat struct {
|
type OpsJobHeartbeat struct {
|
||||||
@@ -231,6 +248,7 @@ type OpsJobHeartbeat struct {
|
|||||||
LastErrorAt *time.Time `json:"last_error_at"`
|
LastErrorAt *time.Time `json:"last_error_at"`
|
||||||
LastError *string `json:"last_error"`
|
LastError *string `json:"last_error"`
|
||||||
LastDurationMs *int64 `json:"last_duration_ms"`
|
LastDurationMs *int64 `json:"last_duration_ms"`
|
||||||
|
LastResult *string `json:"last_result"`
|
||||||
|
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -108,6 +108,10 @@ func (w *limitedResponseWriter) truncated() bool {
|
|||||||
return w.totalWritten > int64(w.limit)
|
return w.totalWritten > int64(w.limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
OpsRetryModeUpstreamEvent = "upstream_event"
|
||||||
|
)
|
||||||
|
|
||||||
func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, errorID int64, mode string, pinnedAccountID *int64) (*OpsRetryResult, error) {
|
func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, errorID int64, mode string, pinnedAccountID *int64) (*OpsRetryResult, error) {
|
||||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -123,6 +127,81 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
|
|||||||
return nil, infraerrors.BadRequest("OPS_RETRY_INVALID_MODE", "mode must be client or upstream")
|
return nil, infraerrors.BadRequest("OPS_RETRY_INVALID_MODE", "mode must be client or upstream")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
errorLog, err := s.GetErrorLogByID(ctx, errorID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if errorLog == nil {
|
||||||
|
return nil, infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(errorLog.RequestBody) == "" {
|
||||||
|
return nil, infraerrors.BadRequest("OPS_RETRY_NO_REQUEST_BODY", "No request body found to retry")
|
||||||
|
}
|
||||||
|
|
||||||
|
var pinned *int64
|
||||||
|
if mode == OpsRetryModeUpstream {
|
||||||
|
if pinnedAccountID != nil && *pinnedAccountID > 0 {
|
||||||
|
pinned = pinnedAccountID
|
||||||
|
} else if errorLog.AccountID != nil && *errorLog.AccountID > 0 {
|
||||||
|
pinned = errorLog.AccountID
|
||||||
|
} else {
|
||||||
|
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "pinned_account_id is required for upstream retry")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.retryWithErrorLog(ctx, requestedByUserID, errorID, mode, mode, pinned, errorLog)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetryUpstreamEvent retries a specific upstream attempt captured inside ops_error_logs.upstream_errors.
|
||||||
|
// idx is 0-based. It always pins the original event account_id.
|
||||||
|
func (s *OpsService) RetryUpstreamEvent(ctx context.Context, requestedByUserID int64, errorID int64, idx int) (*OpsRetryResult, error) {
|
||||||
|
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.opsRepo == nil {
|
||||||
|
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||||
|
}
|
||||||
|
if idx < 0 {
|
||||||
|
return nil, infraerrors.BadRequest("OPS_RETRY_INVALID_UPSTREAM_IDX", "invalid upstream idx")
|
||||||
|
}
|
||||||
|
|
||||||
|
errorLog, err := s.GetErrorLogByID(ctx, errorID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if errorLog == nil {
|
||||||
|
return nil, infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
events, err := ParseOpsUpstreamErrors(errorLog.UpstreamErrors)
|
||||||
|
if err != nil {
|
||||||
|
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_EVENTS_INVALID", "invalid upstream_errors")
|
||||||
|
}
|
||||||
|
if idx >= len(events) {
|
||||||
|
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_IDX_OOB", "upstream idx out of range")
|
||||||
|
}
|
||||||
|
ev := events[idx]
|
||||||
|
if ev == nil {
|
||||||
|
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_EVENT_MISSING", "upstream event missing")
|
||||||
|
}
|
||||||
|
if ev.AccountID <= 0 {
|
||||||
|
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "account_id is required for upstream retry")
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamBody := strings.TrimSpace(ev.UpstreamRequestBody)
|
||||||
|
if upstreamBody == "" {
|
||||||
|
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_NO_REQUEST_BODY", "No upstream request body found to retry")
|
||||||
|
}
|
||||||
|
|
||||||
|
override := *errorLog
|
||||||
|
override.RequestBody = upstreamBody
|
||||||
|
pinned := ev.AccountID
|
||||||
|
|
||||||
|
// Persist as upstream_event, execute as upstream pinned retry.
|
||||||
|
return s.retryWithErrorLog(ctx, requestedByUserID, errorID, OpsRetryModeUpstreamEvent, OpsRetryModeUpstream, &pinned, &override)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsService) retryWithErrorLog(ctx context.Context, requestedByUserID int64, errorID int64, mode string, execMode string, pinnedAccountID *int64, errorLog *OpsErrorLogDetail) (*OpsRetryResult, error) {
|
||||||
latest, err := s.opsRepo.GetLatestRetryAttemptForError(ctx, errorID)
|
latest, err := s.opsRepo.GetLatestRetryAttemptForError(ctx, errorID)
|
||||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, infraerrors.InternalServer("OPS_RETRY_LOAD_LATEST_FAILED", "Failed to check retry status").WithCause(err)
|
return nil, infraerrors.InternalServer("OPS_RETRY_LOAD_LATEST_FAILED", "Failed to check retry status").WithCause(err)
|
||||||
@@ -144,22 +223,18 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
errorLog, err := s.GetErrorLogByID(ctx, errorID)
|
if errorLog == nil || strings.TrimSpace(errorLog.RequestBody) == "" {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(errorLog.RequestBody) == "" {
|
|
||||||
return nil, infraerrors.BadRequest("OPS_RETRY_NO_REQUEST_BODY", "No request body found to retry")
|
return nil, infraerrors.BadRequest("OPS_RETRY_NO_REQUEST_BODY", "No request body found to retry")
|
||||||
}
|
}
|
||||||
|
|
||||||
var pinned *int64
|
var pinned *int64
|
||||||
if mode == OpsRetryModeUpstream {
|
if execMode == OpsRetryModeUpstream {
|
||||||
if pinnedAccountID != nil && *pinnedAccountID > 0 {
|
if pinnedAccountID != nil && *pinnedAccountID > 0 {
|
||||||
pinned = pinnedAccountID
|
pinned = pinnedAccountID
|
||||||
} else if errorLog.AccountID != nil && *errorLog.AccountID > 0 {
|
} else if errorLog.AccountID != nil && *errorLog.AccountID > 0 {
|
||||||
pinned = errorLog.AccountID
|
pinned = errorLog.AccountID
|
||||||
} else {
|
} else {
|
||||||
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "pinned_account_id is required for upstream retry")
|
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "account_id is required for upstream retry")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -196,7 +271,7 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
|
|||||||
execCtx, cancel := context.WithTimeout(ctx, opsRetryTimeout)
|
execCtx, cancel := context.WithTimeout(ctx, opsRetryTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
execRes := s.executeRetry(execCtx, errorLog, mode, pinned)
|
execRes := s.executeRetry(execCtx, errorLog, execMode, pinned)
|
||||||
|
|
||||||
finishedAt := time.Now()
|
finishedAt := time.Now()
|
||||||
result.FinishedAt = finishedAt
|
result.FinishedAt = finishedAt
|
||||||
@@ -220,27 +295,40 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
|
|||||||
msg := result.ErrorMessage
|
msg := result.ErrorMessage
|
||||||
updateErrMsg = &msg
|
updateErrMsg = &msg
|
||||||
}
|
}
|
||||||
|
// Keep legacy result_request_id empty; use upstream_request_id instead.
|
||||||
var resultRequestID *string
|
var resultRequestID *string
|
||||||
if strings.TrimSpace(result.UpstreamRequestID) != "" {
|
|
||||||
v := result.UpstreamRequestID
|
|
||||||
resultRequestID = &v
|
|
||||||
}
|
|
||||||
|
|
||||||
finalStatus := result.Status
|
finalStatus := result.Status
|
||||||
if strings.TrimSpace(finalStatus) == "" {
|
if strings.TrimSpace(finalStatus) == "" {
|
||||||
finalStatus = opsRetryStatusFailed
|
finalStatus = opsRetryStatusFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
success := strings.EqualFold(finalStatus, opsRetryStatusSucceeded)
|
||||||
|
httpStatus := result.HTTPStatusCode
|
||||||
|
upstreamReqID := result.UpstreamRequestID
|
||||||
|
usedAccountID := result.UsedAccountID
|
||||||
|
preview := result.ResponsePreview
|
||||||
|
truncated := result.ResponseTruncated
|
||||||
|
|
||||||
if err := s.opsRepo.UpdateRetryAttempt(updateCtx, &OpsUpdateRetryAttemptInput{
|
if err := s.opsRepo.UpdateRetryAttempt(updateCtx, &OpsUpdateRetryAttemptInput{
|
||||||
ID: attemptID,
|
ID: attemptID,
|
||||||
Status: finalStatus,
|
Status: finalStatus,
|
||||||
FinishedAt: finishedAt,
|
FinishedAt: finishedAt,
|
||||||
DurationMs: result.DurationMs,
|
DurationMs: result.DurationMs,
|
||||||
|
Success: &success,
|
||||||
|
HTTPStatusCode: &httpStatus,
|
||||||
|
UpstreamRequestID: &upstreamReqID,
|
||||||
|
UsedAccountID: usedAccountID,
|
||||||
|
ResponsePreview: &preview,
|
||||||
|
ResponseTruncated: &truncated,
|
||||||
ResultRequestID: resultRequestID,
|
ResultRequestID: resultRequestID,
|
||||||
ErrorMessage: updateErrMsg,
|
ErrorMessage: updateErrMsg,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
// Best-effort: retry itself already executed; do not fail the API response.
|
|
||||||
log.Printf("[Ops] UpdateRetryAttempt failed: %v", err)
|
log.Printf("[Ops] UpdateRetryAttempt failed: %v", err)
|
||||||
|
} else if success {
|
||||||
|
if err := s.opsRepo.UpdateErrorResolution(updateCtx, errorID, true, &requestedByUserID, &attemptID, &finishedAt); err != nil {
|
||||||
|
log.Printf("[Ops] UpdateErrorResolution failed: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
@@ -426,7 +514,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
|
|||||||
if s.gatewayService == nil {
|
if s.gatewayService == nil {
|
||||||
return nil, fmt.Errorf("gateway service not available")
|
return nil, fmt.Errorf("gateway service not available")
|
||||||
}
|
}
|
||||||
return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs)
|
return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "") // 重试不使用会话限制
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported retry type: %s", reqType)
|
return nil, fmt.Errorf("unsupported retry type: %s", reqType)
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user