mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-10 18:14:48 +08:00
Compare commits
143 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f890e85e7 | ||
|
|
10bc7f7042 | ||
|
|
a65fd9dee8 | ||
|
|
1bb4c76deb | ||
|
|
aab44f9fc8 | ||
|
|
0a848e7578 | ||
|
|
90bce60b85 | ||
|
|
c22d51ee41 | ||
|
|
a458e684bc | ||
|
|
87b4662993 | ||
|
|
3a100339b9 | ||
|
|
47eb3c8888 | ||
|
|
4672a6fac3 | ||
|
|
82743704e4 | ||
|
|
cc2d064ab4 | ||
|
|
27214f8657 | ||
|
|
28de614dfb | ||
|
|
850183c269 | ||
|
|
2a5ef6d3f5 | ||
|
|
1d231c6cc3 | ||
|
|
20c71acb3b | ||
|
|
52ad7c6e9c | ||
|
|
5aaaffe4d1 | ||
|
|
5354ba3662 | ||
|
|
2daf13c4c8 | ||
|
|
16a90f3d3a | ||
|
|
8a0ff15242 | ||
|
|
8c993dfd35 | ||
|
|
2a6fb1e456 | ||
|
|
9e6cd36af4 | ||
|
|
f25f992a30 | ||
|
|
841d7ef2f2 | ||
|
|
a7a49be850 | ||
|
|
d5eab7da3b | ||
|
|
9b10241561 | ||
|
|
76448ab555 | ||
|
|
9584af5cb4 | ||
|
|
6fabddcb0b | ||
|
|
5efeabb0c6 | ||
|
|
806f402bba | ||
|
|
5432087d96 | ||
|
|
02cb14c7b8 | ||
|
|
9bdb45be7c | ||
|
|
514c0562e0 | ||
|
|
371275ec34 | ||
|
|
ec24a3c361 | ||
|
|
d89e797bfc | ||
|
|
55e469c7fe | ||
|
|
fb99ceacc7 | ||
|
|
daf10907e4 | ||
|
|
b3b2868f55 | ||
|
|
25b00abca1 | ||
|
|
8d0767352b | ||
|
|
918a253851 | ||
|
|
63711067e6 | ||
|
|
7158b38897 | ||
|
|
7f317b9093 | ||
|
|
7c4309ea24 | ||
|
|
5013290486 | ||
|
|
8cf3e9a620 | ||
|
|
060699c3b8 | ||
|
|
2ca6c631ac | ||
|
|
967e25878f | ||
|
|
182683814b | ||
|
|
99cbfa1567 | ||
|
|
3f8c8d70ad | ||
|
|
9c567fad92 | ||
|
|
33f58d583d | ||
|
|
0abb3a6843 | ||
|
|
3663951d11 | ||
|
|
1e169685f4 | ||
|
|
f38a3e7585 | ||
|
|
b8da5d45ce | ||
|
|
659df6e220 | ||
|
|
d601768016 | ||
|
|
16ddc6a83b | ||
|
|
340dc9cadb | ||
|
|
55fced3942 | ||
|
|
7bbf49fd65 | ||
|
|
eea6c2d02c | ||
|
|
70eaa450db | ||
|
|
55796a118d | ||
|
|
d7fa47d732 | ||
|
|
3d6e01a58f | ||
|
|
f9713e8733 | ||
|
|
0e44829720 | ||
|
|
93db889a10 | ||
|
|
0df7385c4e | ||
|
|
1a3fa6411c | ||
|
|
64614756d1 | ||
|
|
bb1fd54d4d | ||
|
|
d85288a6c0 | ||
|
|
3402acb606 | ||
|
|
7fdc25df3c | ||
|
|
ea699cbdc2 | ||
|
|
fe6a3f4267 | ||
|
|
fe8198c8cd | ||
|
|
675e61385f | ||
|
|
67acac1082 | ||
|
|
d02e1db018 | ||
|
|
0da515071b | ||
|
|
524d80ae1c | ||
|
|
3b71bc3df1 | ||
|
|
22ef9534e0 | ||
|
|
c206d12d5c | ||
|
|
6ad29a470c | ||
|
|
2d45e61a9b | ||
|
|
b98fb013ae | ||
|
|
345a965fa3 | ||
|
|
c02c120579 | ||
|
|
4da681f58a | ||
|
|
68ba866c38 | ||
|
|
9622347faa | ||
|
|
8363663ea8 | ||
|
|
b588ea194c | ||
|
|
465ba76788 | ||
|
|
cf313d5761 | ||
|
|
9618cb5643 | ||
|
|
8c1958c9ad | ||
|
|
2db34139f0 | ||
|
|
9c02ab789d | ||
|
|
e0cccf6ed2 | ||
|
|
89c1a41305 | ||
|
|
202ec21bab | ||
|
|
6dcb27632e | ||
|
|
3141aa5144 | ||
|
|
5443efd7d7 | ||
|
|
62771583e7 | ||
|
|
5526f122b7 | ||
|
|
9c144587fe | ||
|
|
098bf5a1e8 | ||
|
|
4c37ca71ee | ||
|
|
0c52809591 | ||
|
|
53e730f8d5 | ||
|
|
8e248e0853 | ||
|
|
2a0758bdfe | ||
|
|
f55ba3f6c1 | ||
|
|
db51e65b42 | ||
|
|
72a2ed958b | ||
|
|
d0b91a40d4 | ||
|
|
bd74bf7994 | ||
|
|
f28d4b78e7 | ||
|
|
7536dbfee5 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -83,6 +83,8 @@ temp/
|
|||||||
*.log
|
*.log
|
||||||
*.bak
|
*.bak
|
||||||
.cache/
|
.cache/
|
||||||
|
.dev/
|
||||||
|
.serena/
|
||||||
|
|
||||||
# ===================
|
# ===================
|
||||||
# 构建产物
|
# 构建产物
|
||||||
|
|||||||
164
PR_DESCRIPTION.md
Normal file
164
PR_DESCRIPTION.md
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
## 概述
|
||||||
|
|
||||||
|
全面增强运维监控系统(Ops)的错误日志管理和告警静默功能,优化前端 UI 组件代码质量和用户体验。本次更新重构了核心服务层和数据访问层,提升系统可维护性和运维效率。
|
||||||
|
|
||||||
|
## 主要改动
|
||||||
|
|
||||||
|
### 1. 错误日志查询优化
|
||||||
|
|
||||||
|
**功能特性:**
|
||||||
|
- 新增 GetErrorLogByID 接口,支持按 ID 精确查询错误详情
|
||||||
|
- 优化错误日志过滤逻辑,支持多维度筛选(平台、阶段、来源、所有者等)
|
||||||
|
- 改进查询参数处理,简化代码结构
|
||||||
|
- 增强错误分类和标准化处理
|
||||||
|
- 支持错误解决状态追踪(resolved 字段)
|
||||||
|
|
||||||
|
**技术实现:**
|
||||||
|
- `ops_handler.go` - 新增单条错误日志查询接口
|
||||||
|
- `ops_repo.go` - 优化数据查询和过滤条件构建
|
||||||
|
- `ops_models.go` - 扩展错误日志数据模型
|
||||||
|
- 前端 API 接口同步更新
|
||||||
|
|
||||||
|
### 2. 告警静默功能
|
||||||
|
|
||||||
|
**功能特性:**
|
||||||
|
- 支持按规则、平台、分组、区域等维度静默告警
|
||||||
|
- 可设置静默时长和原因说明
|
||||||
|
- 静默记录可追溯,记录创建人和创建时间
|
||||||
|
- 自动过期机制,避免永久静默
|
||||||
|
|
||||||
|
**技术实现:**
|
||||||
|
- `037_ops_alert_silences.sql` - 新增告警静默表
|
||||||
|
- `ops_alerts.go` - 告警静默逻辑实现
|
||||||
|
- `ops_alerts_handler.go` - 告警静默 API 接口
|
||||||
|
- `OpsAlertEventsCard.vue` - 前端告警静默操作界面
|
||||||
|
|
||||||
|
**数据库结构:**
|
||||||
|
|
||||||
|
| 字段 | 类型 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| rule_id | BIGINT | 告警规则 ID |
|
||||||
|
| platform | VARCHAR(64) | 平台标识 |
|
||||||
|
| group_id | BIGINT | 分组 ID(可选) |
|
||||||
|
| region | VARCHAR(64) | 区域(可选) |
|
||||||
|
| until | TIMESTAMPTZ | 静默截止时间 |
|
||||||
|
| reason | TEXT | 静默原因 |
|
||||||
|
| created_by | BIGINT | 创建人 ID |
|
||||||
|
|
||||||
|
### 3. 错误分类标准化
|
||||||
|
|
||||||
|
**功能特性:**
|
||||||
|
- 统一错误阶段分类(request|auth|routing|upstream|network|internal)
|
||||||
|
- 规范错误归属分类(client|provider|platform)
|
||||||
|
- 标准化错误来源分类(client_request|upstream_http|gateway)
|
||||||
|
- 自动迁移历史数据到新分类体系
|
||||||
|
|
||||||
|
**技术实现:**
|
||||||
|
- `038_ops_errors_resolution_retry_results_and_standardize_classification.sql` - 分类标准化迁移
|
||||||
|
- 自动映射历史遗留分类到新标准
|
||||||
|
- 自动解决已恢复的上游错误(客户端状态码 < 400)
|
||||||
|
|
||||||
|
### 4. Gateway 服务集成
|
||||||
|
|
||||||
|
**功能特性:**
|
||||||
|
- 完善各 Gateway 服务的 Ops 集成
|
||||||
|
- 统一错误日志记录接口
|
||||||
|
- 增强上游错误追踪能力
|
||||||
|
|
||||||
|
**涉及服务:**
|
||||||
|
- `antigravity_gateway_service.go` - Antigravity 网关集成
|
||||||
|
- `gateway_service.go` - 通用网关集成
|
||||||
|
- `gemini_messages_compat_service.go` - Gemini 兼容层集成
|
||||||
|
- `openai_gateway_service.go` - OpenAI 网关集成
|
||||||
|
|
||||||
|
### 5. 前端 UI 优化
|
||||||
|
|
||||||
|
**代码重构:**
|
||||||
|
- 大幅简化错误详情模态框代码(从 828 行优化到 450 行)
|
||||||
|
- 优化错误日志表格组件,提升可读性
|
||||||
|
- 清理未使用的 i18n 翻译,减少冗余
|
||||||
|
- 统一组件代码风格和格式
|
||||||
|
- 优化骨架屏组件,更好匹配实际看板布局
|
||||||
|
|
||||||
|
**布局改进:**
|
||||||
|
- 修复模态框内容溢出和滚动问题
|
||||||
|
- 优化表格布局,使用 flex 布局确保正确显示
|
||||||
|
- 改进看板头部布局和交互
|
||||||
|
- 提升响应式体验
|
||||||
|
- 骨架屏支持全屏模式适配
|
||||||
|
|
||||||
|
**交互优化:**
|
||||||
|
- 优化告警事件卡片功能和展示
|
||||||
|
- 改进错误详情展示逻辑
|
||||||
|
- 增强请求详情模态框
|
||||||
|
- 完善运行时设置卡片
|
||||||
|
- 改进加载动画效果
|
||||||
|
|
||||||
|
### 6. 国际化完善
|
||||||
|
|
||||||
|
**文案补充:**
|
||||||
|
- 补充错误日志相关的英文翻译
|
||||||
|
- 添加告警静默功能的中英文文案
|
||||||
|
- 完善提示文本和错误信息
|
||||||
|
- 统一术语翻译标准
|
||||||
|
|
||||||
|
## 文件变更
|
||||||
|
|
||||||
|
**后端(26 个文件):**
|
||||||
|
- `backend/internal/handler/admin/ops_alerts_handler.go` - 告警接口增强
|
||||||
|
- `backend/internal/handler/admin/ops_handler.go` - 错误日志接口优化
|
||||||
|
- `backend/internal/handler/ops_error_logger.go` - 错误记录器增强
|
||||||
|
- `backend/internal/repository/ops_repo.go` - 数据访问层重构
|
||||||
|
- `backend/internal/repository/ops_repo_alerts.go` - 告警数据访问增强
|
||||||
|
- `backend/internal/service/ops_*.go` - 核心服务层重构(10 个文件)
|
||||||
|
- `backend/internal/service/*_gateway_service.go` - Gateway 集成(4 个文件)
|
||||||
|
- `backend/internal/server/routes/admin.go` - 路由配置更新
|
||||||
|
- `backend/migrations/*.sql` - 数据库迁移(2 个文件)
|
||||||
|
- 测试文件更新(5 个文件)
|
||||||
|
|
||||||
|
**前端(13 个文件):**
|
||||||
|
- `frontend/src/views/admin/ops/OpsDashboard.vue` - 看板主页优化
|
||||||
|
- `frontend/src/views/admin/ops/components/*.vue` - 组件重构(10 个文件)
|
||||||
|
- `frontend/src/api/admin/ops.ts` - API 接口扩展
|
||||||
|
- `frontend/src/i18n/locales/*.ts` - 国际化文本(2 个文件)
|
||||||
|
|
||||||
|
## 代码统计
|
||||||
|
|
||||||
|
- 44 个文件修改
|
||||||
|
- 3733 行新增
|
||||||
|
- 995 行删除
|
||||||
|
- 净增加 2738 行
|
||||||
|
|
||||||
|
## 核心改进
|
||||||
|
|
||||||
|
**可维护性提升:**
|
||||||
|
- 重构核心服务层,职责更清晰
|
||||||
|
- 简化前端组件代码,降低复杂度
|
||||||
|
- 统一代码风格和命名规范
|
||||||
|
- 清理冗余代码和未使用的翻译
|
||||||
|
- 标准化错误分类体系
|
||||||
|
|
||||||
|
**功能完善:**
|
||||||
|
- 告警静默功能,减少告警噪音
|
||||||
|
- 错误日志查询优化,提升运维效率
|
||||||
|
- Gateway 服务集成完善,统一监控能力
|
||||||
|
- 错误解决状态追踪,便于问题管理
|
||||||
|
|
||||||
|
**用户体验优化:**
|
||||||
|
- 修复多个 UI 布局问题
|
||||||
|
- 优化交互流程
|
||||||
|
- 完善国际化支持
|
||||||
|
- 提升响应式体验
|
||||||
|
- 改进加载状态展示
|
||||||
|
|
||||||
|
## 测试验证
|
||||||
|
|
||||||
|
- ✅ 错误日志查询和过滤功能
|
||||||
|
- ✅ 告警静默创建和自动过期
|
||||||
|
- ✅ 错误分类标准化迁移
|
||||||
|
- ✅ Gateway 服务错误日志记录
|
||||||
|
- ✅ 前端组件布局和交互
|
||||||
|
- ✅ 骨架屏全屏模式适配
|
||||||
|
- ✅ 国际化文本完整性
|
||||||
|
- ✅ API 接口功能正确性
|
||||||
|
- ✅ 数据库迁移执行成功
|
||||||
@@ -57,6 +57,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## OpenAI Responses 兼容注意事项
|
||||||
|
|
||||||
|
- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。
|
||||||
|
- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## 部署方式
|
## 部署方式
|
||||||
|
|
||||||
### 方式一:脚本安装(推荐)
|
### 方式一:脚本安装(推荐)
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ func provideCleanup(
|
|||||||
opsAlertEvaluator *service.OpsAlertEvaluatorService,
|
opsAlertEvaluator *service.OpsAlertEvaluatorService,
|
||||||
opsCleanup *service.OpsCleanupService,
|
opsCleanup *service.OpsCleanupService,
|
||||||
opsScheduledReport *service.OpsScheduledReportService,
|
opsScheduledReport *service.OpsScheduledReportService,
|
||||||
|
schedulerSnapshot *service.SchedulerSnapshotService,
|
||||||
tokenRefresh *service.TokenRefreshService,
|
tokenRefresh *service.TokenRefreshService,
|
||||||
accountExpiry *service.AccountExpiryService,
|
accountExpiry *service.AccountExpiryService,
|
||||||
pricing *service.PricingService,
|
pricing *service.PricingService,
|
||||||
@@ -116,6 +117,12 @@ func provideCleanup(
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}},
|
}},
|
||||||
|
{"SchedulerSnapshotService", func() error {
|
||||||
|
if schedulerSnapshot != nil {
|
||||||
|
schedulerSnapshot.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
{"TokenRefreshService", func() error {
|
{"TokenRefreshService", func() error {
|
||||||
tokenRefresh.Stop()
|
tokenRefresh.Stop()
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -67,7 +67,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
userHandler := handler.NewUserHandler(userService)
|
userHandler := handler.NewUserHandler(userService)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||||
dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
|
|
||||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
||||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||||
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
|
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
|
||||||
@@ -76,15 +75,17 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||||
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||||
|
dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
|
||||||
dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig)
|
dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig)
|
||||||
|
dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig)
|
||||||
timingWheelService := service.ProvideTimingWheelService()
|
timingWheelService := service.ProvideTimingWheelService()
|
||||||
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
|
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
|
||||||
dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig)
|
|
||||||
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
|
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
|
||||||
accountRepository := repository.NewAccountRepository(client, db)
|
accountRepository := repository.NewAccountRepository(client, db)
|
||||||
proxyRepository := repository.NewProxyRepository(client, db)
|
proxyRepository := repository.NewProxyRepository(client, db)
|
||||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, apiKeyAuthCacheInvalidator)
|
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||||
|
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
|
||||||
adminUserHandler := admin.NewUserHandler(adminService)
|
adminUserHandler := admin.NewUserHandler(adminService)
|
||||||
groupHandler := admin.NewGroupHandler(adminService)
|
groupHandler := admin.NewGroupHandler(adminService)
|
||||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||||
@@ -97,12 +98,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
|
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
|
||||||
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
|
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
|
||||||
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
||||||
rateLimitService := service.NewRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache)
|
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
||||||
|
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
||||||
|
tokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
||||||
|
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, tokenCacheInvalidator)
|
||||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
||||||
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
|
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
|
||||||
usageCache := service.NewUsageCache()
|
usageCache := service.NewUsageCache()
|
||||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache)
|
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache)
|
||||||
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
|
||||||
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
||||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||||
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
||||||
@@ -121,6 +124,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
||||||
promoHandler := admin.NewPromoHandler(promoService)
|
promoHandler := admin.NewPromoHandler(promoService)
|
||||||
opsRepository := repository.NewOpsRepository(db)
|
opsRepository := repository.NewOpsRepository(db)
|
||||||
|
schedulerCache := repository.NewSchedulerCache(redisClient)
|
||||||
|
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||||
|
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||||
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
|
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
|
||||||
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -130,9 +136,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
identityCache := repository.NewIdentityCache(redisClient)
|
identityCache := repository.NewIdentityCache(redisClient)
|
||||||
identityService := service.NewIdentityService(identityCache)
|
identityService := service.NewIdentityService(identityCache)
|
||||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
|
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
|
||||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
|
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
|
||||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService)
|
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService)
|
||||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
|
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
|
||||||
opsHandler := admin.NewOpsHandler(opsService)
|
opsHandler := admin.NewOpsHandler(opsService)
|
||||||
@@ -162,9 +168,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
|
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
|
||||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
||||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
|
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, tokenCacheInvalidator, configConfig)
|
||||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
||||||
application := &Application{
|
application := &Application{
|
||||||
Server: httpServer,
|
Server: httpServer,
|
||||||
Cleanup: v,
|
Cleanup: v,
|
||||||
@@ -194,6 +200,7 @@ func provideCleanup(
|
|||||||
opsAlertEvaluator *service.OpsAlertEvaluatorService,
|
opsAlertEvaluator *service.OpsAlertEvaluatorService,
|
||||||
opsCleanup *service.OpsCleanupService,
|
opsCleanup *service.OpsCleanupService,
|
||||||
opsScheduledReport *service.OpsScheduledReportService,
|
opsScheduledReport *service.OpsScheduledReportService,
|
||||||
|
schedulerSnapshot *service.SchedulerSnapshotService,
|
||||||
tokenRefresh *service.TokenRefreshService,
|
tokenRefresh *service.TokenRefreshService,
|
||||||
accountExpiry *service.AccountExpiryService,
|
accountExpiry *service.AccountExpiryService,
|
||||||
pricing *service.PricingService,
|
pricing *service.PricingService,
|
||||||
@@ -242,6 +249,12 @@ func provideCleanup(
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}},
|
}},
|
||||||
|
{"SchedulerSnapshotService", func() error {
|
||||||
|
if schedulerSnapshot != nil {
|
||||||
|
schedulerSnapshot.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
{"TokenRefreshService", func() error {
|
{"TokenRefreshService", func() error {
|
||||||
tokenRefresh.Stop()
|
tokenRefresh.Stop()
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -43,6 +43,8 @@ type Account struct {
|
|||||||
Concurrency int `json:"concurrency,omitempty"`
|
Concurrency int `json:"concurrency,omitempty"`
|
||||||
// Priority holds the value of the "priority" field.
|
// Priority holds the value of the "priority" field.
|
||||||
Priority int `json:"priority,omitempty"`
|
Priority int `json:"priority,omitempty"`
|
||||||
|
// RateMultiplier holds the value of the "rate_multiplier" field.
|
||||||
|
RateMultiplier float64 `json:"rate_multiplier,omitempty"`
|
||||||
// Status holds the value of the "status" field.
|
// Status holds the value of the "status" field.
|
||||||
Status string `json:"status,omitempty"`
|
Status string `json:"status,omitempty"`
|
||||||
// ErrorMessage holds the value of the "error_message" field.
|
// ErrorMessage holds the value of the "error_message" field.
|
||||||
@@ -135,6 +137,8 @@ func (*Account) scanValues(columns []string) ([]any, error) {
|
|||||||
values[i] = new([]byte)
|
values[i] = new([]byte)
|
||||||
case account.FieldAutoPauseOnExpired, account.FieldSchedulable:
|
case account.FieldAutoPauseOnExpired, account.FieldSchedulable:
|
||||||
values[i] = new(sql.NullBool)
|
values[i] = new(sql.NullBool)
|
||||||
|
case account.FieldRateMultiplier:
|
||||||
|
values[i] = new(sql.NullFloat64)
|
||||||
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority:
|
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldSessionWindowStatus:
|
case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldSessionWindowStatus:
|
||||||
@@ -241,6 +245,12 @@ func (_m *Account) assignValues(columns []string, values []any) error {
|
|||||||
} else if value.Valid {
|
} else if value.Valid {
|
||||||
_m.Priority = int(value.Int64)
|
_m.Priority = int(value.Int64)
|
||||||
}
|
}
|
||||||
|
case account.FieldRateMultiplier:
|
||||||
|
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field rate_multiplier", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.RateMultiplier = value.Float64
|
||||||
|
}
|
||||||
case account.FieldStatus:
|
case account.FieldStatus:
|
||||||
if value, ok := values[i].(*sql.NullString); !ok {
|
if value, ok := values[i].(*sql.NullString); !ok {
|
||||||
return fmt.Errorf("unexpected type %T for field status", values[i])
|
return fmt.Errorf("unexpected type %T for field status", values[i])
|
||||||
@@ -420,6 +430,9 @@ func (_m *Account) String() string {
|
|||||||
builder.WriteString("priority=")
|
builder.WriteString("priority=")
|
||||||
builder.WriteString(fmt.Sprintf("%v", _m.Priority))
|
builder.WriteString(fmt.Sprintf("%v", _m.Priority))
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("rate_multiplier=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.RateMultiplier))
|
||||||
|
builder.WriteString(", ")
|
||||||
builder.WriteString("status=")
|
builder.WriteString("status=")
|
||||||
builder.WriteString(_m.Status)
|
builder.WriteString(_m.Status)
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
|||||||
@@ -39,6 +39,8 @@ const (
|
|||||||
FieldConcurrency = "concurrency"
|
FieldConcurrency = "concurrency"
|
||||||
// FieldPriority holds the string denoting the priority field in the database.
|
// FieldPriority holds the string denoting the priority field in the database.
|
||||||
FieldPriority = "priority"
|
FieldPriority = "priority"
|
||||||
|
// FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
|
||||||
|
FieldRateMultiplier = "rate_multiplier"
|
||||||
// FieldStatus holds the string denoting the status field in the database.
|
// FieldStatus holds the string denoting the status field in the database.
|
||||||
FieldStatus = "status"
|
FieldStatus = "status"
|
||||||
// FieldErrorMessage holds the string denoting the error_message field in the database.
|
// FieldErrorMessage holds the string denoting the error_message field in the database.
|
||||||
@@ -116,6 +118,7 @@ var Columns = []string{
|
|||||||
FieldProxyID,
|
FieldProxyID,
|
||||||
FieldConcurrency,
|
FieldConcurrency,
|
||||||
FieldPriority,
|
FieldPriority,
|
||||||
|
FieldRateMultiplier,
|
||||||
FieldStatus,
|
FieldStatus,
|
||||||
FieldErrorMessage,
|
FieldErrorMessage,
|
||||||
FieldLastUsedAt,
|
FieldLastUsedAt,
|
||||||
@@ -174,6 +177,8 @@ var (
|
|||||||
DefaultConcurrency int
|
DefaultConcurrency int
|
||||||
// DefaultPriority holds the default value on creation for the "priority" field.
|
// DefaultPriority holds the default value on creation for the "priority" field.
|
||||||
DefaultPriority int
|
DefaultPriority int
|
||||||
|
// DefaultRateMultiplier holds the default value on creation for the "rate_multiplier" field.
|
||||||
|
DefaultRateMultiplier float64
|
||||||
// DefaultStatus holds the default value on creation for the "status" field.
|
// DefaultStatus holds the default value on creation for the "status" field.
|
||||||
DefaultStatus string
|
DefaultStatus string
|
||||||
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||||
@@ -244,6 +249,11 @@ func ByPriority(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldPriority, opts...).ToFunc()
|
return sql.OrderByField(FieldPriority, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ByRateMultiplier orders the results by the rate_multiplier field.
|
||||||
|
func ByRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldRateMultiplier, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// ByStatus orders the results by the status field.
|
// ByStatus orders the results by the status field.
|
||||||
func ByStatus(opts ...sql.OrderTermOption) OrderOption {
|
func ByStatus(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return sql.OrderByField(FieldStatus, opts...).ToFunc()
|
return sql.OrderByField(FieldStatus, opts...).ToFunc()
|
||||||
|
|||||||
@@ -105,6 +105,11 @@ func Priority(v int) predicate.Account {
|
|||||||
return predicate.Account(sql.FieldEQ(FieldPriority, v))
|
return predicate.Account(sql.FieldEQ(FieldPriority, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RateMultiplier applies equality check predicate on the "rate_multiplier" field. It's identical to RateMultiplierEQ.
|
||||||
|
func RateMultiplier(v float64) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldEQ(FieldRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
// Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
|
// Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
|
||||||
func Status(v string) predicate.Account {
|
func Status(v string) predicate.Account {
|
||||||
return predicate.Account(sql.FieldEQ(FieldStatus, v))
|
return predicate.Account(sql.FieldEQ(FieldStatus, v))
|
||||||
@@ -675,6 +680,46 @@ func PriorityLTE(v int) predicate.Account {
|
|||||||
return predicate.Account(sql.FieldLTE(FieldPriority, v))
|
return predicate.Account(sql.FieldLTE(FieldPriority, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RateMultiplierEQ applies the EQ predicate on the "rate_multiplier" field.
|
||||||
|
func RateMultiplierEQ(v float64) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldEQ(FieldRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateMultiplierNEQ applies the NEQ predicate on the "rate_multiplier" field.
|
||||||
|
func RateMultiplierNEQ(v float64) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldNEQ(FieldRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateMultiplierIn applies the In predicate on the "rate_multiplier" field.
|
||||||
|
func RateMultiplierIn(vs ...float64) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldIn(FieldRateMultiplier, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateMultiplierNotIn applies the NotIn predicate on the "rate_multiplier" field.
|
||||||
|
func RateMultiplierNotIn(vs ...float64) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldNotIn(FieldRateMultiplier, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateMultiplierGT applies the GT predicate on the "rate_multiplier" field.
|
||||||
|
func RateMultiplierGT(v float64) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldGT(FieldRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateMultiplierGTE applies the GTE predicate on the "rate_multiplier" field.
|
||||||
|
func RateMultiplierGTE(v float64) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldGTE(FieldRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateMultiplierLT applies the LT predicate on the "rate_multiplier" field.
|
||||||
|
func RateMultiplierLT(v float64) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldLT(FieldRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateMultiplierLTE applies the LTE predicate on the "rate_multiplier" field.
|
||||||
|
func RateMultiplierLTE(v float64) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldLTE(FieldRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
// StatusEQ applies the EQ predicate on the "status" field.
|
// StatusEQ applies the EQ predicate on the "status" field.
|
||||||
func StatusEQ(v string) predicate.Account {
|
func StatusEQ(v string) predicate.Account {
|
||||||
return predicate.Account(sql.FieldEQ(FieldStatus, v))
|
return predicate.Account(sql.FieldEQ(FieldStatus, v))
|
||||||
|
|||||||
@@ -153,6 +153,20 @@ func (_c *AccountCreate) SetNillablePriority(v *int) *AccountCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRateMultiplier sets the "rate_multiplier" field.
|
||||||
|
func (_c *AccountCreate) SetRateMultiplier(v float64) *AccountCreate {
|
||||||
|
_c.mutation.SetRateMultiplier(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil.
|
||||||
|
func (_c *AccountCreate) SetNillableRateMultiplier(v *float64) *AccountCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetRateMultiplier(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// SetStatus sets the "status" field.
|
// SetStatus sets the "status" field.
|
||||||
func (_c *AccountCreate) SetStatus(v string) *AccountCreate {
|
func (_c *AccountCreate) SetStatus(v string) *AccountCreate {
|
||||||
_c.mutation.SetStatus(v)
|
_c.mutation.SetStatus(v)
|
||||||
@@ -429,6 +443,10 @@ func (_c *AccountCreate) defaults() error {
|
|||||||
v := account.DefaultPriority
|
v := account.DefaultPriority
|
||||||
_c.mutation.SetPriority(v)
|
_c.mutation.SetPriority(v)
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.RateMultiplier(); !ok {
|
||||||
|
v := account.DefaultRateMultiplier
|
||||||
|
_c.mutation.SetRateMultiplier(v)
|
||||||
|
}
|
||||||
if _, ok := _c.mutation.Status(); !ok {
|
if _, ok := _c.mutation.Status(); !ok {
|
||||||
v := account.DefaultStatus
|
v := account.DefaultStatus
|
||||||
_c.mutation.SetStatus(v)
|
_c.mutation.SetStatus(v)
|
||||||
@@ -488,6 +506,9 @@ func (_c *AccountCreate) check() error {
|
|||||||
if _, ok := _c.mutation.Priority(); !ok {
|
if _, ok := _c.mutation.Priority(); !ok {
|
||||||
return &ValidationError{Name: "priority", err: errors.New(`ent: missing required field "Account.priority"`)}
|
return &ValidationError{Name: "priority", err: errors.New(`ent: missing required field "Account.priority"`)}
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.RateMultiplier(); !ok {
|
||||||
|
return &ValidationError{Name: "rate_multiplier", err: errors.New(`ent: missing required field "Account.rate_multiplier"`)}
|
||||||
|
}
|
||||||
if _, ok := _c.mutation.Status(); !ok {
|
if _, ok := _c.mutation.Status(); !ok {
|
||||||
return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "Account.status"`)}
|
return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "Account.status"`)}
|
||||||
}
|
}
|
||||||
@@ -578,6 +599,10 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(account.FieldPriority, field.TypeInt, value)
|
_spec.SetField(account.FieldPriority, field.TypeInt, value)
|
||||||
_node.Priority = value
|
_node.Priority = value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.RateMultiplier(); ok {
|
||||||
|
_spec.SetField(account.FieldRateMultiplier, field.TypeFloat64, value)
|
||||||
|
_node.RateMultiplier = value
|
||||||
|
}
|
||||||
if value, ok := _c.mutation.Status(); ok {
|
if value, ok := _c.mutation.Status(); ok {
|
||||||
_spec.SetField(account.FieldStatus, field.TypeString, value)
|
_spec.SetField(account.FieldStatus, field.TypeString, value)
|
||||||
_node.Status = value
|
_node.Status = value
|
||||||
@@ -893,6 +918,24 @@ func (u *AccountUpsert) AddPriority(v int) *AccountUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRateMultiplier sets the "rate_multiplier" field.
|
||||||
|
func (u *AccountUpsert) SetRateMultiplier(v float64) *AccountUpsert {
|
||||||
|
u.Set(account.FieldRateMultiplier, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create.
|
||||||
|
func (u *AccountUpsert) UpdateRateMultiplier() *AccountUpsert {
|
||||||
|
u.SetExcluded(account.FieldRateMultiplier)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRateMultiplier adds v to the "rate_multiplier" field.
|
||||||
|
func (u *AccountUpsert) AddRateMultiplier(v float64) *AccountUpsert {
|
||||||
|
u.Add(account.FieldRateMultiplier, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// SetStatus sets the "status" field.
|
// SetStatus sets the "status" field.
|
||||||
func (u *AccountUpsert) SetStatus(v string) *AccountUpsert {
|
func (u *AccountUpsert) SetStatus(v string) *AccountUpsert {
|
||||||
u.Set(account.FieldStatus, v)
|
u.Set(account.FieldStatus, v)
|
||||||
@@ -1325,6 +1368,27 @@ func (u *AccountUpsertOne) UpdatePriority() *AccountUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRateMultiplier sets the "rate_multiplier" field.
|
||||||
|
func (u *AccountUpsertOne) SetRateMultiplier(v float64) *AccountUpsertOne {
|
||||||
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
s.SetRateMultiplier(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRateMultiplier adds v to the "rate_multiplier" field.
|
||||||
|
func (u *AccountUpsertOne) AddRateMultiplier(v float64) *AccountUpsertOne {
|
||||||
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
s.AddRateMultiplier(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create.
|
||||||
|
func (u *AccountUpsertOne) UpdateRateMultiplier() *AccountUpsertOne {
|
||||||
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
s.UpdateRateMultiplier()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetStatus sets the "status" field.
|
// SetStatus sets the "status" field.
|
||||||
func (u *AccountUpsertOne) SetStatus(v string) *AccountUpsertOne {
|
func (u *AccountUpsertOne) SetStatus(v string) *AccountUpsertOne {
|
||||||
return u.Update(func(s *AccountUpsert) {
|
return u.Update(func(s *AccountUpsert) {
|
||||||
@@ -1956,6 +2020,27 @@ func (u *AccountUpsertBulk) UpdatePriority() *AccountUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRateMultiplier sets the "rate_multiplier" field.
|
||||||
|
func (u *AccountUpsertBulk) SetRateMultiplier(v float64) *AccountUpsertBulk {
|
||||||
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
s.SetRateMultiplier(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRateMultiplier adds v to the "rate_multiplier" field.
|
||||||
|
func (u *AccountUpsertBulk) AddRateMultiplier(v float64) *AccountUpsertBulk {
|
||||||
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
s.AddRateMultiplier(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create.
|
||||||
|
func (u *AccountUpsertBulk) UpdateRateMultiplier() *AccountUpsertBulk {
|
||||||
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
s.UpdateRateMultiplier()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetStatus sets the "status" field.
|
// SetStatus sets the "status" field.
|
||||||
func (u *AccountUpsertBulk) SetStatus(v string) *AccountUpsertBulk {
|
func (u *AccountUpsertBulk) SetStatus(v string) *AccountUpsertBulk {
|
||||||
return u.Update(func(s *AccountUpsert) {
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
|||||||
@@ -193,6 +193,27 @@ func (_u *AccountUpdate) AddPriority(v int) *AccountUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRateMultiplier sets the "rate_multiplier" field.
|
||||||
|
func (_u *AccountUpdate) SetRateMultiplier(v float64) *AccountUpdate {
|
||||||
|
_u.mutation.ResetRateMultiplier()
|
||||||
|
_u.mutation.SetRateMultiplier(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil.
|
||||||
|
func (_u *AccountUpdate) SetNillableRateMultiplier(v *float64) *AccountUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetRateMultiplier(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRateMultiplier adds value to the "rate_multiplier" field.
|
||||||
|
func (_u *AccountUpdate) AddRateMultiplier(v float64) *AccountUpdate {
|
||||||
|
_u.mutation.AddRateMultiplier(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetStatus sets the "status" field.
|
// SetStatus sets the "status" field.
|
||||||
func (_u *AccountUpdate) SetStatus(v string) *AccountUpdate {
|
func (_u *AccountUpdate) SetStatus(v string) *AccountUpdate {
|
||||||
_u.mutation.SetStatus(v)
|
_u.mutation.SetStatus(v)
|
||||||
@@ -629,6 +650,12 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if value, ok := _u.mutation.AddedPriority(); ok {
|
if value, ok := _u.mutation.AddedPriority(); ok {
|
||||||
_spec.AddField(account.FieldPriority, field.TypeInt, value)
|
_spec.AddField(account.FieldPriority, field.TypeInt, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.RateMultiplier(); ok {
|
||||||
|
_spec.SetField(account.FieldRateMultiplier, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
|
||||||
|
_spec.AddField(account.FieldRateMultiplier, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.Status(); ok {
|
if value, ok := _u.mutation.Status(); ok {
|
||||||
_spec.SetField(account.FieldStatus, field.TypeString, value)
|
_spec.SetField(account.FieldStatus, field.TypeString, value)
|
||||||
}
|
}
|
||||||
@@ -1005,6 +1032,27 @@ func (_u *AccountUpdateOne) AddPriority(v int) *AccountUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRateMultiplier sets the "rate_multiplier" field.
|
||||||
|
func (_u *AccountUpdateOne) SetRateMultiplier(v float64) *AccountUpdateOne {
|
||||||
|
_u.mutation.ResetRateMultiplier()
|
||||||
|
_u.mutation.SetRateMultiplier(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil.
|
||||||
|
func (_u *AccountUpdateOne) SetNillableRateMultiplier(v *float64) *AccountUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetRateMultiplier(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRateMultiplier adds value to the "rate_multiplier" field.
|
||||||
|
func (_u *AccountUpdateOne) AddRateMultiplier(v float64) *AccountUpdateOne {
|
||||||
|
_u.mutation.AddRateMultiplier(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetStatus sets the "status" field.
|
// SetStatus sets the "status" field.
|
||||||
func (_u *AccountUpdateOne) SetStatus(v string) *AccountUpdateOne {
|
func (_u *AccountUpdateOne) SetStatus(v string) *AccountUpdateOne {
|
||||||
_u.mutation.SetStatus(v)
|
_u.mutation.SetStatus(v)
|
||||||
@@ -1471,6 +1519,12 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er
|
|||||||
if value, ok := _u.mutation.AddedPriority(); ok {
|
if value, ok := _u.mutation.AddedPriority(); ok {
|
||||||
_spec.AddField(account.FieldPriority, field.TypeInt, value)
|
_spec.AddField(account.FieldPriority, field.TypeInt, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.RateMultiplier(); ok {
|
||||||
|
_spec.SetField(account.FieldRateMultiplier, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
|
||||||
|
_spec.AddField(account.FieldRateMultiplier, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.Status(); ok {
|
if value, ok := _u.mutation.Status(); ok {
|
||||||
_spec.SetField(account.FieldStatus, field.TypeString, value)
|
_spec.SetField(account.FieldStatus, field.TypeString, value)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ var (
|
|||||||
{Name: "extra", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
{Name: "extra", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||||
{Name: "concurrency", Type: field.TypeInt, Default: 3},
|
{Name: "concurrency", Type: field.TypeInt, Default: 3},
|
||||||
{Name: "priority", Type: field.TypeInt, Default: 50},
|
{Name: "priority", Type: field.TypeInt, Default: 50},
|
||||||
|
{Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
|
||||||
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
|
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
|
||||||
{Name: "error_message", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
|
{Name: "error_message", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
|
||||||
{Name: "last_used_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
{Name: "last_used_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||||
@@ -101,7 +102,7 @@ var (
|
|||||||
ForeignKeys: []*schema.ForeignKey{
|
ForeignKeys: []*schema.ForeignKey{
|
||||||
{
|
{
|
||||||
Symbol: "accounts_proxies_proxy",
|
Symbol: "accounts_proxies_proxy",
|
||||||
Columns: []*schema.Column{AccountsColumns[24]},
|
Columns: []*schema.Column{AccountsColumns[25]},
|
||||||
RefColumns: []*schema.Column{ProxiesColumns[0]},
|
RefColumns: []*schema.Column{ProxiesColumns[0]},
|
||||||
OnDelete: schema.SetNull,
|
OnDelete: schema.SetNull,
|
||||||
},
|
},
|
||||||
@@ -120,12 +121,12 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "account_status",
|
Name: "account_status",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{AccountsColumns[12]},
|
Columns: []*schema.Column{AccountsColumns[13]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "account_proxy_id",
|
Name: "account_proxy_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{AccountsColumns[24]},
|
Columns: []*schema.Column{AccountsColumns[25]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "account_priority",
|
Name: "account_priority",
|
||||||
@@ -135,27 +136,27 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "account_last_used_at",
|
Name: "account_last_used_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{AccountsColumns[14]},
|
Columns: []*schema.Column{AccountsColumns[15]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "account_schedulable",
|
Name: "account_schedulable",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{AccountsColumns[17]},
|
Columns: []*schema.Column{AccountsColumns[18]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "account_rate_limited_at",
|
Name: "account_rate_limited_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{AccountsColumns[18]},
|
Columns: []*schema.Column{AccountsColumns[19]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "account_rate_limit_reset_at",
|
Name: "account_rate_limit_reset_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{AccountsColumns[19]},
|
Columns: []*schema.Column{AccountsColumns[20]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "account_overload_until",
|
Name: "account_overload_until",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{AccountsColumns[20]},
|
Columns: []*schema.Column{AccountsColumns[21]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "account_deleted_at",
|
Name: "account_deleted_at",
|
||||||
@@ -449,6 +450,7 @@ var (
|
|||||||
{Name: "total_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
|
{Name: "total_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
|
||||||
{Name: "actual_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
|
{Name: "actual_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
|
||||||
{Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
|
{Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
|
||||||
|
{Name: "account_rate_multiplier", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
|
||||||
{Name: "billing_type", Type: field.TypeInt8, Default: 0},
|
{Name: "billing_type", Type: field.TypeInt8, Default: 0},
|
||||||
{Name: "stream", Type: field.TypeBool, Default: false},
|
{Name: "stream", Type: field.TypeBool, Default: false},
|
||||||
{Name: "duration_ms", Type: field.TypeInt, Nullable: true},
|
{Name: "duration_ms", Type: field.TypeInt, Nullable: true},
|
||||||
@@ -472,31 +474,31 @@ var (
|
|||||||
ForeignKeys: []*schema.ForeignKey{
|
ForeignKeys: []*schema.ForeignKey{
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_api_keys_usage_logs",
|
Symbol: "usage_logs_api_keys_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[25]},
|
Columns: []*schema.Column{UsageLogsColumns[26]},
|
||||||
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_accounts_usage_logs",
|
Symbol: "usage_logs_accounts_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[26]},
|
Columns: []*schema.Column{UsageLogsColumns[27]},
|
||||||
RefColumns: []*schema.Column{AccountsColumns[0]},
|
RefColumns: []*schema.Column{AccountsColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_groups_usage_logs",
|
Symbol: "usage_logs_groups_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||||
RefColumns: []*schema.Column{GroupsColumns[0]},
|
RefColumns: []*schema.Column{GroupsColumns[0]},
|
||||||
OnDelete: schema.SetNull,
|
OnDelete: schema.SetNull,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_users_usage_logs",
|
Symbol: "usage_logs_users_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||||
OnDelete: schema.NoAction,
|
OnDelete: schema.NoAction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
||||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||||
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
||||||
OnDelete: schema.SetNull,
|
OnDelete: schema.SetNull,
|
||||||
},
|
},
|
||||||
@@ -505,32 +507,32 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "usagelog_user_id",
|
Name: "usagelog_user_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_api_key_id",
|
Name: "usagelog_api_key_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[25]},
|
Columns: []*schema.Column{UsageLogsColumns[26]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_account_id",
|
Name: "usagelog_account_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[26]},
|
Columns: []*schema.Column{UsageLogsColumns[27]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_group_id",
|
Name: "usagelog_group_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_subscription_id",
|
Name: "usagelog_subscription_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_created_at",
|
Name: "usagelog_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[24]},
|
Columns: []*schema.Column{UsageLogsColumns[25]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_model",
|
Name: "usagelog_model",
|
||||||
@@ -545,12 +547,12 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "usagelog_user_id_created_at",
|
Name: "usagelog_user_id_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[24]},
|
Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[25]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "usagelog_api_key_id_created_at",
|
Name: "usagelog_api_key_id_created_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[25], UsageLogsColumns[24]},
|
Columns: []*schema.Column{UsageLogsColumns[26], UsageLogsColumns[25]},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1187,6 +1187,8 @@ type AccountMutation struct {
|
|||||||
addconcurrency *int
|
addconcurrency *int
|
||||||
priority *int
|
priority *int
|
||||||
addpriority *int
|
addpriority *int
|
||||||
|
rate_multiplier *float64
|
||||||
|
addrate_multiplier *float64
|
||||||
status *string
|
status *string
|
||||||
error_message *string
|
error_message *string
|
||||||
last_used_at *time.Time
|
last_used_at *time.Time
|
||||||
@@ -1822,6 +1824,62 @@ func (m *AccountMutation) ResetPriority() {
|
|||||||
m.addpriority = nil
|
m.addpriority = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRateMultiplier sets the "rate_multiplier" field.
|
||||||
|
func (m *AccountMutation) SetRateMultiplier(f float64) {
|
||||||
|
m.rate_multiplier = &f
|
||||||
|
m.addrate_multiplier = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateMultiplier returns the value of the "rate_multiplier" field in the mutation.
|
||||||
|
func (m *AccountMutation) RateMultiplier() (r float64, exists bool) {
|
||||||
|
v := m.rate_multiplier
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldRateMultiplier returns the old "rate_multiplier" field's value of the Account entity.
|
||||||
|
// If the Account object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *AccountMutation) OldRateMultiplier(ctx context.Context) (v float64, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldRateMultiplier is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldRateMultiplier requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldRateMultiplier: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.RateMultiplier, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRateMultiplier adds f to the "rate_multiplier" field.
|
||||||
|
func (m *AccountMutation) AddRateMultiplier(f float64) {
|
||||||
|
if m.addrate_multiplier != nil {
|
||||||
|
*m.addrate_multiplier += f
|
||||||
|
} else {
|
||||||
|
m.addrate_multiplier = &f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddedRateMultiplier returns the value that was added to the "rate_multiplier" field in this mutation.
|
||||||
|
func (m *AccountMutation) AddedRateMultiplier() (r float64, exists bool) {
|
||||||
|
v := m.addrate_multiplier
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetRateMultiplier resets all changes to the "rate_multiplier" field.
|
||||||
|
func (m *AccountMutation) ResetRateMultiplier() {
|
||||||
|
m.rate_multiplier = nil
|
||||||
|
m.addrate_multiplier = nil
|
||||||
|
}
|
||||||
|
|
||||||
// SetStatus sets the "status" field.
|
// SetStatus sets the "status" field.
|
||||||
func (m *AccountMutation) SetStatus(s string) {
|
func (m *AccountMutation) SetStatus(s string) {
|
||||||
m.status = &s
|
m.status = &s
|
||||||
@@ -2540,7 +2598,7 @@ func (m *AccountMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *AccountMutation) Fields() []string {
|
func (m *AccountMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 24)
|
fields := make([]string, 0, 25)
|
||||||
if m.created_at != nil {
|
if m.created_at != nil {
|
||||||
fields = append(fields, account.FieldCreatedAt)
|
fields = append(fields, account.FieldCreatedAt)
|
||||||
}
|
}
|
||||||
@@ -2577,6 +2635,9 @@ func (m *AccountMutation) Fields() []string {
|
|||||||
if m.priority != nil {
|
if m.priority != nil {
|
||||||
fields = append(fields, account.FieldPriority)
|
fields = append(fields, account.FieldPriority)
|
||||||
}
|
}
|
||||||
|
if m.rate_multiplier != nil {
|
||||||
|
fields = append(fields, account.FieldRateMultiplier)
|
||||||
|
}
|
||||||
if m.status != nil {
|
if m.status != nil {
|
||||||
fields = append(fields, account.FieldStatus)
|
fields = append(fields, account.FieldStatus)
|
||||||
}
|
}
|
||||||
@@ -2645,6 +2706,8 @@ func (m *AccountMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.Concurrency()
|
return m.Concurrency()
|
||||||
case account.FieldPriority:
|
case account.FieldPriority:
|
||||||
return m.Priority()
|
return m.Priority()
|
||||||
|
case account.FieldRateMultiplier:
|
||||||
|
return m.RateMultiplier()
|
||||||
case account.FieldStatus:
|
case account.FieldStatus:
|
||||||
return m.Status()
|
return m.Status()
|
||||||
case account.FieldErrorMessage:
|
case account.FieldErrorMessage:
|
||||||
@@ -2702,6 +2765,8 @@ func (m *AccountMutation) OldField(ctx context.Context, name string) (ent.Value,
|
|||||||
return m.OldConcurrency(ctx)
|
return m.OldConcurrency(ctx)
|
||||||
case account.FieldPriority:
|
case account.FieldPriority:
|
||||||
return m.OldPriority(ctx)
|
return m.OldPriority(ctx)
|
||||||
|
case account.FieldRateMultiplier:
|
||||||
|
return m.OldRateMultiplier(ctx)
|
||||||
case account.FieldStatus:
|
case account.FieldStatus:
|
||||||
return m.OldStatus(ctx)
|
return m.OldStatus(ctx)
|
||||||
case account.FieldErrorMessage:
|
case account.FieldErrorMessage:
|
||||||
@@ -2819,6 +2884,13 @@ func (m *AccountMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetPriority(v)
|
m.SetPriority(v)
|
||||||
return nil
|
return nil
|
||||||
|
case account.FieldRateMultiplier:
|
||||||
|
v, ok := value.(float64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetRateMultiplier(v)
|
||||||
|
return nil
|
||||||
case account.FieldStatus:
|
case account.FieldStatus:
|
||||||
v, ok := value.(string)
|
v, ok := value.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -2917,6 +2989,9 @@ func (m *AccountMutation) AddedFields() []string {
|
|||||||
if m.addpriority != nil {
|
if m.addpriority != nil {
|
||||||
fields = append(fields, account.FieldPriority)
|
fields = append(fields, account.FieldPriority)
|
||||||
}
|
}
|
||||||
|
if m.addrate_multiplier != nil {
|
||||||
|
fields = append(fields, account.FieldRateMultiplier)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2929,6 +3004,8 @@ func (m *AccountMutation) AddedField(name string) (ent.Value, bool) {
|
|||||||
return m.AddedConcurrency()
|
return m.AddedConcurrency()
|
||||||
case account.FieldPriority:
|
case account.FieldPriority:
|
||||||
return m.AddedPriority()
|
return m.AddedPriority()
|
||||||
|
case account.FieldRateMultiplier:
|
||||||
|
return m.AddedRateMultiplier()
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@@ -2952,6 +3029,13 @@ func (m *AccountMutation) AddField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.AddPriority(v)
|
m.AddPriority(v)
|
||||||
return nil
|
return nil
|
||||||
|
case account.FieldRateMultiplier:
|
||||||
|
v, ok := value.(float64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.AddRateMultiplier(v)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown Account numeric field %s", name)
|
return fmt.Errorf("unknown Account numeric field %s", name)
|
||||||
}
|
}
|
||||||
@@ -3090,6 +3174,9 @@ func (m *AccountMutation) ResetField(name string) error {
|
|||||||
case account.FieldPriority:
|
case account.FieldPriority:
|
||||||
m.ResetPriority()
|
m.ResetPriority()
|
||||||
return nil
|
return nil
|
||||||
|
case account.FieldRateMultiplier:
|
||||||
|
m.ResetRateMultiplier()
|
||||||
|
return nil
|
||||||
case account.FieldStatus:
|
case account.FieldStatus:
|
||||||
m.ResetStatus()
|
m.ResetStatus()
|
||||||
return nil
|
return nil
|
||||||
@@ -10190,6 +10277,8 @@ type UsageLogMutation struct {
|
|||||||
addactual_cost *float64
|
addactual_cost *float64
|
||||||
rate_multiplier *float64
|
rate_multiplier *float64
|
||||||
addrate_multiplier *float64
|
addrate_multiplier *float64
|
||||||
|
account_rate_multiplier *float64
|
||||||
|
addaccount_rate_multiplier *float64
|
||||||
billing_type *int8
|
billing_type *int8
|
||||||
addbilling_type *int8
|
addbilling_type *int8
|
||||||
stream *bool
|
stream *bool
|
||||||
@@ -11323,6 +11412,76 @@ func (m *UsageLogMutation) ResetRateMultiplier() {
|
|||||||
m.addrate_multiplier = nil
|
m.addrate_multiplier = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
|
||||||
|
func (m *UsageLogMutation) SetAccountRateMultiplier(f float64) {
|
||||||
|
m.account_rate_multiplier = &f
|
||||||
|
m.addaccount_rate_multiplier = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplier returns the value of the "account_rate_multiplier" field in the mutation.
|
||||||
|
func (m *UsageLogMutation) AccountRateMultiplier() (r float64, exists bool) {
|
||||||
|
v := m.account_rate_multiplier
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldAccountRateMultiplier returns the old "account_rate_multiplier" field's value of the UsageLog entity.
|
||||||
|
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *UsageLogMutation) OldAccountRateMultiplier(ctx context.Context) (v *float64, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldAccountRateMultiplier is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldAccountRateMultiplier requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldAccountRateMultiplier: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.AccountRateMultiplier, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddAccountRateMultiplier adds f to the "account_rate_multiplier" field.
|
||||||
|
func (m *UsageLogMutation) AddAccountRateMultiplier(f float64) {
|
||||||
|
if m.addaccount_rate_multiplier != nil {
|
||||||
|
*m.addaccount_rate_multiplier += f
|
||||||
|
} else {
|
||||||
|
m.addaccount_rate_multiplier = &f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddedAccountRateMultiplier returns the value that was added to the "account_rate_multiplier" field in this mutation.
|
||||||
|
func (m *UsageLogMutation) AddedAccountRateMultiplier() (r float64, exists bool) {
|
||||||
|
v := m.addaccount_rate_multiplier
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
|
||||||
|
func (m *UsageLogMutation) ClearAccountRateMultiplier() {
|
||||||
|
m.account_rate_multiplier = nil
|
||||||
|
m.addaccount_rate_multiplier = nil
|
||||||
|
m.clearedFields[usagelog.FieldAccountRateMultiplier] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplierCleared returns if the "account_rate_multiplier" field was cleared in this mutation.
|
||||||
|
func (m *UsageLogMutation) AccountRateMultiplierCleared() bool {
|
||||||
|
_, ok := m.clearedFields[usagelog.FieldAccountRateMultiplier]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetAccountRateMultiplier resets all changes to the "account_rate_multiplier" field.
|
||||||
|
func (m *UsageLogMutation) ResetAccountRateMultiplier() {
|
||||||
|
m.account_rate_multiplier = nil
|
||||||
|
m.addaccount_rate_multiplier = nil
|
||||||
|
delete(m.clearedFields, usagelog.FieldAccountRateMultiplier)
|
||||||
|
}
|
||||||
|
|
||||||
// SetBillingType sets the "billing_type" field.
|
// SetBillingType sets the "billing_type" field.
|
||||||
func (m *UsageLogMutation) SetBillingType(i int8) {
|
func (m *UsageLogMutation) SetBillingType(i int8) {
|
||||||
m.billing_type = &i
|
m.billing_type = &i
|
||||||
@@ -11963,7 +12122,7 @@ func (m *UsageLogMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *UsageLogMutation) Fields() []string {
|
func (m *UsageLogMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 29)
|
fields := make([]string, 0, 30)
|
||||||
if m.user != nil {
|
if m.user != nil {
|
||||||
fields = append(fields, usagelog.FieldUserID)
|
fields = append(fields, usagelog.FieldUserID)
|
||||||
}
|
}
|
||||||
@@ -12024,6 +12183,9 @@ func (m *UsageLogMutation) Fields() []string {
|
|||||||
if m.rate_multiplier != nil {
|
if m.rate_multiplier != nil {
|
||||||
fields = append(fields, usagelog.FieldRateMultiplier)
|
fields = append(fields, usagelog.FieldRateMultiplier)
|
||||||
}
|
}
|
||||||
|
if m.account_rate_multiplier != nil {
|
||||||
|
fields = append(fields, usagelog.FieldAccountRateMultiplier)
|
||||||
|
}
|
||||||
if m.billing_type != nil {
|
if m.billing_type != nil {
|
||||||
fields = append(fields, usagelog.FieldBillingType)
|
fields = append(fields, usagelog.FieldBillingType)
|
||||||
}
|
}
|
||||||
@@ -12099,6 +12261,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.ActualCost()
|
return m.ActualCost()
|
||||||
case usagelog.FieldRateMultiplier:
|
case usagelog.FieldRateMultiplier:
|
||||||
return m.RateMultiplier()
|
return m.RateMultiplier()
|
||||||
|
case usagelog.FieldAccountRateMultiplier:
|
||||||
|
return m.AccountRateMultiplier()
|
||||||
case usagelog.FieldBillingType:
|
case usagelog.FieldBillingType:
|
||||||
return m.BillingType()
|
return m.BillingType()
|
||||||
case usagelog.FieldStream:
|
case usagelog.FieldStream:
|
||||||
@@ -12166,6 +12330,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
|
|||||||
return m.OldActualCost(ctx)
|
return m.OldActualCost(ctx)
|
||||||
case usagelog.FieldRateMultiplier:
|
case usagelog.FieldRateMultiplier:
|
||||||
return m.OldRateMultiplier(ctx)
|
return m.OldRateMultiplier(ctx)
|
||||||
|
case usagelog.FieldAccountRateMultiplier:
|
||||||
|
return m.OldAccountRateMultiplier(ctx)
|
||||||
case usagelog.FieldBillingType:
|
case usagelog.FieldBillingType:
|
||||||
return m.OldBillingType(ctx)
|
return m.OldBillingType(ctx)
|
||||||
case usagelog.FieldStream:
|
case usagelog.FieldStream:
|
||||||
@@ -12333,6 +12499,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetRateMultiplier(v)
|
m.SetRateMultiplier(v)
|
||||||
return nil
|
return nil
|
||||||
|
case usagelog.FieldAccountRateMultiplier:
|
||||||
|
v, ok := value.(float64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetAccountRateMultiplier(v)
|
||||||
|
return nil
|
||||||
case usagelog.FieldBillingType:
|
case usagelog.FieldBillingType:
|
||||||
v, ok := value.(int8)
|
v, ok := value.(int8)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -12443,6 +12616,9 @@ func (m *UsageLogMutation) AddedFields() []string {
|
|||||||
if m.addrate_multiplier != nil {
|
if m.addrate_multiplier != nil {
|
||||||
fields = append(fields, usagelog.FieldRateMultiplier)
|
fields = append(fields, usagelog.FieldRateMultiplier)
|
||||||
}
|
}
|
||||||
|
if m.addaccount_rate_multiplier != nil {
|
||||||
|
fields = append(fields, usagelog.FieldAccountRateMultiplier)
|
||||||
|
}
|
||||||
if m.addbilling_type != nil {
|
if m.addbilling_type != nil {
|
||||||
fields = append(fields, usagelog.FieldBillingType)
|
fields = append(fields, usagelog.FieldBillingType)
|
||||||
}
|
}
|
||||||
@@ -12489,6 +12665,8 @@ func (m *UsageLogMutation) AddedField(name string) (ent.Value, bool) {
|
|||||||
return m.AddedActualCost()
|
return m.AddedActualCost()
|
||||||
case usagelog.FieldRateMultiplier:
|
case usagelog.FieldRateMultiplier:
|
||||||
return m.AddedRateMultiplier()
|
return m.AddedRateMultiplier()
|
||||||
|
case usagelog.FieldAccountRateMultiplier:
|
||||||
|
return m.AddedAccountRateMultiplier()
|
||||||
case usagelog.FieldBillingType:
|
case usagelog.FieldBillingType:
|
||||||
return m.AddedBillingType()
|
return m.AddedBillingType()
|
||||||
case usagelog.FieldDurationMs:
|
case usagelog.FieldDurationMs:
|
||||||
@@ -12597,6 +12775,13 @@ func (m *UsageLogMutation) AddField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.AddRateMultiplier(v)
|
m.AddRateMultiplier(v)
|
||||||
return nil
|
return nil
|
||||||
|
case usagelog.FieldAccountRateMultiplier:
|
||||||
|
v, ok := value.(float64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.AddAccountRateMultiplier(v)
|
||||||
|
return nil
|
||||||
case usagelog.FieldBillingType:
|
case usagelog.FieldBillingType:
|
||||||
v, ok := value.(int8)
|
v, ok := value.(int8)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -12639,6 +12824,9 @@ func (m *UsageLogMutation) ClearedFields() []string {
|
|||||||
if m.FieldCleared(usagelog.FieldSubscriptionID) {
|
if m.FieldCleared(usagelog.FieldSubscriptionID) {
|
||||||
fields = append(fields, usagelog.FieldSubscriptionID)
|
fields = append(fields, usagelog.FieldSubscriptionID)
|
||||||
}
|
}
|
||||||
|
if m.FieldCleared(usagelog.FieldAccountRateMultiplier) {
|
||||||
|
fields = append(fields, usagelog.FieldAccountRateMultiplier)
|
||||||
|
}
|
||||||
if m.FieldCleared(usagelog.FieldDurationMs) {
|
if m.FieldCleared(usagelog.FieldDurationMs) {
|
||||||
fields = append(fields, usagelog.FieldDurationMs)
|
fields = append(fields, usagelog.FieldDurationMs)
|
||||||
}
|
}
|
||||||
@@ -12674,6 +12862,9 @@ func (m *UsageLogMutation) ClearField(name string) error {
|
|||||||
case usagelog.FieldSubscriptionID:
|
case usagelog.FieldSubscriptionID:
|
||||||
m.ClearSubscriptionID()
|
m.ClearSubscriptionID()
|
||||||
return nil
|
return nil
|
||||||
|
case usagelog.FieldAccountRateMultiplier:
|
||||||
|
m.ClearAccountRateMultiplier()
|
||||||
|
return nil
|
||||||
case usagelog.FieldDurationMs:
|
case usagelog.FieldDurationMs:
|
||||||
m.ClearDurationMs()
|
m.ClearDurationMs()
|
||||||
return nil
|
return nil
|
||||||
@@ -12757,6 +12948,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
|
|||||||
case usagelog.FieldRateMultiplier:
|
case usagelog.FieldRateMultiplier:
|
||||||
m.ResetRateMultiplier()
|
m.ResetRateMultiplier()
|
||||||
return nil
|
return nil
|
||||||
|
case usagelog.FieldAccountRateMultiplier:
|
||||||
|
m.ResetAccountRateMultiplier()
|
||||||
|
return nil
|
||||||
case usagelog.FieldBillingType:
|
case usagelog.FieldBillingType:
|
||||||
m.ResetBillingType()
|
m.ResetBillingType()
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -177,22 +177,26 @@ func init() {
|
|||||||
accountDescPriority := accountFields[8].Descriptor()
|
accountDescPriority := accountFields[8].Descriptor()
|
||||||
// account.DefaultPriority holds the default value on creation for the priority field.
|
// account.DefaultPriority holds the default value on creation for the priority field.
|
||||||
account.DefaultPriority = accountDescPriority.Default.(int)
|
account.DefaultPriority = accountDescPriority.Default.(int)
|
||||||
|
// accountDescRateMultiplier is the schema descriptor for rate_multiplier field.
|
||||||
|
accountDescRateMultiplier := accountFields[9].Descriptor()
|
||||||
|
// account.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
|
||||||
|
account.DefaultRateMultiplier = accountDescRateMultiplier.Default.(float64)
|
||||||
// accountDescStatus is the schema descriptor for status field.
|
// accountDescStatus is the schema descriptor for status field.
|
||||||
accountDescStatus := accountFields[9].Descriptor()
|
accountDescStatus := accountFields[10].Descriptor()
|
||||||
// account.DefaultStatus holds the default value on creation for the status field.
|
// account.DefaultStatus holds the default value on creation for the status field.
|
||||||
account.DefaultStatus = accountDescStatus.Default.(string)
|
account.DefaultStatus = accountDescStatus.Default.(string)
|
||||||
// account.StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
// account.StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||||
account.StatusValidator = accountDescStatus.Validators[0].(func(string) error)
|
account.StatusValidator = accountDescStatus.Validators[0].(func(string) error)
|
||||||
// accountDescAutoPauseOnExpired is the schema descriptor for auto_pause_on_expired field.
|
// accountDescAutoPauseOnExpired is the schema descriptor for auto_pause_on_expired field.
|
||||||
accountDescAutoPauseOnExpired := accountFields[13].Descriptor()
|
accountDescAutoPauseOnExpired := accountFields[14].Descriptor()
|
||||||
// account.DefaultAutoPauseOnExpired holds the default value on creation for the auto_pause_on_expired field.
|
// account.DefaultAutoPauseOnExpired holds the default value on creation for the auto_pause_on_expired field.
|
||||||
account.DefaultAutoPauseOnExpired = accountDescAutoPauseOnExpired.Default.(bool)
|
account.DefaultAutoPauseOnExpired = accountDescAutoPauseOnExpired.Default.(bool)
|
||||||
// accountDescSchedulable is the schema descriptor for schedulable field.
|
// accountDescSchedulable is the schema descriptor for schedulable field.
|
||||||
accountDescSchedulable := accountFields[14].Descriptor()
|
accountDescSchedulable := accountFields[15].Descriptor()
|
||||||
// account.DefaultSchedulable holds the default value on creation for the schedulable field.
|
// account.DefaultSchedulable holds the default value on creation for the schedulable field.
|
||||||
account.DefaultSchedulable = accountDescSchedulable.Default.(bool)
|
account.DefaultSchedulable = accountDescSchedulable.Default.(bool)
|
||||||
// accountDescSessionWindowStatus is the schema descriptor for session_window_status field.
|
// accountDescSessionWindowStatus is the schema descriptor for session_window_status field.
|
||||||
accountDescSessionWindowStatus := accountFields[20].Descriptor()
|
accountDescSessionWindowStatus := accountFields[21].Descriptor()
|
||||||
// account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save.
|
// account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save.
|
||||||
account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error)
|
account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error)
|
||||||
accountgroupFields := schema.AccountGroup{}.Fields()
|
accountgroupFields := schema.AccountGroup{}.Fields()
|
||||||
@@ -578,31 +582,31 @@ func init() {
|
|||||||
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
|
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
|
||||||
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
|
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
|
||||||
// usagelogDescBillingType is the schema descriptor for billing_type field.
|
// usagelogDescBillingType is the schema descriptor for billing_type field.
|
||||||
usagelogDescBillingType := usagelogFields[20].Descriptor()
|
usagelogDescBillingType := usagelogFields[21].Descriptor()
|
||||||
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
|
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
|
||||||
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
|
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
|
||||||
// usagelogDescStream is the schema descriptor for stream field.
|
// usagelogDescStream is the schema descriptor for stream field.
|
||||||
usagelogDescStream := usagelogFields[21].Descriptor()
|
usagelogDescStream := usagelogFields[22].Descriptor()
|
||||||
// usagelog.DefaultStream holds the default value on creation for the stream field.
|
// usagelog.DefaultStream holds the default value on creation for the stream field.
|
||||||
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
|
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
|
||||||
// usagelogDescUserAgent is the schema descriptor for user_agent field.
|
// usagelogDescUserAgent is the schema descriptor for user_agent field.
|
||||||
usagelogDescUserAgent := usagelogFields[24].Descriptor()
|
usagelogDescUserAgent := usagelogFields[25].Descriptor()
|
||||||
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
|
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
|
||||||
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
|
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
|
||||||
// usagelogDescIPAddress is the schema descriptor for ip_address field.
|
// usagelogDescIPAddress is the schema descriptor for ip_address field.
|
||||||
usagelogDescIPAddress := usagelogFields[25].Descriptor()
|
usagelogDescIPAddress := usagelogFields[26].Descriptor()
|
||||||
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
|
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
|
||||||
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
|
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
|
||||||
// usagelogDescImageCount is the schema descriptor for image_count field.
|
// usagelogDescImageCount is the schema descriptor for image_count field.
|
||||||
usagelogDescImageCount := usagelogFields[26].Descriptor()
|
usagelogDescImageCount := usagelogFields[27].Descriptor()
|
||||||
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
|
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
|
||||||
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
|
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
|
||||||
// usagelogDescImageSize is the schema descriptor for image_size field.
|
// usagelogDescImageSize is the schema descriptor for image_size field.
|
||||||
usagelogDescImageSize := usagelogFields[27].Descriptor()
|
usagelogDescImageSize := usagelogFields[28].Descriptor()
|
||||||
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
||||||
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
|
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
|
||||||
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
||||||
usagelogDescCreatedAt := usagelogFields[28].Descriptor()
|
usagelogDescCreatedAt := usagelogFields[29].Descriptor()
|
||||||
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||||
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
||||||
userMixin := schema.User{}.Mixin()
|
userMixin := schema.User{}.Mixin()
|
||||||
|
|||||||
@@ -102,6 +102,12 @@ func (Account) Fields() []ent.Field {
|
|||||||
field.Int("priority").
|
field.Int("priority").
|
||||||
Default(50),
|
Default(50),
|
||||||
|
|
||||||
|
// rate_multiplier: 账号计费倍率(>=0,允许 0 表示该账号计费为 0)
|
||||||
|
// 仅影响账号维度计费口径,不影响用户/API Key 扣费(分组倍率)
|
||||||
|
field.Float("rate_multiplier").
|
||||||
|
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}).
|
||||||
|
Default(1.0),
|
||||||
|
|
||||||
// status: 账户状态,如 "active", "error", "disabled"
|
// status: 账户状态,如 "active", "error", "disabled"
|
||||||
field.String("status").
|
field.String("status").
|
||||||
MaxLen(20).
|
MaxLen(20).
|
||||||
|
|||||||
@@ -85,6 +85,12 @@ func (UsageLog) Fields() []ent.Field {
|
|||||||
Default(1).
|
Default(1).
|
||||||
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}),
|
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}),
|
||||||
|
|
||||||
|
// account_rate_multiplier: 账号计费倍率快照(NULL 表示按 1.0 处理)
|
||||||
|
field.Float("account_rate_multiplier").
|
||||||
|
Optional().
|
||||||
|
Nillable().
|
||||||
|
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}),
|
||||||
|
|
||||||
// 其他字段
|
// 其他字段
|
||||||
field.Int8("billing_type").
|
field.Int8("billing_type").
|
||||||
Default(0),
|
Default(0),
|
||||||
|
|||||||
@@ -62,6 +62,8 @@ type UsageLog struct {
|
|||||||
ActualCost float64 `json:"actual_cost,omitempty"`
|
ActualCost float64 `json:"actual_cost,omitempty"`
|
||||||
// RateMultiplier holds the value of the "rate_multiplier" field.
|
// RateMultiplier holds the value of the "rate_multiplier" field.
|
||||||
RateMultiplier float64 `json:"rate_multiplier,omitempty"`
|
RateMultiplier float64 `json:"rate_multiplier,omitempty"`
|
||||||
|
// AccountRateMultiplier holds the value of the "account_rate_multiplier" field.
|
||||||
|
AccountRateMultiplier *float64 `json:"account_rate_multiplier,omitempty"`
|
||||||
// BillingType holds the value of the "billing_type" field.
|
// BillingType holds the value of the "billing_type" field.
|
||||||
BillingType int8 `json:"billing_type,omitempty"`
|
BillingType int8 `json:"billing_type,omitempty"`
|
||||||
// Stream holds the value of the "stream" field.
|
// Stream holds the value of the "stream" field.
|
||||||
@@ -165,7 +167,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
|
|||||||
switch columns[i] {
|
switch columns[i] {
|
||||||
case usagelog.FieldStream:
|
case usagelog.FieldStream:
|
||||||
values[i] = new(sql.NullBool)
|
values[i] = new(sql.NullBool)
|
||||||
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier:
|
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
|
||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
@@ -316,6 +318,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
|
|||||||
} else if value.Valid {
|
} else if value.Valid {
|
||||||
_m.RateMultiplier = value.Float64
|
_m.RateMultiplier = value.Float64
|
||||||
}
|
}
|
||||||
|
case usagelog.FieldAccountRateMultiplier:
|
||||||
|
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field account_rate_multiplier", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.AccountRateMultiplier = new(float64)
|
||||||
|
*_m.AccountRateMultiplier = value.Float64
|
||||||
|
}
|
||||||
case usagelog.FieldBillingType:
|
case usagelog.FieldBillingType:
|
||||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||||
return fmt.Errorf("unexpected type %T for field billing_type", values[i])
|
return fmt.Errorf("unexpected type %T for field billing_type", values[i])
|
||||||
@@ -500,6 +509,11 @@ func (_m *UsageLog) String() string {
|
|||||||
builder.WriteString("rate_multiplier=")
|
builder.WriteString("rate_multiplier=")
|
||||||
builder.WriteString(fmt.Sprintf("%v", _m.RateMultiplier))
|
builder.WriteString(fmt.Sprintf("%v", _m.RateMultiplier))
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
if v := _m.AccountRateMultiplier; v != nil {
|
||||||
|
builder.WriteString("account_rate_multiplier=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||||
|
}
|
||||||
|
builder.WriteString(", ")
|
||||||
builder.WriteString("billing_type=")
|
builder.WriteString("billing_type=")
|
||||||
builder.WriteString(fmt.Sprintf("%v", _m.BillingType))
|
builder.WriteString(fmt.Sprintf("%v", _m.BillingType))
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
|||||||
@@ -54,6 +54,8 @@ const (
|
|||||||
FieldActualCost = "actual_cost"
|
FieldActualCost = "actual_cost"
|
||||||
// FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
|
// FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
|
||||||
FieldRateMultiplier = "rate_multiplier"
|
FieldRateMultiplier = "rate_multiplier"
|
||||||
|
// FieldAccountRateMultiplier holds the string denoting the account_rate_multiplier field in the database.
|
||||||
|
FieldAccountRateMultiplier = "account_rate_multiplier"
|
||||||
// FieldBillingType holds the string denoting the billing_type field in the database.
|
// FieldBillingType holds the string denoting the billing_type field in the database.
|
||||||
FieldBillingType = "billing_type"
|
FieldBillingType = "billing_type"
|
||||||
// FieldStream holds the string denoting the stream field in the database.
|
// FieldStream holds the string denoting the stream field in the database.
|
||||||
@@ -144,6 +146,7 @@ var Columns = []string{
|
|||||||
FieldTotalCost,
|
FieldTotalCost,
|
||||||
FieldActualCost,
|
FieldActualCost,
|
||||||
FieldRateMultiplier,
|
FieldRateMultiplier,
|
||||||
|
FieldAccountRateMultiplier,
|
||||||
FieldBillingType,
|
FieldBillingType,
|
||||||
FieldStream,
|
FieldStream,
|
||||||
FieldDurationMs,
|
FieldDurationMs,
|
||||||
@@ -320,6 +323,11 @@ func ByRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldRateMultiplier, opts...).ToFunc()
|
return sql.OrderByField(FieldRateMultiplier, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ByAccountRateMultiplier orders the results by the account_rate_multiplier field.
|
||||||
|
func ByAccountRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldAccountRateMultiplier, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// ByBillingType orders the results by the billing_type field.
|
// ByBillingType orders the results by the billing_type field.
|
||||||
func ByBillingType(opts ...sql.OrderTermOption) OrderOption {
|
func ByBillingType(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return sql.OrderByField(FieldBillingType, opts...).ToFunc()
|
return sql.OrderByField(FieldBillingType, opts...).ToFunc()
|
||||||
|
|||||||
@@ -155,6 +155,11 @@ func RateMultiplier(v float64) predicate.UsageLog {
|
|||||||
return predicate.UsageLog(sql.FieldEQ(FieldRateMultiplier, v))
|
return predicate.UsageLog(sql.FieldEQ(FieldRateMultiplier, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplier applies equality check predicate on the "account_rate_multiplier" field. It's identical to AccountRateMultiplierEQ.
|
||||||
|
func AccountRateMultiplier(v float64) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldEQ(FieldAccountRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
// BillingType applies equality check predicate on the "billing_type" field. It's identical to BillingTypeEQ.
|
// BillingType applies equality check predicate on the "billing_type" field. It's identical to BillingTypeEQ.
|
||||||
func BillingType(v int8) predicate.UsageLog {
|
func BillingType(v int8) predicate.UsageLog {
|
||||||
return predicate.UsageLog(sql.FieldEQ(FieldBillingType, v))
|
return predicate.UsageLog(sql.FieldEQ(FieldBillingType, v))
|
||||||
@@ -970,6 +975,56 @@ func RateMultiplierLTE(v float64) predicate.UsageLog {
|
|||||||
return predicate.UsageLog(sql.FieldLTE(FieldRateMultiplier, v))
|
return predicate.UsageLog(sql.FieldLTE(FieldRateMultiplier, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplierEQ applies the EQ predicate on the "account_rate_multiplier" field.
|
||||||
|
func AccountRateMultiplierEQ(v float64) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldEQ(FieldAccountRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplierNEQ applies the NEQ predicate on the "account_rate_multiplier" field.
|
||||||
|
func AccountRateMultiplierNEQ(v float64) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldNEQ(FieldAccountRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplierIn applies the In predicate on the "account_rate_multiplier" field.
|
||||||
|
func AccountRateMultiplierIn(vs ...float64) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldIn(FieldAccountRateMultiplier, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplierNotIn applies the NotIn predicate on the "account_rate_multiplier" field.
|
||||||
|
func AccountRateMultiplierNotIn(vs ...float64) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldNotIn(FieldAccountRateMultiplier, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplierGT applies the GT predicate on the "account_rate_multiplier" field.
|
||||||
|
func AccountRateMultiplierGT(v float64) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldGT(FieldAccountRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplierGTE applies the GTE predicate on the "account_rate_multiplier" field.
|
||||||
|
func AccountRateMultiplierGTE(v float64) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldGTE(FieldAccountRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplierLT applies the LT predicate on the "account_rate_multiplier" field.
|
||||||
|
func AccountRateMultiplierLT(v float64) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldLT(FieldAccountRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplierLTE applies the LTE predicate on the "account_rate_multiplier" field.
|
||||||
|
func AccountRateMultiplierLTE(v float64) predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldLTE(FieldAccountRateMultiplier, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplierIsNil applies the IsNil predicate on the "account_rate_multiplier" field.
|
||||||
|
func AccountRateMultiplierIsNil() predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldIsNull(FieldAccountRateMultiplier))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountRateMultiplierNotNil applies the NotNil predicate on the "account_rate_multiplier" field.
|
||||||
|
func AccountRateMultiplierNotNil() predicate.UsageLog {
|
||||||
|
return predicate.UsageLog(sql.FieldNotNull(FieldAccountRateMultiplier))
|
||||||
|
}
|
||||||
|
|
||||||
// BillingTypeEQ applies the EQ predicate on the "billing_type" field.
|
// BillingTypeEQ applies the EQ predicate on the "billing_type" field.
|
||||||
func BillingTypeEQ(v int8) predicate.UsageLog {
|
func BillingTypeEQ(v int8) predicate.UsageLog {
|
||||||
return predicate.UsageLog(sql.FieldEQ(FieldBillingType, v))
|
return predicate.UsageLog(sql.FieldEQ(FieldBillingType, v))
|
||||||
|
|||||||
@@ -267,6 +267,20 @@ func (_c *UsageLogCreate) SetNillableRateMultiplier(v *float64) *UsageLogCreate
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
|
||||||
|
func (_c *UsageLogCreate) SetAccountRateMultiplier(v float64) *UsageLogCreate {
|
||||||
|
_c.mutation.SetAccountRateMultiplier(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableAccountRateMultiplier sets the "account_rate_multiplier" field if the given value is not nil.
|
||||||
|
func (_c *UsageLogCreate) SetNillableAccountRateMultiplier(v *float64) *UsageLogCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetAccountRateMultiplier(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// SetBillingType sets the "billing_type" field.
|
// SetBillingType sets the "billing_type" field.
|
||||||
func (_c *UsageLogCreate) SetBillingType(v int8) *UsageLogCreate {
|
func (_c *UsageLogCreate) SetBillingType(v int8) *UsageLogCreate {
|
||||||
_c.mutation.SetBillingType(v)
|
_c.mutation.SetBillingType(v)
|
||||||
@@ -712,6 +726,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
|
_spec.SetField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
|
||||||
_node.RateMultiplier = value
|
_node.RateMultiplier = value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.AccountRateMultiplier(); ok {
|
||||||
|
_spec.SetField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
|
||||||
|
_node.AccountRateMultiplier = &value
|
||||||
|
}
|
||||||
if value, ok := _c.mutation.BillingType(); ok {
|
if value, ok := _c.mutation.BillingType(); ok {
|
||||||
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
|
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
|
||||||
_node.BillingType = value
|
_node.BillingType = value
|
||||||
@@ -1215,6 +1233,30 @@ func (u *UsageLogUpsert) AddRateMultiplier(v float64) *UsageLogUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
|
||||||
|
func (u *UsageLogUpsert) SetAccountRateMultiplier(v float64) *UsageLogUpsert {
|
||||||
|
u.Set(usagelog.FieldAccountRateMultiplier, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAccountRateMultiplier sets the "account_rate_multiplier" field to the value that was provided on create.
|
||||||
|
func (u *UsageLogUpsert) UpdateAccountRateMultiplier() *UsageLogUpsert {
|
||||||
|
u.SetExcluded(usagelog.FieldAccountRateMultiplier)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddAccountRateMultiplier adds v to the "account_rate_multiplier" field.
|
||||||
|
func (u *UsageLogUpsert) AddAccountRateMultiplier(v float64) *UsageLogUpsert {
|
||||||
|
u.Add(usagelog.FieldAccountRateMultiplier, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
|
||||||
|
func (u *UsageLogUpsert) ClearAccountRateMultiplier() *UsageLogUpsert {
|
||||||
|
u.SetNull(usagelog.FieldAccountRateMultiplier)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// SetBillingType sets the "billing_type" field.
|
// SetBillingType sets the "billing_type" field.
|
||||||
func (u *UsageLogUpsert) SetBillingType(v int8) *UsageLogUpsert {
|
func (u *UsageLogUpsert) SetBillingType(v int8) *UsageLogUpsert {
|
||||||
u.Set(usagelog.FieldBillingType, v)
|
u.Set(usagelog.FieldBillingType, v)
|
||||||
@@ -1795,6 +1837,34 @@ func (u *UsageLogUpsertOne) UpdateRateMultiplier() *UsageLogUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
|
||||||
|
func (u *UsageLogUpsertOne) SetAccountRateMultiplier(v float64) *UsageLogUpsertOne {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.SetAccountRateMultiplier(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddAccountRateMultiplier adds v to the "account_rate_multiplier" field.
|
||||||
|
func (u *UsageLogUpsertOne) AddAccountRateMultiplier(v float64) *UsageLogUpsertOne {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.AddAccountRateMultiplier(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAccountRateMultiplier sets the "account_rate_multiplier" field to the value that was provided on create.
|
||||||
|
func (u *UsageLogUpsertOne) UpdateAccountRateMultiplier() *UsageLogUpsertOne {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.UpdateAccountRateMultiplier()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
|
||||||
|
func (u *UsageLogUpsertOne) ClearAccountRateMultiplier() *UsageLogUpsertOne {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.ClearAccountRateMultiplier()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetBillingType sets the "billing_type" field.
|
// SetBillingType sets the "billing_type" field.
|
||||||
func (u *UsageLogUpsertOne) SetBillingType(v int8) *UsageLogUpsertOne {
|
func (u *UsageLogUpsertOne) SetBillingType(v int8) *UsageLogUpsertOne {
|
||||||
return u.Update(func(s *UsageLogUpsert) {
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
@@ -2566,6 +2636,34 @@ func (u *UsageLogUpsertBulk) UpdateRateMultiplier() *UsageLogUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
|
||||||
|
func (u *UsageLogUpsertBulk) SetAccountRateMultiplier(v float64) *UsageLogUpsertBulk {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.SetAccountRateMultiplier(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddAccountRateMultiplier adds v to the "account_rate_multiplier" field.
|
||||||
|
func (u *UsageLogUpsertBulk) AddAccountRateMultiplier(v float64) *UsageLogUpsertBulk {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.AddAccountRateMultiplier(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAccountRateMultiplier sets the "account_rate_multiplier" field to the value that was provided on create.
|
||||||
|
func (u *UsageLogUpsertBulk) UpdateAccountRateMultiplier() *UsageLogUpsertBulk {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.UpdateAccountRateMultiplier()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
|
||||||
|
func (u *UsageLogUpsertBulk) ClearAccountRateMultiplier() *UsageLogUpsertBulk {
|
||||||
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
s.ClearAccountRateMultiplier()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetBillingType sets the "billing_type" field.
|
// SetBillingType sets the "billing_type" field.
|
||||||
func (u *UsageLogUpsertBulk) SetBillingType(v int8) *UsageLogUpsertBulk {
|
func (u *UsageLogUpsertBulk) SetBillingType(v int8) *UsageLogUpsertBulk {
|
||||||
return u.Update(func(s *UsageLogUpsert) {
|
return u.Update(func(s *UsageLogUpsert) {
|
||||||
|
|||||||
@@ -415,6 +415,33 @@ func (_u *UsageLogUpdate) AddRateMultiplier(v float64) *UsageLogUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
|
||||||
|
func (_u *UsageLogUpdate) SetAccountRateMultiplier(v float64) *UsageLogUpdate {
|
||||||
|
_u.mutation.ResetAccountRateMultiplier()
|
||||||
|
_u.mutation.SetAccountRateMultiplier(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableAccountRateMultiplier sets the "account_rate_multiplier" field if the given value is not nil.
|
||||||
|
func (_u *UsageLogUpdate) SetNillableAccountRateMultiplier(v *float64) *UsageLogUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetAccountRateMultiplier(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddAccountRateMultiplier adds value to the "account_rate_multiplier" field.
|
||||||
|
func (_u *UsageLogUpdate) AddAccountRateMultiplier(v float64) *UsageLogUpdate {
|
||||||
|
_u.mutation.AddAccountRateMultiplier(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
|
||||||
|
func (_u *UsageLogUpdate) ClearAccountRateMultiplier() *UsageLogUpdate {
|
||||||
|
_u.mutation.ClearAccountRateMultiplier()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetBillingType sets the "billing_type" field.
|
// SetBillingType sets the "billing_type" field.
|
||||||
func (_u *UsageLogUpdate) SetBillingType(v int8) *UsageLogUpdate {
|
func (_u *UsageLogUpdate) SetBillingType(v int8) *UsageLogUpdate {
|
||||||
_u.mutation.ResetBillingType()
|
_u.mutation.ResetBillingType()
|
||||||
@@ -807,6 +834,15 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
|
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
|
||||||
_spec.AddField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
|
_spec.AddField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.AccountRateMultiplier(); ok {
|
||||||
|
_spec.SetField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedAccountRateMultiplier(); ok {
|
||||||
|
_spec.AddField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.AccountRateMultiplierCleared() {
|
||||||
|
_spec.ClearField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.BillingType(); ok {
|
if value, ok := _u.mutation.BillingType(); ok {
|
||||||
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
|
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
|
||||||
}
|
}
|
||||||
@@ -1406,6 +1442,33 @@ func (_u *UsageLogUpdateOne) AddRateMultiplier(v float64) *UsageLogUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
|
||||||
|
func (_u *UsageLogUpdateOne) SetAccountRateMultiplier(v float64) *UsageLogUpdateOne {
|
||||||
|
_u.mutation.ResetAccountRateMultiplier()
|
||||||
|
_u.mutation.SetAccountRateMultiplier(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableAccountRateMultiplier sets the "account_rate_multiplier" field if the given value is not nil.
|
||||||
|
func (_u *UsageLogUpdateOne) SetNillableAccountRateMultiplier(v *float64) *UsageLogUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetAccountRateMultiplier(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddAccountRateMultiplier adds value to the "account_rate_multiplier" field.
|
||||||
|
func (_u *UsageLogUpdateOne) AddAccountRateMultiplier(v float64) *UsageLogUpdateOne {
|
||||||
|
_u.mutation.AddAccountRateMultiplier(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
|
||||||
|
func (_u *UsageLogUpdateOne) ClearAccountRateMultiplier() *UsageLogUpdateOne {
|
||||||
|
_u.mutation.ClearAccountRateMultiplier()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetBillingType sets the "billing_type" field.
|
// SetBillingType sets the "billing_type" field.
|
||||||
func (_u *UsageLogUpdateOne) SetBillingType(v int8) *UsageLogUpdateOne {
|
func (_u *UsageLogUpdateOne) SetBillingType(v int8) *UsageLogUpdateOne {
|
||||||
_u.mutation.ResetBillingType()
|
_u.mutation.ResetBillingType()
|
||||||
@@ -1828,6 +1891,15 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
|
|||||||
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
|
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
|
||||||
_spec.AddField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
|
_spec.AddField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.AccountRateMultiplier(); ok {
|
||||||
|
_spec.SetField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedAccountRateMultiplier(); ok {
|
||||||
|
_spec.AddField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.AccountRateMultiplierCleared() {
|
||||||
|
_spec.ClearField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.BillingType(); ok {
|
if value, ok := _u.mutation.BillingType(); ok {
|
||||||
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
|
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -270,6 +270,29 @@ type GatewaySchedulingConfig struct {
|
|||||||
|
|
||||||
// 过期槽位清理周期(0 表示禁用)
|
// 过期槽位清理周期(0 表示禁用)
|
||||||
SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"`
|
SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"`
|
||||||
|
|
||||||
|
// 受控回源配置
|
||||||
|
DbFallbackEnabled bool `mapstructure:"db_fallback_enabled"`
|
||||||
|
// 受控回源超时(秒),0 表示不额外收紧超时
|
||||||
|
DbFallbackTimeoutSeconds int `mapstructure:"db_fallback_timeout_seconds"`
|
||||||
|
// 受控回源限流(实例级 QPS),0 表示不限制
|
||||||
|
DbFallbackMaxQPS int `mapstructure:"db_fallback_max_qps"`
|
||||||
|
|
||||||
|
// Outbox 轮询与滞后阈值配置
|
||||||
|
// Outbox 轮询周期(秒)
|
||||||
|
OutboxPollIntervalSeconds int `mapstructure:"outbox_poll_interval_seconds"`
|
||||||
|
// Outbox 滞后告警阈值(秒)
|
||||||
|
OutboxLagWarnSeconds int `mapstructure:"outbox_lag_warn_seconds"`
|
||||||
|
// Outbox 触发强制重建阈值(秒)
|
||||||
|
OutboxLagRebuildSeconds int `mapstructure:"outbox_lag_rebuild_seconds"`
|
||||||
|
// Outbox 连续滞后触发次数
|
||||||
|
OutboxLagRebuildFailures int `mapstructure:"outbox_lag_rebuild_failures"`
|
||||||
|
// Outbox 积压触发重建阈值(行数)
|
||||||
|
OutboxBacklogRebuildRows int `mapstructure:"outbox_backlog_rebuild_rows"`
|
||||||
|
|
||||||
|
// 全量重建周期配置
|
||||||
|
// 全量重建周期(秒),0 表示禁用
|
||||||
|
FullRebuildIntervalSeconds int `mapstructure:"full_rebuild_interval_seconds"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServerConfig) Address() string {
|
func (s *ServerConfig) Address() string {
|
||||||
@@ -744,11 +767,20 @@ func setDefaults() {
|
|||||||
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
||||||
viper.SetDefault("gateway.max_line_size", 10*1024*1024)
|
viper.SetDefault("gateway.max_line_size", 10*1024*1024)
|
||||||
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
||||||
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second)
|
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
|
||||||
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
|
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
|
||||||
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
|
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
|
||||||
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
|
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
|
||||||
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
|
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
|
||||||
|
viper.SetDefault("gateway.scheduling.db_fallback_enabled", true)
|
||||||
|
viper.SetDefault("gateway.scheduling.db_fallback_timeout_seconds", 0)
|
||||||
|
viper.SetDefault("gateway.scheduling.db_fallback_max_qps", 0)
|
||||||
|
viper.SetDefault("gateway.scheduling.outbox_poll_interval_seconds", 1)
|
||||||
|
viper.SetDefault("gateway.scheduling.outbox_lag_warn_seconds", 5)
|
||||||
|
viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_seconds", 10)
|
||||||
|
viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3)
|
||||||
|
viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000)
|
||||||
|
viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300)
|
||||||
viper.SetDefault("concurrency.ping_interval", 10)
|
viper.SetDefault("concurrency.ping_interval", 10)
|
||||||
|
|
||||||
// TokenRefresh
|
// TokenRefresh
|
||||||
@@ -1021,6 +1053,35 @@ func (c *Config) Validate() error {
|
|||||||
if c.Gateway.Scheduling.SlotCleanupInterval < 0 {
|
if c.Gateway.Scheduling.SlotCleanupInterval < 0 {
|
||||||
return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative")
|
return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative")
|
||||||
}
|
}
|
||||||
|
if c.Gateway.Scheduling.DbFallbackTimeoutSeconds < 0 {
|
||||||
|
return fmt.Errorf("gateway.scheduling.db_fallback_timeout_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Gateway.Scheduling.DbFallbackMaxQPS < 0 {
|
||||||
|
return fmt.Errorf("gateway.scheduling.db_fallback_max_qps must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Gateway.Scheduling.OutboxPollIntervalSeconds <= 0 {
|
||||||
|
return fmt.Errorf("gateway.scheduling.outbox_poll_interval_seconds must be positive")
|
||||||
|
}
|
||||||
|
if c.Gateway.Scheduling.OutboxLagWarnSeconds < 0 {
|
||||||
|
return fmt.Errorf("gateway.scheduling.outbox_lag_warn_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Gateway.Scheduling.OutboxLagRebuildSeconds < 0 {
|
||||||
|
return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Gateway.Scheduling.OutboxLagRebuildFailures <= 0 {
|
||||||
|
return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_failures must be positive")
|
||||||
|
}
|
||||||
|
if c.Gateway.Scheduling.OutboxBacklogRebuildRows < 0 {
|
||||||
|
return fmt.Errorf("gateway.scheduling.outbox_backlog_rebuild_rows must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Gateway.Scheduling.FullRebuildIntervalSeconds < 0 {
|
||||||
|
return fmt.Errorf("gateway.scheduling.full_rebuild_interval_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Gateway.Scheduling.OutboxLagWarnSeconds > 0 &&
|
||||||
|
c.Gateway.Scheduling.OutboxLagRebuildSeconds > 0 &&
|
||||||
|
c.Gateway.Scheduling.OutboxLagRebuildSeconds < c.Gateway.Scheduling.OutboxLagWarnSeconds {
|
||||||
|
return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_seconds must be >= outbox_lag_warn_seconds")
|
||||||
|
}
|
||||||
if c.Ops.MetricsCollectorCache.TTL < 0 {
|
if c.Ops.MetricsCollectorCache.TTL < 0 {
|
||||||
return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative")
|
return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,8 +39,8 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) {
|
|||||||
if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 3 {
|
if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 3 {
|
||||||
t.Fatalf("StickySessionMaxWaiting = %d, want 3", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
|
t.Fatalf("StickySessionMaxWaiting = %d, want 3", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
|
||||||
}
|
}
|
||||||
if cfg.Gateway.Scheduling.StickySessionWaitTimeout != 45*time.Second {
|
if cfg.Gateway.Scheduling.StickySessionWaitTimeout != 120*time.Second {
|
||||||
t.Fatalf("StickySessionWaitTimeout = %v, want 45s", cfg.Gateway.Scheduling.StickySessionWaitTimeout)
|
t.Fatalf("StickySessionWaitTimeout = %v, want 120s", cfg.Gateway.Scheduling.StickySessionWaitTimeout)
|
||||||
}
|
}
|
||||||
if cfg.Gateway.Scheduling.FallbackWaitTimeout != 30*time.Second {
|
if cfg.Gateway.Scheduling.FallbackWaitTimeout != 30*time.Second {
|
||||||
t.Fatalf("FallbackWaitTimeout = %v, want 30s", cfg.Gateway.Scheduling.FallbackWaitTimeout)
|
t.Fatalf("FallbackWaitTimeout = %v, want 30s", cfg.Gateway.Scheduling.FallbackWaitTimeout)
|
||||||
|
|||||||
@@ -84,6 +84,7 @@ type CreateAccountRequest struct {
|
|||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
Concurrency int `json:"concurrency"`
|
Concurrency int `json:"concurrency"`
|
||||||
Priority int `json:"priority"`
|
Priority int `json:"priority"`
|
||||||
|
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||||
GroupIDs []int64 `json:"group_ids"`
|
GroupIDs []int64 `json:"group_ids"`
|
||||||
ExpiresAt *int64 `json:"expires_at"`
|
ExpiresAt *int64 `json:"expires_at"`
|
||||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
||||||
@@ -101,6 +102,7 @@ type UpdateAccountRequest struct {
|
|||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
Concurrency *int `json:"concurrency"`
|
Concurrency *int `json:"concurrency"`
|
||||||
Priority *int `json:"priority"`
|
Priority *int `json:"priority"`
|
||||||
|
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||||
GroupIDs *[]int64 `json:"group_ids"`
|
GroupIDs *[]int64 `json:"group_ids"`
|
||||||
ExpiresAt *int64 `json:"expires_at"`
|
ExpiresAt *int64 `json:"expires_at"`
|
||||||
@@ -115,6 +117,7 @@ type BulkUpdateAccountsRequest struct {
|
|||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
Concurrency *int `json:"concurrency"`
|
Concurrency *int `json:"concurrency"`
|
||||||
Priority *int `json:"priority"`
|
Priority *int `json:"priority"`
|
||||||
|
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||||
Schedulable *bool `json:"schedulable"`
|
Schedulable *bool `json:"schedulable"`
|
||||||
GroupIDs *[]int64 `json:"group_ids"`
|
GroupIDs *[]int64 `json:"group_ids"`
|
||||||
@@ -199,6 +202,10 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
|||||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if req.RateMultiplier != nil && *req.RateMultiplier < 0 {
|
||||||
|
response.BadRequest(c, "rate_multiplier must be >= 0")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 确定是否跳过混合渠道检查
|
// 确定是否跳过混合渠道检查
|
||||||
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
||||||
@@ -213,6 +220,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
|||||||
ProxyID: req.ProxyID,
|
ProxyID: req.ProxyID,
|
||||||
Concurrency: req.Concurrency,
|
Concurrency: req.Concurrency,
|
||||||
Priority: req.Priority,
|
Priority: req.Priority,
|
||||||
|
RateMultiplier: req.RateMultiplier,
|
||||||
GroupIDs: req.GroupIDs,
|
GroupIDs: req.GroupIDs,
|
||||||
ExpiresAt: req.ExpiresAt,
|
ExpiresAt: req.ExpiresAt,
|
||||||
AutoPauseOnExpired: req.AutoPauseOnExpired,
|
AutoPauseOnExpired: req.AutoPauseOnExpired,
|
||||||
@@ -258,6 +266,10 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
|||||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if req.RateMultiplier != nil && *req.RateMultiplier < 0 {
|
||||||
|
response.BadRequest(c, "rate_multiplier must be >= 0")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 确定是否跳过混合渠道检查
|
// 确定是否跳过混合渠道检查
|
||||||
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
||||||
@@ -271,6 +283,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
|||||||
ProxyID: req.ProxyID,
|
ProxyID: req.ProxyID,
|
||||||
Concurrency: req.Concurrency, // 指针类型,nil 表示未提供
|
Concurrency: req.Concurrency, // 指针类型,nil 表示未提供
|
||||||
Priority: req.Priority, // 指针类型,nil 表示未提供
|
Priority: req.Priority, // 指针类型,nil 表示未提供
|
||||||
|
RateMultiplier: req.RateMultiplier,
|
||||||
Status: req.Status,
|
Status: req.Status,
|
||||||
GroupIDs: req.GroupIDs,
|
GroupIDs: req.GroupIDs,
|
||||||
ExpiresAt: req.ExpiresAt,
|
ExpiresAt: req.ExpiresAt,
|
||||||
@@ -652,6 +665,10 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
|||||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if req.RateMultiplier != nil && *req.RateMultiplier < 0 {
|
||||||
|
response.BadRequest(c, "rate_multiplier must be >= 0")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 确定是否跳过混合渠道检查
|
// 确定是否跳过混合渠道检查
|
||||||
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
||||||
@@ -660,6 +677,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
|||||||
req.ProxyID != nil ||
|
req.ProxyID != nil ||
|
||||||
req.Concurrency != nil ||
|
req.Concurrency != nil ||
|
||||||
req.Priority != nil ||
|
req.Priority != nil ||
|
||||||
|
req.RateMultiplier != nil ||
|
||||||
req.Status != "" ||
|
req.Status != "" ||
|
||||||
req.Schedulable != nil ||
|
req.Schedulable != nil ||
|
||||||
req.GroupIDs != nil ||
|
req.GroupIDs != nil ||
|
||||||
@@ -677,6 +695,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
|||||||
ProxyID: req.ProxyID,
|
ProxyID: req.ProxyID,
|
||||||
Concurrency: req.Concurrency,
|
Concurrency: req.Concurrency,
|
||||||
Priority: req.Priority,
|
Priority: req.Priority,
|
||||||
|
RateMultiplier: req.RateMultiplier,
|
||||||
Status: req.Status,
|
Status: req.Status,
|
||||||
Schedulable: req.Schedulable,
|
Schedulable: req.Schedulable,
|
||||||
GroupIDs: req.GroupIDs,
|
GroupIDs: req.GroupIDs,
|
||||||
|
|||||||
@@ -186,13 +186,16 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
|
|||||||
|
|
||||||
// GetUsageTrend handles getting usage trend data
|
// GetUsageTrend handles getting usage trend data
|
||||||
// GET /api/v1/admin/dashboard/trend
|
// GET /api/v1/admin/dashboard/trend
|
||||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id
|
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream
|
||||||
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||||
startTime, endTime := parseTimeRange(c)
|
startTime, endTime := parseTimeRange(c)
|
||||||
granularity := c.DefaultQuery("granularity", "day")
|
granularity := c.DefaultQuery("granularity", "day")
|
||||||
|
|
||||||
// Parse optional filter params
|
// Parse optional filter params
|
||||||
var userID, apiKeyID int64
|
var userID, apiKeyID, accountID, groupID int64
|
||||||
|
var model string
|
||||||
|
var stream *bool
|
||||||
|
|
||||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||||
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
||||||
userID = id
|
userID = id
|
||||||
@@ -203,8 +206,26 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
|||||||
apiKeyID = id
|
apiKeyID = id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
|
||||||
|
if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil {
|
||||||
|
accountID = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
|
||||||
|
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
|
||||||
|
groupID = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if modelStr := c.Query("model"); modelStr != "" {
|
||||||
|
model = modelStr
|
||||||
|
}
|
||||||
|
if streamStr := c.Query("stream"); streamStr != "" {
|
||||||
|
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
|
||||||
|
stream = &streamVal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID)
|
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get usage trend")
|
response.Error(c, 500, "Failed to get usage trend")
|
||||||
return
|
return
|
||||||
@@ -220,12 +241,14 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
|||||||
|
|
||||||
// GetModelStats handles getting model usage statistics
|
// GetModelStats handles getting model usage statistics
|
||||||
// GET /api/v1/admin/dashboard/models
|
// GET /api/v1/admin/dashboard/models
|
||||||
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id
|
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream
|
||||||
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||||
startTime, endTime := parseTimeRange(c)
|
startTime, endTime := parseTimeRange(c)
|
||||||
|
|
||||||
// Parse optional filter params
|
// Parse optional filter params
|
||||||
var userID, apiKeyID int64
|
var userID, apiKeyID, accountID, groupID int64
|
||||||
|
var stream *bool
|
||||||
|
|
||||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||||
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
||||||
userID = id
|
userID = id
|
||||||
@@ -236,8 +259,23 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
|||||||
apiKeyID = id
|
apiKeyID = id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
|
||||||
|
if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil {
|
||||||
|
accountID = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
|
||||||
|
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
|
||||||
|
groupID = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if streamStr := c.Query("stream"); streamStr != "" {
|
||||||
|
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
|
||||||
|
stream = &streamVal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID)
|
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get model statistics")
|
response.Error(c, 500, "Failed to get model statistics")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -7,8 +7,10 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gin-gonic/gin/binding"
|
"github.com/gin-gonic/gin/binding"
|
||||||
@@ -18,8 +20,6 @@ var validOpsAlertMetricTypes = []string{
|
|||||||
"success_rate",
|
"success_rate",
|
||||||
"error_rate",
|
"error_rate",
|
||||||
"upstream_error_rate",
|
"upstream_error_rate",
|
||||||
"p95_latency_ms",
|
|
||||||
"p99_latency_ms",
|
|
||||||
"cpu_usage_percent",
|
"cpu_usage_percent",
|
||||||
"memory_usage_percent",
|
"memory_usage_percent",
|
||||||
"concurrency_queue_depth",
|
"concurrency_queue_depth",
|
||||||
@@ -372,8 +372,135 @@ func (h *OpsHandler) DeleteAlertRule(c *gin.Context) {
|
|||||||
response.Success(c, gin.H{"deleted": true})
|
response.Success(c, gin.H{"deleted": true})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAlertEvent returns a single ops alert event.
|
||||||
|
// GET /api/v1/admin/ops/alert-events/:id
|
||||||
|
func (h *OpsHandler) GetAlertEvent(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid event ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ev, err := h.opsService.GetAlertEventByID(c.Request.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, ev)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAlertEventStatus updates an ops alert event status.
|
||||||
|
// PUT /api/v1/admin/ops/alert-events/:id/status
|
||||||
|
func (h *OpsHandler) UpdateAlertEventStatus(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid event ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
payload.Status = strings.TrimSpace(payload.Status)
|
||||||
|
if payload.Status == "" {
|
||||||
|
response.BadRequest(c, "Invalid status")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if payload.Status != service.OpsAlertStatusResolved && payload.Status != service.OpsAlertStatusManualResolved {
|
||||||
|
response.BadRequest(c, "Invalid status")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var resolvedAt *time.Time
|
||||||
|
if payload.Status == service.OpsAlertStatusResolved || payload.Status == service.OpsAlertStatusManualResolved {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
resolvedAt = &now
|
||||||
|
}
|
||||||
|
if err := h.opsService.UpdateAlertEventStatus(c.Request.Context(), id, payload.Status, resolvedAt); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"updated": true})
|
||||||
|
}
|
||||||
|
|
||||||
// ListAlertEvents lists recent ops alert events.
|
// ListAlertEvents lists recent ops alert events.
|
||||||
// GET /api/v1/admin/ops/alert-events
|
// GET /api/v1/admin/ops/alert-events
|
||||||
|
// CreateAlertSilence creates a scoped silence for ops alerts.
|
||||||
|
// POST /api/v1/admin/ops/alert-silences
|
||||||
|
func (h *OpsHandler) CreateAlertSilence(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload struct {
|
||||||
|
RuleID int64 `json:"rule_id"`
|
||||||
|
Platform string `json:"platform"`
|
||||||
|
GroupID *int64 `json:"group_id"`
|
||||||
|
Region *string `json:"region"`
|
||||||
|
Until string `json:"until"`
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
until, err := time.Parse(time.RFC3339, strings.TrimSpace(payload.Until))
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid until")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
createdBy := (*int64)(nil)
|
||||||
|
if subject, ok := middleware.GetAuthSubjectFromContext(c); ok {
|
||||||
|
uid := subject.UserID
|
||||||
|
createdBy = &uid
|
||||||
|
}
|
||||||
|
|
||||||
|
silence := &service.OpsAlertSilence{
|
||||||
|
RuleID: payload.RuleID,
|
||||||
|
Platform: strings.TrimSpace(payload.Platform),
|
||||||
|
GroupID: payload.GroupID,
|
||||||
|
Region: payload.Region,
|
||||||
|
Until: until,
|
||||||
|
Reason: strings.TrimSpace(payload.Reason),
|
||||||
|
CreatedBy: createdBy,
|
||||||
|
}
|
||||||
|
|
||||||
|
created, err := h.opsService.CreateAlertSilence(c.Request.Context(), silence)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, created)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *OpsHandler) ListAlertEvents(c *gin.Context) {
|
func (h *OpsHandler) ListAlertEvents(c *gin.Context) {
|
||||||
if h.opsService == nil {
|
if h.opsService == nil {
|
||||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
@@ -384,7 +511,7 @@ func (h *OpsHandler) ListAlertEvents(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
limit := 100
|
limit := 20
|
||||||
if raw := strings.TrimSpace(c.Query("limit")); raw != "" {
|
if raw := strings.TrimSpace(c.Query("limit")); raw != "" {
|
||||||
n, err := strconv.Atoi(raw)
|
n, err := strconv.Atoi(raw)
|
||||||
if err != nil || n <= 0 {
|
if err != nil || n <= 0 {
|
||||||
@@ -400,6 +527,49 @@ func (h *OpsHandler) ListAlertEvents(c *gin.Context) {
|
|||||||
Severity: strings.TrimSpace(c.Query("severity")),
|
Severity: strings.TrimSpace(c.Query("severity")),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if v := strings.TrimSpace(c.Query("email_sent")); v != "" {
|
||||||
|
vv := strings.ToLower(v)
|
||||||
|
switch vv {
|
||||||
|
case "true", "1":
|
||||||
|
b := true
|
||||||
|
filter.EmailSent = &b
|
||||||
|
case "false", "0":
|
||||||
|
b := false
|
||||||
|
filter.EmailSent = &b
|
||||||
|
default:
|
||||||
|
response.BadRequest(c, "Invalid email_sent")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cursor pagination: both params must be provided together.
|
||||||
|
rawTS := strings.TrimSpace(c.Query("before_fired_at"))
|
||||||
|
rawID := strings.TrimSpace(c.Query("before_id"))
|
||||||
|
if (rawTS == "") != (rawID == "") {
|
||||||
|
response.BadRequest(c, "before_fired_at and before_id must be provided together")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if rawTS != "" {
|
||||||
|
ts, err := time.Parse(time.RFC3339Nano, rawTS)
|
||||||
|
if err != nil {
|
||||||
|
if t2, err2 := time.Parse(time.RFC3339, rawTS); err2 == nil {
|
||||||
|
ts = t2
|
||||||
|
} else {
|
||||||
|
response.BadRequest(c, "Invalid before_fired_at")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
filter.BeforeFiredAt = &ts
|
||||||
|
}
|
||||||
|
if rawID != "" {
|
||||||
|
id, err := strconv.ParseInt(rawID, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid before_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.BeforeID = &id
|
||||||
|
}
|
||||||
|
|
||||||
// Optional global filter support (platform/group/time range).
|
// Optional global filter support (platform/group/time range).
|
||||||
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
|
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
|
||||||
filter.Platform = platform
|
filter.Platform = platform
|
||||||
|
|||||||
@@ -19,6 +19,57 @@ type OpsHandler struct {
|
|||||||
opsService *service.OpsService
|
opsService *service.OpsService
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetErrorLogByID returns ops error log detail.
|
||||||
|
// GET /api/v1/admin/ops/errors/:id
|
||||||
|
func (h *OpsHandler) GetErrorLogByID(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idStr := strings.TrimSpace(c.Param("id"))
|
||||||
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid error id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
detail, err := h.opsService.GetErrorLogByID(c.Request.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, detail)
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
opsListViewErrors = "errors"
|
||||||
|
opsListViewExcluded = "excluded"
|
||||||
|
opsListViewAll = "all"
|
||||||
|
)
|
||||||
|
|
||||||
|
func parseOpsViewParam(c *gin.Context) string {
|
||||||
|
if c == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
v := strings.ToLower(strings.TrimSpace(c.Query("view")))
|
||||||
|
switch v {
|
||||||
|
case "", opsListViewErrors:
|
||||||
|
return opsListViewErrors
|
||||||
|
case opsListViewExcluded:
|
||||||
|
return opsListViewExcluded
|
||||||
|
case opsListViewAll:
|
||||||
|
return opsListViewAll
|
||||||
|
default:
|
||||||
|
return opsListViewErrors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func NewOpsHandler(opsService *service.OpsService) *OpsHandler {
|
func NewOpsHandler(opsService *service.OpsService) *OpsHandler {
|
||||||
return &OpsHandler{opsService: opsService}
|
return &OpsHandler{opsService: opsService}
|
||||||
}
|
}
|
||||||
@@ -47,16 +98,26 @@ func (h *OpsHandler) GetErrorLogs(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
filter := &service.OpsErrorLogFilter{
|
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
|
||||||
Page: page,
|
|
||||||
PageSize: pageSize,
|
|
||||||
}
|
|
||||||
if !startTime.IsZero() {
|
if !startTime.IsZero() {
|
||||||
filter.StartTime = &startTime
|
filter.StartTime = &startTime
|
||||||
}
|
}
|
||||||
if !endTime.IsZero() {
|
if !endTime.IsZero() {
|
||||||
filter.EndTime = &endTime
|
filter.EndTime = &endTime
|
||||||
}
|
}
|
||||||
|
filter.View = parseOpsViewParam(c)
|
||||||
|
filter.Phase = strings.TrimSpace(c.Query("phase"))
|
||||||
|
filter.Owner = strings.TrimSpace(c.Query("error_owner"))
|
||||||
|
filter.Source = strings.TrimSpace(c.Query("error_source"))
|
||||||
|
filter.Query = strings.TrimSpace(c.Query("q"))
|
||||||
|
filter.UserQuery = strings.TrimSpace(c.Query("user_query"))
|
||||||
|
|
||||||
|
// Force request errors: client-visible status >= 400.
|
||||||
|
// buildOpsErrorLogsWhere already applies this for non-upstream phase.
|
||||||
|
if strings.EqualFold(strings.TrimSpace(filter.Phase), "upstream") {
|
||||||
|
filter.Phase = ""
|
||||||
|
}
|
||||||
|
|
||||||
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
|
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
|
||||||
filter.Platform = platform
|
filter.Platform = platform
|
||||||
@@ -77,11 +138,19 @@ func (h *OpsHandler) GetErrorLogs(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
filter.AccountID = &id
|
filter.AccountID = &id
|
||||||
}
|
}
|
||||||
if phase := strings.TrimSpace(c.Query("phase")); phase != "" {
|
|
||||||
filter.Phase = phase
|
if v := strings.TrimSpace(c.Query("resolved")); v != "" {
|
||||||
}
|
switch strings.ToLower(v) {
|
||||||
if q := strings.TrimSpace(c.Query("q")); q != "" {
|
case "1", "true", "yes":
|
||||||
filter.Query = q
|
b := true
|
||||||
|
filter.Resolved = &b
|
||||||
|
case "0", "false", "no":
|
||||||
|
b := false
|
||||||
|
filter.Resolved = &b
|
||||||
|
default:
|
||||||
|
response.BadRequest(c, "Invalid resolved")
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
|
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
|
||||||
parts := strings.Split(statusCodesStr, ",")
|
parts := strings.Split(statusCodesStr, ",")
|
||||||
@@ -106,13 +175,120 @@ func (h *OpsHandler) GetErrorLogs(c *gin.Context) {
|
|||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
|
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetErrorLogByID returns a single error log detail.
|
// ListRequestErrors lists client-visible request errors.
|
||||||
// GET /api/v1/admin/ops/errors/:id
|
// GET /api/v1/admin/ops/request-errors
|
||||||
func (h *OpsHandler) GetErrorLogByID(c *gin.Context) {
|
func (h *OpsHandler) ListRequestErrors(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
page, pageSize := response.ParsePagination(c)
|
||||||
|
if pageSize > 500 {
|
||||||
|
pageSize = 500
|
||||||
|
}
|
||||||
|
startTime, endTime, err := parseOpsTimeRange(c, "1h")
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
|
||||||
|
if !startTime.IsZero() {
|
||||||
|
filter.StartTime = &startTime
|
||||||
|
}
|
||||||
|
if !endTime.IsZero() {
|
||||||
|
filter.EndTime = &endTime
|
||||||
|
}
|
||||||
|
filter.View = parseOpsViewParam(c)
|
||||||
|
filter.Phase = strings.TrimSpace(c.Query("phase"))
|
||||||
|
filter.Owner = strings.TrimSpace(c.Query("error_owner"))
|
||||||
|
filter.Source = strings.TrimSpace(c.Query("error_source"))
|
||||||
|
filter.Query = strings.TrimSpace(c.Query("q"))
|
||||||
|
filter.UserQuery = strings.TrimSpace(c.Query("user_query"))
|
||||||
|
|
||||||
|
// Force request errors: client-visible status >= 400.
|
||||||
|
// buildOpsErrorLogsWhere already applies this for non-upstream phase.
|
||||||
|
if strings.EqualFold(strings.TrimSpace(filter.Phase), "upstream") {
|
||||||
|
filter.Phase = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
|
||||||
|
filter.Platform = platform
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid group_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.GroupID = &id
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid account_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.AccountID = &id
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := strings.TrimSpace(c.Query("resolved")); v != "" {
|
||||||
|
switch strings.ToLower(v) {
|
||||||
|
case "1", "true", "yes":
|
||||||
|
b := true
|
||||||
|
filter.Resolved = &b
|
||||||
|
case "0", "false", "no":
|
||||||
|
b := false
|
||||||
|
filter.Resolved = &b
|
||||||
|
default:
|
||||||
|
response.BadRequest(c, "Invalid resolved")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
|
||||||
|
parts := strings.Split(statusCodesStr, ",")
|
||||||
|
out := make([]int, 0, len(parts))
|
||||||
|
for _, part := range parts {
|
||||||
|
p := strings.TrimSpace(part)
|
||||||
|
if p == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
n, err := strconv.Atoi(p)
|
||||||
|
if err != nil || n < 0 {
|
||||||
|
response.BadRequest(c, "Invalid status_codes")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out = append(out, n)
|
||||||
|
}
|
||||||
|
filter.StatusCodes = out
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRequestError returns request error detail.
|
||||||
|
// GET /api/v1/admin/ops/request-errors/:id
|
||||||
|
func (h *OpsHandler) GetRequestError(c *gin.Context) {
|
||||||
|
// same storage; just proxy to existing detail
|
||||||
|
h.GetErrorLogByID(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListRequestErrorUpstreamErrors lists upstream error logs correlated to a request error.
|
||||||
|
// GET /api/v1/admin/ops/request-errors/:id/upstream-errors
|
||||||
|
func (h *OpsHandler) ListRequestErrorUpstreamErrors(c *gin.Context) {
|
||||||
if h.opsService == nil {
|
if h.opsService == nil {
|
||||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
return
|
return
|
||||||
@@ -129,15 +305,306 @@ func (h *OpsHandler) GetErrorLogByID(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load request error to get correlation keys.
|
||||||
detail, err := h.opsService.GetErrorLogByID(c.Request.Context(), id)
|
detail, err := h.opsService.GetErrorLogByID(c.Request.Context(), id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, detail)
|
// Correlate by request_id/client_request_id.
|
||||||
|
requestID := strings.TrimSpace(detail.RequestID)
|
||||||
|
clientRequestID := strings.TrimSpace(detail.ClientRequestID)
|
||||||
|
if requestID == "" && clientRequestID == "" {
|
||||||
|
response.Paginated(c, []*service.OpsErrorLog{}, 0, 1, 10)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
page, pageSize := response.ParsePagination(c)
|
||||||
|
if pageSize > 500 {
|
||||||
|
pageSize = 500
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep correlation window wide enough so linked upstream errors
|
||||||
|
// are discoverable even when UI defaults to 1h elsewhere.
|
||||||
|
startTime, endTime, err := parseOpsTimeRange(c, "30d")
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
|
||||||
|
if !startTime.IsZero() {
|
||||||
|
filter.StartTime = &startTime
|
||||||
|
}
|
||||||
|
if !endTime.IsZero() {
|
||||||
|
filter.EndTime = &endTime
|
||||||
|
}
|
||||||
|
filter.View = "all"
|
||||||
|
filter.Phase = "upstream"
|
||||||
|
filter.Owner = "provider"
|
||||||
|
filter.Source = strings.TrimSpace(c.Query("error_source"))
|
||||||
|
filter.Query = strings.TrimSpace(c.Query("q"))
|
||||||
|
|
||||||
|
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
|
||||||
|
filter.Platform = platform
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prefer exact match on request_id; if missing, fall back to client_request_id.
|
||||||
|
if requestID != "" {
|
||||||
|
filter.RequestID = requestID
|
||||||
|
} else {
|
||||||
|
filter.ClientRequestID = clientRequestID
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If client asks for details, expand each upstream error log to include upstream response fields.
|
||||||
|
includeDetail := strings.TrimSpace(c.Query("include_detail"))
|
||||||
|
if includeDetail == "1" || strings.EqualFold(includeDetail, "true") || strings.EqualFold(includeDetail, "yes") {
|
||||||
|
details := make([]*service.OpsErrorLogDetail, 0, len(result.Errors))
|
||||||
|
for _, item := range result.Errors {
|
||||||
|
if item == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
d, err := h.opsService.GetErrorLogByID(c.Request.Context(), item.ID)
|
||||||
|
if err != nil || d == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
details = append(details, d)
|
||||||
|
}
|
||||||
|
response.Paginated(c, details, int64(result.Total), result.Page, result.PageSize)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RetryRequestErrorClient retries the client request based on stored request body.
|
||||||
|
// POST /api/v1/admin/ops/request-errors/:id/retry-client
|
||||||
|
func (h *OpsHandler) RetryRequestErrorClient(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok || subject.UserID <= 0 {
|
||||||
|
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idStr := strings.TrimSpace(c.Param("id"))
|
||||||
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid error id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, service.OpsRetryModeClient, nil)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetryRequestErrorUpstreamEvent retries a specific upstream attempt using captured upstream_request_body.
|
||||||
|
// POST /api/v1/admin/ops/request-errors/:id/upstream-errors/:idx/retry
|
||||||
|
func (h *OpsHandler) RetryRequestErrorUpstreamEvent(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok || subject.UserID <= 0 {
|
||||||
|
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idStr := strings.TrimSpace(c.Param("id"))
|
||||||
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid error id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idxStr := strings.TrimSpace(c.Param("idx"))
|
||||||
|
idx, err := strconv.Atoi(idxStr)
|
||||||
|
if err != nil || idx < 0 {
|
||||||
|
response.BadRequest(c, "Invalid upstream idx")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.opsService.RetryUpstreamEvent(c.Request.Context(), subject.UserID, id, idx)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveRequestError toggles resolved status.
|
||||||
|
// PUT /api/v1/admin/ops/request-errors/:id/resolve
|
||||||
|
func (h *OpsHandler) ResolveRequestError(c *gin.Context) {
|
||||||
|
h.UpdateErrorResolution(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListUpstreamErrors lists independent upstream errors.
|
||||||
|
// GET /api/v1/admin/ops/upstream-errors
|
||||||
|
func (h *OpsHandler) ListUpstreamErrors(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
page, pageSize := response.ParsePagination(c)
|
||||||
|
if pageSize > 500 {
|
||||||
|
pageSize = 500
|
||||||
|
}
|
||||||
|
startTime, endTime, err := parseOpsTimeRange(c, "1h")
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
|
||||||
|
if !startTime.IsZero() {
|
||||||
|
filter.StartTime = &startTime
|
||||||
|
}
|
||||||
|
if !endTime.IsZero() {
|
||||||
|
filter.EndTime = &endTime
|
||||||
|
}
|
||||||
|
|
||||||
|
filter.View = parseOpsViewParam(c)
|
||||||
|
filter.Phase = "upstream"
|
||||||
|
filter.Owner = "provider"
|
||||||
|
filter.Source = strings.TrimSpace(c.Query("error_source"))
|
||||||
|
filter.Query = strings.TrimSpace(c.Query("q"))
|
||||||
|
|
||||||
|
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
|
||||||
|
filter.Platform = platform
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid group_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.GroupID = &id
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid account_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.AccountID = &id
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := strings.TrimSpace(c.Query("resolved")); v != "" {
|
||||||
|
switch strings.ToLower(v) {
|
||||||
|
case "1", "true", "yes":
|
||||||
|
b := true
|
||||||
|
filter.Resolved = &b
|
||||||
|
case "0", "false", "no":
|
||||||
|
b := false
|
||||||
|
filter.Resolved = &b
|
||||||
|
default:
|
||||||
|
response.BadRequest(c, "Invalid resolved")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
|
||||||
|
parts := strings.Split(statusCodesStr, ",")
|
||||||
|
out := make([]int, 0, len(parts))
|
||||||
|
for _, part := range parts {
|
||||||
|
p := strings.TrimSpace(part)
|
||||||
|
if p == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
n, err := strconv.Atoi(p)
|
||||||
|
if err != nil || n < 0 {
|
||||||
|
response.BadRequest(c, "Invalid status_codes")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out = append(out, n)
|
||||||
|
}
|
||||||
|
filter.StatusCodes = out
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUpstreamError returns upstream error detail.
|
||||||
|
// GET /api/v1/admin/ops/upstream-errors/:id
|
||||||
|
func (h *OpsHandler) GetUpstreamError(c *gin.Context) {
|
||||||
|
h.GetErrorLogByID(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetryUpstreamError retries upstream error using the original account_id.
|
||||||
|
// POST /api/v1/admin/ops/upstream-errors/:id/retry
|
||||||
|
func (h *OpsHandler) RetryUpstreamError(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok || subject.UserID <= 0 {
|
||||||
|
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idStr := strings.TrimSpace(c.Param("id"))
|
||||||
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid error id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, service.OpsRetryModeUpstream, nil)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveUpstreamError toggles resolved status.
|
||||||
|
// PUT /api/v1/admin/ops/upstream-errors/:id/resolve
|
||||||
|
func (h *OpsHandler) ResolveUpstreamError(c *gin.Context) {
|
||||||
|
h.UpdateErrorResolution(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== Existing endpoints ====================
|
||||||
|
|
||||||
// ListRequestDetails returns a request-level list (success + error) for drill-down.
|
// ListRequestDetails returns a request-level list (success + error) for drill-down.
|
||||||
// GET /api/v1/admin/ops/requests
|
// GET /api/v1/admin/ops/requests
|
||||||
func (h *OpsHandler) ListRequestDetails(c *gin.Context) {
|
func (h *OpsHandler) ListRequestDetails(c *gin.Context) {
|
||||||
@@ -242,6 +709,11 @@ func (h *OpsHandler) ListRequestDetails(c *gin.Context) {
|
|||||||
type opsRetryRequest struct {
|
type opsRetryRequest struct {
|
||||||
Mode string `json:"mode"`
|
Mode string `json:"mode"`
|
||||||
PinnedAccountID *int64 `json:"pinned_account_id"`
|
PinnedAccountID *int64 `json:"pinned_account_id"`
|
||||||
|
Force bool `json:"force"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type opsResolveRequest struct {
|
||||||
|
Resolved bool `json:"resolved"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// RetryErrorRequest retries a failed request using stored request_body.
|
// RetryErrorRequest retries a failed request using stored request_body.
|
||||||
@@ -278,6 +750,16 @@ func (h *OpsHandler) RetryErrorRequest(c *gin.Context) {
|
|||||||
req.Mode = service.OpsRetryModeClient
|
req.Mode = service.OpsRetryModeClient
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Force flag is currently a UI-level acknowledgement. Server may still enforce safety constraints.
|
||||||
|
_ = req.Force
|
||||||
|
|
||||||
|
// Legacy endpoint safety: only allow retrying the client request here.
|
||||||
|
// Upstream retries must go through the split endpoints.
|
||||||
|
if strings.EqualFold(strings.TrimSpace(req.Mode), service.OpsRetryModeUpstream) {
|
||||||
|
response.BadRequest(c, "upstream retry is not supported on this endpoint")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, req.Mode, req.PinnedAccountID)
|
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, req.Mode, req.PinnedAccountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
@@ -287,6 +769,81 @@ func (h *OpsHandler) RetryErrorRequest(c *gin.Context) {
|
|||||||
response.Success(c, result)
|
response.Success(c, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListRetryAttempts lists retry attempts for an error log.
|
||||||
|
// GET /api/v1/admin/ops/errors/:id/retries
|
||||||
|
func (h *OpsHandler) ListRetryAttempts(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idStr := strings.TrimSpace(c.Param("id"))
|
||||||
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid error id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
limit := 50
|
||||||
|
if v := strings.TrimSpace(c.Query("limit")); v != "" {
|
||||||
|
n, err := strconv.Atoi(v)
|
||||||
|
if err != nil || n <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid limit")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
limit = n
|
||||||
|
}
|
||||||
|
|
||||||
|
items, err := h.opsService.ListRetryAttemptsByErrorID(c.Request.Context(), id, limit)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, items)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateErrorResolution allows manual resolve/unresolve.
|
||||||
|
// PUT /api/v1/admin/ops/errors/:id/resolve
|
||||||
|
func (h *OpsHandler) UpdateErrorResolution(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok || subject.UserID <= 0 {
|
||||||
|
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idStr := strings.TrimSpace(c.Param("id"))
|
||||||
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid error id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req opsResolveRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
uid := subject.UserID
|
||||||
|
if err := h.opsService.UpdateErrorResolution(c.Request.Context(), id, req.Resolved, &uid, nil); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"ok": true})
|
||||||
|
}
|
||||||
|
|
||||||
func parseOpsTimeRange(c *gin.Context, defaultRange string) (time.Time, time.Time, error) {
|
func parseOpsTimeRange(c *gin.Context, defaultRange string) (time.Time, time.Time, error) {
|
||||||
startStr := strings.TrimSpace(c.Query("start_time"))
|
startStr := strings.TrimSpace(c.Query("start_time"))
|
||||||
endStr := strings.TrimSpace(c.Query("end_time"))
|
endStr := strings.TrimSpace(c.Query("end_time"))
|
||||||
@@ -358,6 +915,10 @@ func parseOpsDuration(v string) (time.Duration, bool) {
|
|||||||
return 6 * time.Hour, true
|
return 6 * time.Hour, true
|
||||||
case "24h":
|
case "24h":
|
||||||
return 24 * time.Hour, true
|
return 24 * time.Hour, true
|
||||||
|
case "7d":
|
||||||
|
return 7 * 24 * time.Hour, true
|
||||||
|
case "30d":
|
||||||
|
return 30 * 24 * time.Hour, true
|
||||||
default:
|
default:
|
||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -118,3 +118,96 @@ func (h *OpsHandler) GetAccountAvailability(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
response.Success(c, payload)
|
response.Success(c, payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseOpsRealtimeWindow(v string) (time.Duration, string, bool) {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||||
|
case "", "1min", "1m":
|
||||||
|
return 1 * time.Minute, "1min", true
|
||||||
|
case "5min", "5m":
|
||||||
|
return 5 * time.Minute, "5min", true
|
||||||
|
case "30min", "30m":
|
||||||
|
return 30 * time.Minute, "30min", true
|
||||||
|
case "1h", "60m", "60min":
|
||||||
|
return 1 * time.Hour, "1h", true
|
||||||
|
default:
|
||||||
|
return 0, "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRealtimeTrafficSummary returns QPS/TPS current/peak/avg for the selected window.
|
||||||
|
// GET /api/v1/admin/ops/realtime-traffic
|
||||||
|
//
|
||||||
|
// Query params:
|
||||||
|
// - window: 1min|5min|30min|1h (default: 1min)
|
||||||
|
// - platform: optional
|
||||||
|
// - group_id: optional
|
||||||
|
func (h *OpsHandler) GetRealtimeTrafficSummary(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
windowDur, windowLabel, ok := parseOpsRealtimeWindow(c.Query("window"))
|
||||||
|
if !ok {
|
||||||
|
response.BadRequest(c, "Invalid window")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
platform := strings.TrimSpace(c.Query("platform"))
|
||||||
|
var groupID *int64
|
||||||
|
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid group_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
groupID = &id
|
||||||
|
}
|
||||||
|
|
||||||
|
endTime := time.Now().UTC()
|
||||||
|
startTime := endTime.Add(-windowDur)
|
||||||
|
|
||||||
|
if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) {
|
||||||
|
disabledSummary := &service.OpsRealtimeTrafficSummary{
|
||||||
|
Window: windowLabel,
|
||||||
|
StartTime: startTime,
|
||||||
|
EndTime: endTime,
|
||||||
|
Platform: platform,
|
||||||
|
GroupID: groupID,
|
||||||
|
QPS: service.OpsRateSummary{},
|
||||||
|
TPS: service.OpsRateSummary{},
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"enabled": false,
|
||||||
|
"summary": disabledSummary,
|
||||||
|
"timestamp": endTime,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := &service.OpsDashboardFilter{
|
||||||
|
StartTime: startTime,
|
||||||
|
EndTime: endTime,
|
||||||
|
Platform: platform,
|
||||||
|
GroupID: groupID,
|
||||||
|
QueryMode: service.OpsQueryModeRaw,
|
||||||
|
}
|
||||||
|
|
||||||
|
summary, err := h.opsService.GetRealtimeTrafficSummary(c.Request.Context(), filter)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if summary != nil {
|
||||||
|
summary.Window = windowLabel
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"enabled": true,
|
||||||
|
"summary": summary,
|
||||||
|
"timestamp": endTime,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -146,3 +146,49 @@ func (h *OpsHandler) UpdateAdvancedSettings(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
response.Success(c, updated)
|
response.Success(c, updated)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetMetricThresholds returns Ops metric thresholds (DB-backed).
|
||||||
|
// GET /api/v1/admin/ops/settings/metric-thresholds
|
||||||
|
func (h *OpsHandler) GetMetricThresholds(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := h.opsService.GetMetricThresholds(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, http.StatusInternalServerError, "Failed to get metric thresholds")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateMetricThresholds updates Ops metric thresholds (DB-backed).
|
||||||
|
// PUT /api/v1/admin/ops/settings/metric-thresholds
|
||||||
|
func (h *OpsHandler) UpdateMetricThresholds(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req service.OpsMetricThresholds
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, err := h.opsService.UpdateMetricThresholds(c.Request.Context(), &req)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, updated)
|
||||||
|
}
|
||||||
|
|||||||
@@ -196,6 +196,28 @@ func (h *ProxyHandler) Delete(c *gin.Context) {
|
|||||||
response.Success(c, gin.H{"message": "Proxy deleted successfully"})
|
response.Success(c, gin.H{"message": "Proxy deleted successfully"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BatchDelete handles batch deleting proxies
|
||||||
|
// POST /api/v1/admin/proxies/batch-delete
|
||||||
|
func (h *ProxyHandler) BatchDelete(c *gin.Context) {
|
||||||
|
type BatchDeleteRequest struct {
|
||||||
|
IDs []int64 `json:"ids" binding:"required,min=1"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var req BatchDeleteRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.adminService.BatchDeleteProxies(c.Request.Context(), req.IDs)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|
||||||
// Test handles testing proxy connectivity
|
// Test handles testing proxy connectivity
|
||||||
// POST /api/v1/admin/proxies/:id/test
|
// POST /api/v1/admin/proxies/:id/test
|
||||||
func (h *ProxyHandler) Test(c *gin.Context) {
|
func (h *ProxyHandler) Test(c *gin.Context) {
|
||||||
@@ -243,19 +265,17 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
page, pageSize := response.ParsePagination(c)
|
accounts, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID)
|
||||||
|
|
||||||
accounts, total, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID, page, pageSize)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
out := make([]dto.Account, 0, len(accounts))
|
out := make([]dto.ProxyAccountSummary, 0, len(accounts))
|
||||||
for i := range accounts {
|
for i := range accounts {
|
||||||
out = append(out, *dto.AccountFromService(&accounts[i]))
|
out = append(out, *dto.ProxyAccountSummaryFromService(&accounts[i]))
|
||||||
}
|
}
|
||||||
response.Paginated(c, out, total, page, pageSize)
|
response.Success(c, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BatchCreateProxyItem represents a single proxy in batch create request
|
// BatchCreateProxyItem represents a single proxy in batch create request
|
||||||
|
|||||||
@@ -654,3 +654,68 @@ func (h *SettingHandler) DeleteAdminAPIKey(c *gin.Context) {
|
|||||||
|
|
||||||
response.Success(c, gin.H{"message": "Admin API key deleted"})
|
response.Success(c, gin.H{"message": "Admin API key deleted"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetStreamTimeoutSettings 获取流超时处理配置
|
||||||
|
// GET /api/v1/admin/settings/stream-timeout
|
||||||
|
func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
|
||||||
|
settings, err := h.settingService.GetStreamTimeoutSettings(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, dto.StreamTimeoutSettings{
|
||||||
|
Enabled: settings.Enabled,
|
||||||
|
Action: settings.Action,
|
||||||
|
TempUnschedMinutes: settings.TempUnschedMinutes,
|
||||||
|
ThresholdCount: settings.ThresholdCount,
|
||||||
|
ThresholdWindowMinutes: settings.ThresholdWindowMinutes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
|
||||||
|
type UpdateStreamTimeoutSettingsRequest struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Action string `json:"action"`
|
||||||
|
TempUnschedMinutes int `json:"temp_unsched_minutes"`
|
||||||
|
ThresholdCount int `json:"threshold_count"`
|
||||||
|
ThresholdWindowMinutes int `json:"threshold_window_minutes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateStreamTimeoutSettings 更新流超时处理配置
|
||||||
|
// PUT /api/v1/admin/settings/stream-timeout
|
||||||
|
func (h *SettingHandler) UpdateStreamTimeoutSettings(c *gin.Context) {
|
||||||
|
var req UpdateStreamTimeoutSettingsRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
settings := &service.StreamTimeoutSettings{
|
||||||
|
Enabled: req.Enabled,
|
||||||
|
Action: req.Action,
|
||||||
|
TempUnschedMinutes: req.TempUnschedMinutes,
|
||||||
|
ThresholdCount: req.ThresholdCount,
|
||||||
|
ThresholdWindowMinutes: req.ThresholdWindowMinutes,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.settingService.SetStreamTimeoutSettings(c.Request.Context(), settings); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 重新获取设置返回
|
||||||
|
updatedSettings, err := h.settingService.GetStreamTimeoutSettings(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, dto.StreamTimeoutSettings{
|
||||||
|
Enabled: updatedSettings.Enabled,
|
||||||
|
Action: updatedSettings.Action,
|
||||||
|
TempUnschedMinutes: updatedSettings.TempUnschedMinutes,
|
||||||
|
ThresholdCount: updatedSettings.ThresholdCount,
|
||||||
|
ThresholdWindowMinutes: updatedSettings.ThresholdWindowMinutes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
@@ -76,7 +77,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
|||||||
|
|
||||||
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
|
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
|
||||||
if req.VerifyCode == "" {
|
if req.VerifyCode == "" {
|
||||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
|
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -105,7 +106,7 @@ func (h *AuthHandler) SendVerifyCode(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Turnstile 验证
|
// Turnstile 验证
|
||||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
|
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -132,7 +133,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Turnstile 验证
|
// Turnstile 验证
|
||||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
|
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -125,6 +125,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
|||||||
ProxyID: a.ProxyID,
|
ProxyID: a.ProxyID,
|
||||||
Concurrency: a.Concurrency,
|
Concurrency: a.Concurrency,
|
||||||
Priority: a.Priority,
|
Priority: a.Priority,
|
||||||
|
RateMultiplier: a.BillingRateMultiplier(),
|
||||||
Status: a.Status,
|
Status: a.Status,
|
||||||
ErrorMessage: a.ErrorMessage,
|
ErrorMessage: a.ErrorMessage,
|
||||||
LastUsedAt: a.LastUsedAt,
|
LastUsedAt: a.LastUsedAt,
|
||||||
@@ -212,8 +213,29 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &ProxyWithAccountCount{
|
return &ProxyWithAccountCount{
|
||||||
Proxy: *ProxyFromService(&p.Proxy),
|
Proxy: *ProxyFromService(&p.Proxy),
|
||||||
AccountCount: p.AccountCount,
|
AccountCount: p.AccountCount,
|
||||||
|
LatencyMs: p.LatencyMs,
|
||||||
|
LatencyStatus: p.LatencyStatus,
|
||||||
|
LatencyMessage: p.LatencyMessage,
|
||||||
|
IPAddress: p.IPAddress,
|
||||||
|
Country: p.Country,
|
||||||
|
CountryCode: p.CountryCode,
|
||||||
|
Region: p.Region,
|
||||||
|
City: p.City,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ProxyAccountSummaryFromService(a *service.ProxyAccountSummary) *ProxyAccountSummary {
|
||||||
|
if a == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &ProxyAccountSummary{
|
||||||
|
ID: a.ID,
|
||||||
|
Name: a.Name,
|
||||||
|
Platform: a.Platform,
|
||||||
|
Type: a.Type,
|
||||||
|
Notes: a.Notes,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -279,6 +301,7 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, inclu
|
|||||||
TotalCost: l.TotalCost,
|
TotalCost: l.TotalCost,
|
||||||
ActualCost: l.ActualCost,
|
ActualCost: l.ActualCost,
|
||||||
RateMultiplier: l.RateMultiplier,
|
RateMultiplier: l.RateMultiplier,
|
||||||
|
AccountRateMultiplier: l.AccountRateMultiplier,
|
||||||
BillingType: l.BillingType,
|
BillingType: l.BillingType,
|
||||||
Stream: l.Stream,
|
Stream: l.Stream,
|
||||||
DurationMs: l.DurationMs,
|
DurationMs: l.DurationMs,
|
||||||
|
|||||||
@@ -66,3 +66,12 @@ type PublicSettings struct {
|
|||||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StreamTimeoutSettings 流超时处理配置 DTO
|
||||||
|
type StreamTimeoutSettings struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Action string `json:"action"`
|
||||||
|
TempUnschedMinutes int `json:"temp_unsched_minutes"`
|
||||||
|
ThresholdCount int `json:"threshold_count"`
|
||||||
|
ThresholdWindowMinutes int `json:"threshold_window_minutes"`
|
||||||
|
}
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ type Account struct {
|
|||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
Concurrency int `json:"concurrency"`
|
Concurrency int `json:"concurrency"`
|
||||||
Priority int `json:"priority"`
|
Priority int `json:"priority"`
|
||||||
|
RateMultiplier float64 `json:"rate_multiplier"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
ErrorMessage string `json:"error_message"`
|
ErrorMessage string `json:"error_message"`
|
||||||
LastUsedAt *time.Time `json:"last_used_at"`
|
LastUsedAt *time.Time `json:"last_used_at"`
|
||||||
@@ -129,7 +130,23 @@ type Proxy struct {
|
|||||||
|
|
||||||
type ProxyWithAccountCount struct {
|
type ProxyWithAccountCount struct {
|
||||||
Proxy
|
Proxy
|
||||||
AccountCount int64 `json:"account_count"`
|
AccountCount int64 `json:"account_count"`
|
||||||
|
LatencyMs *int64 `json:"latency_ms,omitempty"`
|
||||||
|
LatencyStatus string `json:"latency_status,omitempty"`
|
||||||
|
LatencyMessage string `json:"latency_message,omitempty"`
|
||||||
|
IPAddress string `json:"ip_address,omitempty"`
|
||||||
|
Country string `json:"country,omitempty"`
|
||||||
|
CountryCode string `json:"country_code,omitempty"`
|
||||||
|
Region string `json:"region,omitempty"`
|
||||||
|
City string `json:"city,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProxyAccountSummary struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Platform string `json:"platform"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Notes *string `json:"notes,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type RedeemCode struct {
|
type RedeemCode struct {
|
||||||
@@ -169,13 +186,14 @@ type UsageLog struct {
|
|||||||
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
|
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
|
||||||
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
|
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
|
||||||
|
|
||||||
InputCost float64 `json:"input_cost"`
|
InputCost float64 `json:"input_cost"`
|
||||||
OutputCost float64 `json:"output_cost"`
|
OutputCost float64 `json:"output_cost"`
|
||||||
CacheCreationCost float64 `json:"cache_creation_cost"`
|
CacheCreationCost float64 `json:"cache_creation_cost"`
|
||||||
CacheReadCost float64 `json:"cache_read_cost"`
|
CacheReadCost float64 `json:"cache_read_cost"`
|
||||||
TotalCost float64 `json:"total_cost"`
|
TotalCost float64 `json:"total_cost"`
|
||||||
ActualCost float64 `json:"actual_cost"`
|
ActualCost float64 `json:"actual_cost"`
|
||||||
RateMultiplier float64 `json:"rate_multiplier"`
|
RateMultiplier float64 `json:"rate_multiplier"`
|
||||||
|
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
|
||||||
|
|
||||||
BillingType int8 `json:"billing_type"`
|
BillingType int8 `json:"billing_type"`
|
||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
@@ -88,6 +89,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查是否为 Claude Code 客户端,设置到 context 中
|
||||||
|
SetClaudeCodeClientContext(c, body)
|
||||||
|
|
||||||
setOpsRequestContext(c, "", false, body)
|
setOpsRequestContext(c, "", false, body)
|
||||||
|
|
||||||
parsedReq, err := service.ParseGatewayRequest(body)
|
parsedReq, err := service.ParseGatewayRequest(body)
|
||||||
@@ -271,12 +275,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
failedAccountIDs[account.ID] = struct{}{}
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
|
lastFailoverStatus = failoverErr.StatusCode
|
||||||
if switchCount >= maxAccountSwitches {
|
if switchCount >= maxAccountSwitches {
|
||||||
lastFailoverStatus = failoverErr.StatusCode
|
|
||||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
lastFailoverStatus = failoverErr.StatusCode
|
|
||||||
switchCount++
|
switchCount++
|
||||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||||
continue
|
continue
|
||||||
@@ -286,8 +289,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
|
userAgent := c.GetHeader("User-Agent")
|
||||||
|
clientIP := ip.GetClientIP(c)
|
||||||
|
|
||||||
// 异步记录使用量(subscription已在函数开头获取)
|
// 异步记录使用量(subscription已在函数开头获取)
|
||||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
@@ -296,10 +303,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: usedAccount,
|
Account: usedAccount,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
|
UserAgent: ua,
|
||||||
|
IPAddress: clientIP,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
log.Printf("Record usage failed: %v", err)
|
log.Printf("Record usage failed: %v", err)
|
||||||
}
|
}
|
||||||
}(result, account)
|
}(result, account, userAgent, clientIP)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -399,12 +408,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
failedAccountIDs[account.ID] = struct{}{}
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
|
lastFailoverStatus = failoverErr.StatusCode
|
||||||
if switchCount >= maxAccountSwitches {
|
if switchCount >= maxAccountSwitches {
|
||||||
lastFailoverStatus = failoverErr.StatusCode
|
|
||||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
lastFailoverStatus = failoverErr.StatusCode
|
|
||||||
switchCount++
|
switchCount++
|
||||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||||
continue
|
continue
|
||||||
@@ -414,8 +422,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
|
userAgent := c.GetHeader("User-Agent")
|
||||||
|
clientIP := ip.GetClientIP(c)
|
||||||
|
|
||||||
// 异步记录使用量(subscription已在函数开头获取)
|
// 异步记录使用量(subscription已在函数开头获取)
|
||||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
@@ -424,10 +436,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: usedAccount,
|
Account: usedAccount,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
|
UserAgent: ua,
|
||||||
|
IPAddress: clientIP,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
log.Printf("Record usage failed: %v", err)
|
log.Printf("Record usage failed: %v", err)
|
||||||
}
|
}
|
||||||
}(result, account)
|
}(result, account, userAgent, clientIP)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
@@ -314,8 +315,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
|
userAgent := c.GetHeader("User-Agent")
|
||||||
|
clientIP := ip.GetClientIP(c)
|
||||||
|
|
||||||
// 6) record usage async
|
// 6) record usage async
|
||||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
@@ -324,10 +329,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: usedAccount,
|
Account: usedAccount,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
|
UserAgent: ua,
|
||||||
|
IPAddress: ip,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
log.Printf("Record usage failed: %v", err)
|
log.Printf("Record usage failed: %v", err)
|
||||||
}
|
}
|
||||||
}(result, account)
|
}(result, account, userAgent, clientIP)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
@@ -113,6 +114,26 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
|
|
||||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
|
|
||||||
|
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
|
||||||
|
// 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call,
|
||||||
|
// 或带 id 且与 call_id 匹配的 item_reference。
|
||||||
|
if service.HasFunctionCallOutput(reqBody) {
|
||||||
|
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
||||||
|
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
|
||||||
|
if service.HasFunctionCallOutputMissingCallID(reqBody) {
|
||||||
|
log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel)
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
callIDs := service.FunctionCallOutputCallIDs(reqBody)
|
||||||
|
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
|
||||||
|
log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel)
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Track if we've started streaming (for error handling)
|
// Track if we've started streaming (for error handling)
|
||||||
streamStarted := false
|
streamStarted := false
|
||||||
|
|
||||||
@@ -263,8 +284,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
|
userAgent := c.GetHeader("User-Agent")
|
||||||
|
clientIP := ip.GetClientIP(c)
|
||||||
|
|
||||||
// Async record usage
|
// Async record usage
|
||||||
go func(result *service.OpenAIForwardResult, usedAccount *service.Account) {
|
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua, ip string) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
@@ -273,10 +298,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: usedAccount,
|
Account: usedAccount,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
|
UserAgent: ua,
|
||||||
|
IPAddress: ip,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
log.Printf("Record usage failed: %v", err)
|
log.Printf("Record usage failed: %v", err)
|
||||||
}
|
}
|
||||||
}(result, account)
|
}(result, account, userAgent, clientIP)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -489,6 +490,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
|||||||
Severity: classifyOpsSeverity("upstream_error", effectiveUpstreamStatus),
|
Severity: classifyOpsSeverity("upstream_error", effectiveUpstreamStatus),
|
||||||
StatusCode: status,
|
StatusCode: status,
|
||||||
IsBusinessLimited: false,
|
IsBusinessLimited: false,
|
||||||
|
IsCountTokens: isCountTokensRequest(c),
|
||||||
|
|
||||||
ErrorMessage: recoveredMsg,
|
ErrorMessage: recoveredMsg,
|
||||||
ErrorBody: "",
|
ErrorBody: "",
|
||||||
@@ -521,7 +523,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var clientIP string
|
var clientIP string
|
||||||
if ip := strings.TrimSpace(c.ClientIP()); ip != "" {
|
if ip := strings.TrimSpace(ip.GetClientIP(c)); ip != "" {
|
||||||
clientIP = ip
|
clientIP = ip
|
||||||
entry.ClientIP = &clientIP
|
entry.ClientIP = &clientIP
|
||||||
}
|
}
|
||||||
@@ -542,6 +544,11 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
|||||||
body := w.buf.Bytes()
|
body := w.buf.Bytes()
|
||||||
parsed := parseOpsErrorResponse(body)
|
parsed := parseOpsErrorResponse(body)
|
||||||
|
|
||||||
|
// Skip logging if the error should be filtered based on settings
|
||||||
|
if shouldSkipOpsErrorLog(c.Request.Context(), ops, parsed.Message, string(body), c.Request.URL.Path) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
apiKey, _ := middleware2.GetAPIKeyFromContext(c)
|
apiKey, _ := middleware2.GetAPIKeyFromContext(c)
|
||||||
|
|
||||||
clientRequestID, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
|
clientRequestID, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
|
||||||
@@ -598,6 +605,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
|||||||
Severity: classifyOpsSeverity(parsed.ErrorType, status),
|
Severity: classifyOpsSeverity(parsed.ErrorType, status),
|
||||||
StatusCode: status,
|
StatusCode: status,
|
||||||
IsBusinessLimited: isBusinessLimited,
|
IsBusinessLimited: isBusinessLimited,
|
||||||
|
IsCountTokens: isCountTokensRequest(c),
|
||||||
|
|
||||||
ErrorMessage: parsed.Message,
|
ErrorMessage: parsed.Message,
|
||||||
// Keep the full captured error body (capture is already capped at 64KB) so the
|
// Keep the full captured error body (capture is already capped at 64KB) so the
|
||||||
@@ -680,7 +688,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var clientIP string
|
var clientIP string
|
||||||
if ip := strings.TrimSpace(c.ClientIP()); ip != "" {
|
if ip := strings.TrimSpace(ip.GetClientIP(c)); ip != "" {
|
||||||
clientIP = ip
|
clientIP = ip
|
||||||
entry.ClientIP = &clientIP
|
entry.ClientIP = &clientIP
|
||||||
}
|
}
|
||||||
@@ -704,6 +712,14 @@ var opsRetryRequestHeaderAllowlist = []string{
|
|||||||
"anthropic-version",
|
"anthropic-version",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isCountTokensRequest checks if the request is a count_tokens request
|
||||||
|
func isCountTokensRequest(c *gin.Context) bool {
|
||||||
|
if c == nil || c.Request == nil || c.Request.URL == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.Contains(c.Request.URL.Path, "/count_tokens")
|
||||||
|
}
|
||||||
|
|
||||||
func extractOpsRetryRequestHeaders(c *gin.Context) *string {
|
func extractOpsRetryRequestHeaders(c *gin.Context) *string {
|
||||||
if c == nil || c.Request == nil {
|
if c == nil || c.Request == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -821,28 +837,30 @@ func normalizeOpsErrorType(errType string, code string) string {
|
|||||||
|
|
||||||
func classifyOpsPhase(errType, message, code string) string {
|
func classifyOpsPhase(errType, message, code string) string {
|
||||||
msg := strings.ToLower(message)
|
msg := strings.ToLower(message)
|
||||||
|
// Standardized phases: request|auth|routing|upstream|network|internal
|
||||||
|
// Map billing/concurrency/response => request; scheduling => routing.
|
||||||
switch strings.TrimSpace(code) {
|
switch strings.TrimSpace(code) {
|
||||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
||||||
return "billing"
|
return "request"
|
||||||
}
|
}
|
||||||
|
|
||||||
switch errType {
|
switch errType {
|
||||||
case "authentication_error":
|
case "authentication_error":
|
||||||
return "auth"
|
return "auth"
|
||||||
case "billing_error", "subscription_error":
|
case "billing_error", "subscription_error":
|
||||||
return "billing"
|
return "request"
|
||||||
case "rate_limit_error":
|
case "rate_limit_error":
|
||||||
if strings.Contains(msg, "concurrency") || strings.Contains(msg, "pending") || strings.Contains(msg, "queue") {
|
if strings.Contains(msg, "concurrency") || strings.Contains(msg, "pending") || strings.Contains(msg, "queue") {
|
||||||
return "concurrency"
|
return "request"
|
||||||
}
|
}
|
||||||
return "upstream"
|
return "upstream"
|
||||||
case "invalid_request_error":
|
case "invalid_request_error":
|
||||||
return "response"
|
return "request"
|
||||||
case "upstream_error", "overloaded_error":
|
case "upstream_error", "overloaded_error":
|
||||||
return "upstream"
|
return "upstream"
|
||||||
case "api_error":
|
case "api_error":
|
||||||
if strings.Contains(msg, "no available accounts") {
|
if strings.Contains(msg, "no available accounts") {
|
||||||
return "scheduling"
|
return "routing"
|
||||||
}
|
}
|
||||||
return "internal"
|
return "internal"
|
||||||
default:
|
default:
|
||||||
@@ -903,34 +921,38 @@ func classifyOpsIsBusinessLimited(errType, phase, code string, status int, messa
|
|||||||
}
|
}
|
||||||
|
|
||||||
func classifyOpsErrorOwner(phase string, message string) string {
|
func classifyOpsErrorOwner(phase string, message string) string {
|
||||||
|
// Standardized owners: client|provider|platform
|
||||||
switch phase {
|
switch phase {
|
||||||
case "upstream", "network":
|
case "upstream", "network":
|
||||||
return "provider"
|
return "provider"
|
||||||
case "billing", "concurrency", "auth", "response":
|
case "request", "auth":
|
||||||
return "client"
|
return "client"
|
||||||
|
case "routing", "internal":
|
||||||
|
return "platform"
|
||||||
default:
|
default:
|
||||||
if strings.Contains(strings.ToLower(message), "upstream") {
|
if strings.Contains(strings.ToLower(message), "upstream") {
|
||||||
return "provider"
|
return "provider"
|
||||||
}
|
}
|
||||||
return "sub2api"
|
return "platform"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func classifyOpsErrorSource(phase string, message string) string {
|
func classifyOpsErrorSource(phase string, message string) string {
|
||||||
|
// Standardized sources: client_request|upstream_http|gateway
|
||||||
switch phase {
|
switch phase {
|
||||||
case "upstream":
|
case "upstream":
|
||||||
return "upstream_http"
|
return "upstream_http"
|
||||||
case "network":
|
case "network":
|
||||||
return "upstream_network"
|
return "gateway"
|
||||||
case "billing":
|
case "request", "auth":
|
||||||
return "billing"
|
return "client_request"
|
||||||
case "concurrency":
|
case "routing", "internal":
|
||||||
return "concurrency"
|
return "gateway"
|
||||||
default:
|
default:
|
||||||
if strings.Contains(strings.ToLower(message), "upstream") {
|
if strings.Contains(strings.ToLower(message), "upstream") {
|
||||||
return "upstream_http"
|
return "upstream_http"
|
||||||
}
|
}
|
||||||
return "internal"
|
return "gateway"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -952,3 +974,42 @@ func truncateString(s string, max int) string {
|
|||||||
func strconvItoa(v int) string {
|
func strconvItoa(v int) string {
|
||||||
return strconv.Itoa(v)
|
return strconv.Itoa(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// shouldSkipOpsErrorLog determines if an error should be skipped from logging based on settings.
|
||||||
|
// Returns true for errors that should be filtered according to OpsAdvancedSettings.
|
||||||
|
func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message, body, requestPath string) bool {
|
||||||
|
if ops == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get advanced settings to check filter configuration
|
||||||
|
settings, err := ops.GetOpsAdvancedSettings(ctx)
|
||||||
|
if err != nil || settings == nil {
|
||||||
|
// If we can't get settings, don't skip (fail open)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
msgLower := strings.ToLower(message)
|
||||||
|
bodyLower := strings.ToLower(body)
|
||||||
|
|
||||||
|
// Check if count_tokens errors should be ignored
|
||||||
|
if settings.IgnoreCountTokensErrors && strings.Contains(requestPath, "/count_tokens") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if context canceled errors should be ignored (client disconnects)
|
||||||
|
if settings.IgnoreContextCanceled {
|
||||||
|
if strings.Contains(msgLower, "context canceled") || strings.Contains(bodyLower, "context canceled") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if "no available accounts" errors should be ignored
|
||||||
|
if settings.IgnoreNoAvailableAccounts {
|
||||||
|
if strings.Contains(msgLower, "no available accounts") || strings.Contains(bodyLower, "no available accounts") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,7 +2,10 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -25,15 +28,34 @@ type RateLimitOptions struct {
|
|||||||
var rateLimitScript = redis.NewScript(`
|
var rateLimitScript = redis.NewScript(`
|
||||||
local current = redis.call('INCR', KEYS[1])
|
local current = redis.call('INCR', KEYS[1])
|
||||||
local ttl = redis.call('PTTL', KEYS[1])
|
local ttl = redis.call('PTTL', KEYS[1])
|
||||||
if current == 1 or ttl == -1 then
|
local repaired = 0
|
||||||
|
if current == 1 then
|
||||||
redis.call('PEXPIRE', KEYS[1], ARGV[1])
|
redis.call('PEXPIRE', KEYS[1], ARGV[1])
|
||||||
|
elseif ttl == -1 then
|
||||||
|
redis.call('PEXPIRE', KEYS[1], ARGV[1])
|
||||||
|
repaired = 1
|
||||||
end
|
end
|
||||||
return current
|
return {current, repaired}
|
||||||
`)
|
`)
|
||||||
|
|
||||||
// rateLimitRun 允许测试覆写脚本执行逻辑
|
// rateLimitRun 允许测试覆写脚本执行逻辑
|
||||||
var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) {
|
var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) {
|
||||||
return rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Int64()
|
values, err := rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Slice()
|
||||||
|
if err != nil {
|
||||||
|
return 0, false, err
|
||||||
|
}
|
||||||
|
if len(values) < 2 {
|
||||||
|
return 0, false, fmt.Errorf("rate limit script returned %d values", len(values))
|
||||||
|
}
|
||||||
|
count, err := parseInt64(values[0])
|
||||||
|
if err != nil {
|
||||||
|
return 0, false, err
|
||||||
|
}
|
||||||
|
repaired, err := parseInt64(values[1])
|
||||||
|
if err != nil {
|
||||||
|
return 0, false, err
|
||||||
|
}
|
||||||
|
return count, repaired == 1, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RateLimiter Redis 速率限制器
|
// RateLimiter Redis 速率限制器
|
||||||
@@ -74,8 +96,9 @@ func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Durati
|
|||||||
windowMillis := windowTTLMillis(window)
|
windowMillis := windowTTLMillis(window)
|
||||||
|
|
||||||
// 使用 Lua 脚本原子操作增加计数并设置过期
|
// 使用 Lua 脚本原子操作增加计数并设置过期
|
||||||
count, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis)
|
count, repaired, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Printf("[RateLimit] redis error: key=%s mode=%s err=%v", redisKey, failureModeLabel(failureMode), err)
|
||||||
if failureMode == RateLimitFailClose {
|
if failureMode == RateLimitFailClose {
|
||||||
abortRateLimit(c)
|
abortRateLimit(c)
|
||||||
return
|
return
|
||||||
@@ -84,6 +107,9 @@ func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Durati
|
|||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if repaired {
|
||||||
|
log.Printf("[RateLimit] ttl repaired: key=%s window_ms=%d", redisKey, windowMillis)
|
||||||
|
}
|
||||||
|
|
||||||
// 超过限制
|
// 超过限制
|
||||||
if count > int64(limit) {
|
if count > int64(limit) {
|
||||||
@@ -109,3 +135,27 @@ func abortRateLimit(c *gin.Context) {
|
|||||||
"message": "Too many requests, please try again later",
|
"message": "Too many requests, please try again later",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func failureModeLabel(mode RateLimitFailureMode) string {
|
||||||
|
if mode == RateLimitFailClose {
|
||||||
|
return "fail-close"
|
||||||
|
}
|
||||||
|
return "fail-open"
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseInt64(value any) (int64, error) {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case int64:
|
||||||
|
return v, nil
|
||||||
|
case int:
|
||||||
|
return int64(v), nil
|
||||||
|
case string:
|
||||||
|
parsed, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return parsed, nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("unexpected value type %T", value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -66,13 +66,13 @@ func TestRateLimiterSuccessAndLimit(t *testing.T) {
|
|||||||
originalRun := rateLimitRun
|
originalRun := rateLimitRun
|
||||||
counts := []int64{1, 2}
|
counts := []int64{1, 2}
|
||||||
callIndex := 0
|
callIndex := 0
|
||||||
rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) {
|
rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) {
|
||||||
if callIndex >= len(counts) {
|
if callIndex >= len(counts) {
|
||||||
return counts[len(counts)-1], nil
|
return counts[len(counts)-1], false, nil
|
||||||
}
|
}
|
||||||
value := counts[callIndex]
|
value := counts[callIndex]
|
||||||
callIndex++
|
callIndex++
|
||||||
return value, nil
|
return value, false, nil
|
||||||
}
|
}
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
rateLimitRun = originalRun
|
rateLimitRun = originalRun
|
||||||
|
|||||||
@@ -1,8 +1,14 @@
|
|||||||
package usagestats
|
package usagestats
|
||||||
|
|
||||||
// AccountStats 账号使用统计
|
// AccountStats 账号使用统计
|
||||||
|
//
|
||||||
|
// cost: 账号口径费用(使用 total_cost * account_rate_multiplier)
|
||||||
|
// standard_cost: 标准费用(使用 total_cost,不含倍率)
|
||||||
|
// user_cost: 用户/API Key 口径费用(使用 actual_cost,受分组倍率影响)
|
||||||
type AccountStats struct {
|
type AccountStats struct {
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
Tokens int64 `json:"tokens"`
|
Tokens int64 `json:"tokens"`
|
||||||
Cost float64 `json:"cost"`
|
Cost float64 `json:"cost"`
|
||||||
|
StandardCost float64 `json:"standard_cost"`
|
||||||
|
UserCost float64 `json:"user_cost"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -147,14 +147,15 @@ type UsageLogFilters struct {
|
|||||||
|
|
||||||
// UsageStats represents usage statistics
|
// UsageStats represents usage statistics
|
||||||
type UsageStats struct {
|
type UsageStats struct {
|
||||||
TotalRequests int64 `json:"total_requests"`
|
TotalRequests int64 `json:"total_requests"`
|
||||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||||
TotalTokens int64 `json:"total_tokens"`
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
TotalCost float64 `json:"total_cost"`
|
TotalCost float64 `json:"total_cost"`
|
||||||
TotalActualCost float64 `json:"total_actual_cost"`
|
TotalActualCost float64 `json:"total_actual_cost"`
|
||||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
|
||||||
|
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// BatchUserUsageStats represents usage stats for a single user
|
// BatchUserUsageStats represents usage stats for a single user
|
||||||
@@ -177,25 +178,29 @@ type AccountUsageHistory struct {
|
|||||||
Label string `json:"label"`
|
Label string `json:"label"`
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
Tokens int64 `json:"tokens"`
|
Tokens int64 `json:"tokens"`
|
||||||
Cost float64 `json:"cost"`
|
Cost float64 `json:"cost"` // 标准计费(total_cost)
|
||||||
ActualCost float64 `json:"actual_cost"`
|
ActualCost float64 `json:"actual_cost"` // 账号口径费用(total_cost * account_rate_multiplier)
|
||||||
|
UserCost float64 `json:"user_cost"` // 用户口径费用(actual_cost,受分组倍率影响)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AccountUsageSummary represents summary statistics for an account
|
// AccountUsageSummary represents summary statistics for an account
|
||||||
type AccountUsageSummary struct {
|
type AccountUsageSummary struct {
|
||||||
Days int `json:"days"`
|
Days int `json:"days"`
|
||||||
ActualDaysUsed int `json:"actual_days_used"`
|
ActualDaysUsed int `json:"actual_days_used"`
|
||||||
TotalCost float64 `json:"total_cost"`
|
TotalCost float64 `json:"total_cost"` // 账号口径费用
|
||||||
|
TotalUserCost float64 `json:"total_user_cost"` // 用户口径费用
|
||||||
TotalStandardCost float64 `json:"total_standard_cost"`
|
TotalStandardCost float64 `json:"total_standard_cost"`
|
||||||
TotalRequests int64 `json:"total_requests"`
|
TotalRequests int64 `json:"total_requests"`
|
||||||
TotalTokens int64 `json:"total_tokens"`
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
AvgDailyCost float64 `json:"avg_daily_cost"`
|
AvgDailyCost float64 `json:"avg_daily_cost"` // 账号口径日均
|
||||||
|
AvgDailyUserCost float64 `json:"avg_daily_user_cost"`
|
||||||
AvgDailyRequests float64 `json:"avg_daily_requests"`
|
AvgDailyRequests float64 `json:"avg_daily_requests"`
|
||||||
AvgDailyTokens float64 `json:"avg_daily_tokens"`
|
AvgDailyTokens float64 `json:"avg_daily_tokens"`
|
||||||
AvgDurationMs float64 `json:"avg_duration_ms"`
|
AvgDurationMs float64 `json:"avg_duration_ms"`
|
||||||
Today *struct {
|
Today *struct {
|
||||||
Date string `json:"date"`
|
Date string `json:"date"`
|
||||||
Cost float64 `json:"cost"`
|
Cost float64 `json:"cost"`
|
||||||
|
UserCost float64 `json:"user_cost"`
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
Tokens int64 `json:"tokens"`
|
Tokens int64 `json:"tokens"`
|
||||||
} `json:"today"`
|
} `json:"today"`
|
||||||
@@ -203,6 +208,7 @@ type AccountUsageSummary struct {
|
|||||||
Date string `json:"date"`
|
Date string `json:"date"`
|
||||||
Label string `json:"label"`
|
Label string `json:"label"`
|
||||||
Cost float64 `json:"cost"`
|
Cost float64 `json:"cost"`
|
||||||
|
UserCost float64 `json:"user_cost"`
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
} `json:"highest_cost_day"`
|
} `json:"highest_cost_day"`
|
||||||
HighestRequestDay *struct {
|
HighestRequestDay *struct {
|
||||||
@@ -210,6 +216,7 @@ type AccountUsageSummary struct {
|
|||||||
Label string `json:"label"`
|
Label string `json:"label"`
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
Cost float64 `json:"cost"`
|
Cost float64 `json:"cost"`
|
||||||
|
UserCost float64 `json:"user_cost"`
|
||||||
} `json:"highest_request_day"`
|
} `json:"highest_request_day"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"log"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -79,6 +80,10 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
|
|||||||
SetSchedulable(account.Schedulable).
|
SetSchedulable(account.Schedulable).
|
||||||
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
|
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
|
||||||
|
|
||||||
|
if account.RateMultiplier != nil {
|
||||||
|
builder.SetRateMultiplier(*account.RateMultiplier)
|
||||||
|
}
|
||||||
|
|
||||||
if account.ProxyID != nil {
|
if account.ProxyID != nil {
|
||||||
builder.SetProxyID(*account.ProxyID)
|
builder.SetProxyID(*account.ProxyID)
|
||||||
}
|
}
|
||||||
@@ -115,6 +120,9 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
|
|||||||
account.ID = created.ID
|
account.ID = created.ID
|
||||||
account.CreatedAt = created.CreatedAt
|
account.CreatedAt = created.CreatedAt
|
||||||
account.UpdatedAt = created.UpdatedAt
|
account.UpdatedAt = created.UpdatedAt
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -287,6 +295,10 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
|||||||
SetSchedulable(account.Schedulable).
|
SetSchedulable(account.Schedulable).
|
||||||
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
|
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
|
||||||
|
|
||||||
|
if account.RateMultiplier != nil {
|
||||||
|
builder.SetRateMultiplier(*account.RateMultiplier)
|
||||||
|
}
|
||||||
|
|
||||||
if account.ProxyID != nil {
|
if account.ProxyID != nil {
|
||||||
builder.SetProxyID(*account.ProxyID)
|
builder.SetProxyID(*account.ProxyID)
|
||||||
} else {
|
} else {
|
||||||
@@ -341,10 +353,17 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
|||||||
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
|
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
|
||||||
}
|
}
|
||||||
account.UpdatedAt = updated.UpdatedAt
|
account.UpdatedAt = updated.UpdatedAt
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
||||||
|
groupIDs, err := r.loadAccountGroupIDs(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
// 使用事务保证账号与关联分组的删除原子性
|
// 使用事务保证账号与关联分组的删除原子性
|
||||||
tx, err := r.client.Tx(ctx)
|
tx, err := r.client.Tx(ctx)
|
||||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||||
@@ -368,7 +387,12 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if tx != nil {
|
if tx != nil {
|
||||||
return tx.Commit()
|
if err := tx.Commit(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, buildSchedulerGroupPayload(groupIDs)); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -455,7 +479,18 @@ func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error
|
|||||||
Where(dbaccount.IDEQ(id)).
|
Where(dbaccount.IDEQ(id)).
|
||||||
SetLastUsedAt(now).
|
SetLastUsedAt(now).
|
||||||
Save(ctx)
|
Save(ctx)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
payload := map[string]any{
|
||||||
|
"last_used": map[string]int64{
|
||||||
|
strconv.FormatInt(id, 10): now.Unix(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, &id, nil, payload); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||||
@@ -479,7 +514,18 @@ func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map
|
|||||||
args = append(args, pq.Array(ids))
|
args = append(args, pq.Array(ids))
|
||||||
|
|
||||||
_, err := r.sql.ExecContext(ctx, caseSQL, args...)
|
_, err := r.sql.ExecContext(ctx, caseSQL, args...)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
lastUsedPayload := make(map[string]int64, len(updates))
|
||||||
|
for id, ts := range updates {
|
||||||
|
lastUsedPayload[strconv.FormatInt(id, 10)] = ts.Unix()
|
||||||
|
}
|
||||||
|
payload := map[string]any{"last_used": lastUsedPayload}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, nil, nil, payload); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue batch last used failed: err=%v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
|
func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||||
@@ -488,7 +534,13 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
|
|||||||
SetStatus(service.StatusError).
|
SetStatus(service.StatusError).
|
||||||
SetErrorMessage(errorMsg).
|
SetErrorMessage(errorMsg).
|
||||||
Save(ctx)
|
Save(ctx)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
|
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
|
||||||
@@ -497,7 +549,14 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i
|
|||||||
SetGroupID(groupID).
|
SetGroupID(groupID).
|
||||||
SetPriority(priority).
|
SetPriority(priority).
|
||||||
Save(ctx)
|
Save(ctx)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
payload := buildSchedulerGroupPayload([]int64{groupID})
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
|
func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
|
||||||
@@ -507,7 +566,14 @@ func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, grou
|
|||||||
dbaccountgroup.GroupIDEQ(groupID),
|
dbaccountgroup.GroupIDEQ(groupID),
|
||||||
).
|
).
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
payload := buildSchedulerGroupPayload([]int64{groupID})
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) {
|
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) {
|
||||||
@@ -528,6 +594,10 @@ func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]s
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||||
|
existingGroupIDs, err := r.loadAccountGroupIDs(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
// 使用事务保证删除旧绑定与创建新绑定的原子性
|
// 使用事务保证删除旧绑定与创建新绑定的原子性
|
||||||
tx, err := r.client.Tx(ctx)
|
tx, err := r.client.Tx(ctx)
|
||||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||||
@@ -568,7 +638,13 @@ func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, gro
|
|||||||
}
|
}
|
||||||
|
|
||||||
if tx != nil {
|
if tx != nil {
|
||||||
return tx.Commit()
|
if err := tx.Commit(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
payload := buildSchedulerGroupPayload(mergeGroupIDs(existingGroupIDs, groupIDs))
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -672,7 +748,13 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
|||||||
SetRateLimitedAt(now).
|
SetRateLimitedAt(now).
|
||||||
SetRateLimitResetAt(resetAt).
|
SetRateLimitResetAt(resetAt).
|
||||||
Save(ctx)
|
Save(ctx)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
|
func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
|
||||||
@@ -706,6 +788,9 @@ func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, i
|
|||||||
if affected == 0 {
|
if affected == 0 {
|
||||||
return service.ErrAccountNotFound
|
return service.ErrAccountNotFound
|
||||||
}
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -714,7 +799,13 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
|
|||||||
Where(dbaccount.IDEQ(id)).
|
Where(dbaccount.IDEQ(id)).
|
||||||
SetOverloadUntil(until).
|
SetOverloadUntil(until).
|
||||||
Save(ctx)
|
Save(ctx)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||||
@@ -727,7 +818,13 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
|
|||||||
AND deleted_at IS NULL
|
AND deleted_at IS NULL
|
||||||
AND (temp_unschedulable_until IS NULL OR temp_unschedulable_until < $1)
|
AND (temp_unschedulable_until IS NULL OR temp_unschedulable_until < $1)
|
||||||
`, until, reason, id)
|
`, until, reason, id)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64) error {
|
func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64) error {
|
||||||
@@ -739,7 +836,13 @@ func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64
|
|||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
AND deleted_at IS NULL
|
AND deleted_at IS NULL
|
||||||
`, id)
|
`, id)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
|
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
|
||||||
@@ -749,7 +852,13 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
|||||||
ClearRateLimitResetAt().
|
ClearRateLimitResetAt().
|
||||||
ClearOverloadUntil().
|
ClearOverloadUntil().
|
||||||
Save(ctx)
|
Save(ctx)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||||
@@ -770,6 +879,9 @@ func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id
|
|||||||
if affected == 0 {
|
if affected == 0 {
|
||||||
return service.ErrAccountNotFound
|
return service.ErrAccountNotFound
|
||||||
}
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -792,7 +904,13 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
|
|||||||
Where(dbaccount.IDEQ(id)).
|
Where(dbaccount.IDEQ(id)).
|
||||||
SetSchedulable(schedulable).
|
SetSchedulable(schedulable).
|
||||||
Save(ctx)
|
Save(ctx)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
|
func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
|
||||||
@@ -813,6 +931,11 @@ func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now ti
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
if rows > 0 {
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventFullRebuild, nil, nil, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
return rows, nil
|
return rows, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -844,6 +967,9 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
|
|||||||
if affected == 0 {
|
if affected == 0 {
|
||||||
return service.ErrAccountNotFound
|
return service.ErrAccountNotFound
|
||||||
}
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -881,6 +1007,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
|||||||
args = append(args, *updates.Priority)
|
args = append(args, *updates.Priority)
|
||||||
idx++
|
idx++
|
||||||
}
|
}
|
||||||
|
if updates.RateMultiplier != nil {
|
||||||
|
setClauses = append(setClauses, "rate_multiplier = $"+itoa(idx))
|
||||||
|
args = append(args, *updates.RateMultiplier)
|
||||||
|
idx++
|
||||||
|
}
|
||||||
if updates.Status != nil {
|
if updates.Status != nil {
|
||||||
setClauses = append(setClauses, "status = $"+itoa(idx))
|
setClauses = append(setClauses, "status = $"+itoa(idx))
|
||||||
args = append(args, *updates.Status)
|
args = append(args, *updates.Status)
|
||||||
@@ -928,6 +1059,12 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
if rows > 0 {
|
||||||
|
payload := map[string]any{"account_ids": ids}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
return rows, nil
|
return rows, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1170,11 +1307,61 @@ func (r *accountRepository) loadAccountGroups(ctx context.Context, accountIDs []
|
|||||||
return groupsByAccount, groupIDsByAccount, accountGroupsByAccount, nil
|
return groupsByAccount, groupIDsByAccount, accountGroupsByAccount, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *accountRepository) loadAccountGroupIDs(ctx context.Context, accountID int64) ([]int64, error) {
|
||||||
|
entries, err := r.client.AccountGroup.
|
||||||
|
Query().
|
||||||
|
Where(dbaccountgroup.AccountIDEQ(accountID)).
|
||||||
|
All(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ids := make([]int64, 0, len(entries))
|
||||||
|
for _, entry := range entries {
|
||||||
|
ids = append(ids, entry.GroupID)
|
||||||
|
}
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeGroupIDs(a []int64, b []int64) []int64 {
|
||||||
|
seen := make(map[int64]struct{}, len(a)+len(b))
|
||||||
|
out := make([]int64, 0, len(a)+len(b))
|
||||||
|
for _, id := range a {
|
||||||
|
if id <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[id]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[id] = struct{}{}
|
||||||
|
out = append(out, id)
|
||||||
|
}
|
||||||
|
for _, id := range b {
|
||||||
|
if id <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[id]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[id] = struct{}{}
|
||||||
|
out = append(out, id)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildSchedulerGroupPayload(groupIDs []int64) map[string]any {
|
||||||
|
if len(groupIDs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return map[string]any{"group_ids": groupIDs}
|
||||||
|
}
|
||||||
|
|
||||||
func accountEntityToService(m *dbent.Account) *service.Account {
|
func accountEntityToService(m *dbent.Account) *service.Account {
|
||||||
if m == nil {
|
if m == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rateMultiplier := m.RateMultiplier
|
||||||
|
|
||||||
return &service.Account{
|
return &service.Account{
|
||||||
ID: m.ID,
|
ID: m.ID,
|
||||||
Name: m.Name,
|
Name: m.Name,
|
||||||
@@ -1186,6 +1373,7 @@ func accountEntityToService(m *dbent.Account) *service.Account {
|
|||||||
ProxyID: m.ProxyID,
|
ProxyID: m.ProxyID,
|
||||||
Concurrency: m.Concurrency,
|
Concurrency: m.Concurrency,
|
||||||
Priority: m.Priority,
|
Priority: m.Priority,
|
||||||
|
RateMultiplier: &rateMultiplier,
|
||||||
Status: m.Status,
|
Status: m.Status,
|
||||||
ErrorMessage: derefString(m.ErrorMessage),
|
ErrorMessage: derefString(m.ErrorMessage),
|
||||||
LastUsedAt: m.LastUsedAt,
|
LastUsedAt: m.LastUsedAt,
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
)
|
)
|
||||||
@@ -41,21 +42,22 @@ func isPostgresDriver(db *sql.DB) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
|
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||||
startUTC := start.UTC()
|
loc := timezone.Location()
|
||||||
endUTC := end.UTC()
|
startLocal := start.In(loc)
|
||||||
if !endUTC.After(startUTC) {
|
endLocal := end.In(loc)
|
||||||
|
if !endLocal.After(startLocal) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
hourStart := startUTC.Truncate(time.Hour)
|
hourStart := startLocal.Truncate(time.Hour)
|
||||||
hourEnd := endUTC.Truncate(time.Hour)
|
hourEnd := endLocal.Truncate(time.Hour)
|
||||||
if endUTC.After(hourEnd) {
|
if endLocal.After(hourEnd) {
|
||||||
hourEnd = hourEnd.Add(time.Hour)
|
hourEnd = hourEnd.Add(time.Hour)
|
||||||
}
|
}
|
||||||
|
|
||||||
dayStart := truncateToDayUTC(startUTC)
|
dayStart := truncateToDay(startLocal)
|
||||||
dayEnd := truncateToDayUTC(endUTC)
|
dayEnd := truncateToDay(endLocal)
|
||||||
if endUTC.After(dayEnd) {
|
if endLocal.After(dayEnd) {
|
||||||
dayEnd = dayEnd.Add(24 * time.Hour)
|
dayEnd = dayEnd.Add(24 * time.Hour)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -146,38 +148,41 @@ func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *dashboardAggregationRepository) insertHourlyActiveUsers(ctx context.Context, start, end time.Time) error {
|
func (r *dashboardAggregationRepository) insertHourlyActiveUsers(ctx context.Context, start, end time.Time) error {
|
||||||
|
tzName := timezone.Name()
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO usage_dashboard_hourly_users (bucket_start, user_id)
|
INSERT INTO usage_dashboard_hourly_users (bucket_start, user_id)
|
||||||
SELECT DISTINCT
|
SELECT DISTINCT
|
||||||
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
|
date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
|
||||||
user_id
|
user_id
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE created_at >= $1 AND created_at < $2
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
ON CONFLICT DO NOTHING
|
ON CONFLICT DO NOTHING
|
||||||
`
|
`
|
||||||
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC())
|
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *dashboardAggregationRepository) insertDailyActiveUsers(ctx context.Context, start, end time.Time) error {
|
func (r *dashboardAggregationRepository) insertDailyActiveUsers(ctx context.Context, start, end time.Time) error {
|
||||||
|
tzName := timezone.Name()
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO usage_dashboard_daily_users (bucket_date, user_id)
|
INSERT INTO usage_dashboard_daily_users (bucket_date, user_id)
|
||||||
SELECT DISTINCT
|
SELECT DISTINCT
|
||||||
(bucket_start AT TIME ZONE 'UTC')::date AS bucket_date,
|
(bucket_start AT TIME ZONE $3)::date AS bucket_date,
|
||||||
user_id
|
user_id
|
||||||
FROM usage_dashboard_hourly_users
|
FROM usage_dashboard_hourly_users
|
||||||
WHERE bucket_start >= $1 AND bucket_start < $2
|
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||||
ON CONFLICT DO NOTHING
|
ON CONFLICT DO NOTHING
|
||||||
`
|
`
|
||||||
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC())
|
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Context, start, end time.Time) error {
|
func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Context, start, end time.Time) error {
|
||||||
|
tzName := timezone.Name()
|
||||||
query := `
|
query := `
|
||||||
WITH hourly AS (
|
WITH hourly AS (
|
||||||
SELECT
|
SELECT
|
||||||
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
|
date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
|
||||||
COUNT(*) AS total_requests,
|
COUNT(*) AS total_requests,
|
||||||
COALESCE(SUM(input_tokens), 0) AS input_tokens,
|
COALESCE(SUM(input_tokens), 0) AS input_tokens,
|
||||||
COALESCE(SUM(output_tokens), 0) AS output_tokens,
|
COALESCE(SUM(output_tokens), 0) AS output_tokens,
|
||||||
@@ -236,15 +241,16 @@ func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Cont
|
|||||||
active_users = EXCLUDED.active_users,
|
active_users = EXCLUDED.active_users,
|
||||||
computed_at = EXCLUDED.computed_at
|
computed_at = EXCLUDED.computed_at
|
||||||
`
|
`
|
||||||
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC())
|
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Context, start, end time.Time) error {
|
func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Context, start, end time.Time) error {
|
||||||
|
tzName := timezone.Name()
|
||||||
query := `
|
query := `
|
||||||
WITH daily AS (
|
WITH daily AS (
|
||||||
SELECT
|
SELECT
|
||||||
(bucket_start AT TIME ZONE 'UTC')::date AS bucket_date,
|
(bucket_start AT TIME ZONE $5)::date AS bucket_date,
|
||||||
COALESCE(SUM(total_requests), 0) AS total_requests,
|
COALESCE(SUM(total_requests), 0) AS total_requests,
|
||||||
COALESCE(SUM(input_tokens), 0) AS input_tokens,
|
COALESCE(SUM(input_tokens), 0) AS input_tokens,
|
||||||
COALESCE(SUM(output_tokens), 0) AS output_tokens,
|
COALESCE(SUM(output_tokens), 0) AS output_tokens,
|
||||||
@@ -255,7 +261,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
|
|||||||
COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms
|
COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms
|
||||||
FROM usage_dashboard_hourly
|
FROM usage_dashboard_hourly
|
||||||
WHERE bucket_start >= $1 AND bucket_start < $2
|
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||||
GROUP BY (bucket_start AT TIME ZONE 'UTC')::date
|
GROUP BY (bucket_start AT TIME ZONE $5)::date
|
||||||
),
|
),
|
||||||
user_counts AS (
|
user_counts AS (
|
||||||
SELECT bucket_date, COUNT(*) AS active_users
|
SELECT bucket_date, COUNT(*) AS active_users
|
||||||
@@ -303,7 +309,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
|
|||||||
active_users = EXCLUDED.active_users,
|
active_users = EXCLUDED.active_users,
|
||||||
computed_at = EXCLUDED.computed_at
|
computed_at = EXCLUDED.computed_at
|
||||||
`
|
`
|
||||||
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC(), start.UTC(), end.UTC())
|
_, err := r.sql.ExecContext(ctx, query, start, end, start, end, tzName)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -376,9 +382,8 @@ func (r *dashboardAggregationRepository) createUsageLogsPartition(ctx context.Co
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func truncateToDayUTC(t time.Time) time.Time {
|
func truncateToDay(t time.Time) time.Time {
|
||||||
t = t.UTC()
|
return timezone.StartOfDay(t)
|
||||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func truncateToMonthUTC(t time.Time) time.Time {
|
func truncateToMonthUTC(t time.Time) time.Time {
|
||||||
|
|||||||
@@ -33,6 +33,11 @@ func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string,
|
|||||||
return c.rdb.Set(ctx, key, token, ttl).Err()
|
return c.rdb.Set(ctx, key, token, ttl).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *geminiTokenCache) DeleteAccessToken(ctx context.Context, cacheKey string) error {
|
||||||
|
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
|
|
||||||
func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
|
func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
|
||||||
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
|
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
|
||||||
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
|
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
|
||||||
|
|||||||
@@ -0,0 +1,47 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GeminiTokenCacheSuite struct {
|
||||||
|
IntegrationRedisSuite
|
||||||
|
cache service.GeminiTokenCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GeminiTokenCacheSuite) SetupTest() {
|
||||||
|
s.IntegrationRedisSuite.SetupTest()
|
||||||
|
s.cache = NewGeminiTokenCache(s.rdb)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken() {
|
||||||
|
cacheKey := "project-123"
|
||||||
|
token := "token-value"
|
||||||
|
require.NoError(s.T(), s.cache.SetAccessToken(s.ctx, cacheKey, token, time.Minute))
|
||||||
|
|
||||||
|
got, err := s.cache.GetAccessToken(s.ctx, cacheKey)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Equal(s.T(), token, got)
|
||||||
|
|
||||||
|
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, cacheKey))
|
||||||
|
|
||||||
|
_, err = s.cache.GetAccessToken(s.ctx, cacheKey)
|
||||||
|
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken_MissingKey() {
|
||||||
|
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, "missing-key"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeminiTokenCacheSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(GeminiTokenCacheSuite))
|
||||||
|
}
|
||||||
28
backend/internal/repository/gemini_token_cache_test.go
Normal file
28
backend/internal/repository/gemini_token_cache_test.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGeminiTokenCache_DeleteAccessToken_RedisError(t *testing.T) {
|
||||||
|
rdb := redis.NewClient(&redis.Options{
|
||||||
|
Addr: "127.0.0.1:1",
|
||||||
|
DialTimeout: 50 * time.Millisecond,
|
||||||
|
ReadTimeout: 50 * time.Millisecond,
|
||||||
|
WriteTimeout: 50 * time.Millisecond,
|
||||||
|
})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = rdb.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
cache := NewGeminiTokenCache(rdb)
|
||||||
|
err := cache.DeleteAccessToken(context.Background(), "broken")
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
|
"log"
|
||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||||
@@ -55,6 +56,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
|||||||
groupIn.ID = created.ID
|
groupIn.ID = created.ID
|
||||||
groupIn.CreatedAt = created.CreatedAt
|
groupIn.CreatedAt = created.CreatedAt
|
||||||
groupIn.UpdatedAt = created.UpdatedAt
|
groupIn.UpdatedAt = created.UpdatedAt
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue group create failed: group=%d err=%v", groupIn.ID, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return translatePersistenceError(err, nil, service.ErrGroupExists)
|
return translatePersistenceError(err, nil, service.ErrGroupExists)
|
||||||
}
|
}
|
||||||
@@ -111,12 +115,21 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
|||||||
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
|
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
|
||||||
}
|
}
|
||||||
groupIn.UpdatedAt = updated.UpdatedAt
|
groupIn.UpdatedAt = updated.UpdatedAt
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue group update failed: group=%d err=%v", groupIn.ID, err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *groupRepository) Delete(ctx context.Context, id int64) error {
|
func (r *groupRepository) Delete(ctx context.Context, id int64) error {
|
||||||
_, err := r.client.Group.Delete().Where(group.IDEQ(id)).Exec(ctx)
|
_, err := r.client.Group.Delete().Where(group.IDEQ(id)).Exec(ctx)
|
||||||
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
if err != nil {
|
||||||
|
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||||
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue group delete failed: group=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
||||||
@@ -246,6 +259,9 @@ func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, grou
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
affected, _ := res.RowsAffected()
|
affected, _ := res.RowsAffected()
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue group account clear failed: group=%d err=%v", groupID, err)
|
||||||
|
}
|
||||||
return affected, nil
|
return affected, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -353,6 +369,9 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue group cascade delete failed: group=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
return affectedUserIDs, nil
|
return affectedUserIDs, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,23 @@ CREATE TABLE IF NOT EXISTS schema_migrations (
|
|||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
|
const atlasSchemaRevisionsTableDDL = `
|
||||||
|
CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
|
||||||
|
version TEXT PRIMARY KEY,
|
||||||
|
description TEXT NOT NULL,
|
||||||
|
type INTEGER NOT NULL,
|
||||||
|
applied INTEGER NOT NULL DEFAULT 0,
|
||||||
|
total INTEGER NOT NULL DEFAULT 0,
|
||||||
|
executed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
execution_time BIGINT NOT NULL DEFAULT 0,
|
||||||
|
error TEXT NULL,
|
||||||
|
error_stmt TEXT NULL,
|
||||||
|
hash TEXT NOT NULL DEFAULT '',
|
||||||
|
partial_hashes TEXT[] NULL,
|
||||||
|
operator_version TEXT NULL
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
// migrationsAdvisoryLockID 是用于序列化迁移操作的 PostgreSQL Advisory Lock ID。
|
// migrationsAdvisoryLockID 是用于序列化迁移操作的 PostgreSQL Advisory Lock ID。
|
||||||
// 在多实例部署场景下,该锁确保同一时间只有一个实例执行迁移。
|
// 在多实例部署场景下,该锁确保同一时间只有一个实例执行迁移。
|
||||||
// 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。
|
// 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。
|
||||||
@@ -94,6 +111,11 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
|||||||
return fmt.Errorf("create schema_migrations: %w", err)
|
return fmt.Errorf("create schema_migrations: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 自动对齐 Atlas 基线(如果检测到 legacy schema_migrations 且缺失 atlas_schema_revisions)。
|
||||||
|
if err := ensureAtlasBaselineAligned(ctx, db, fsys); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// 获取所有 .sql 迁移文件并按文件名排序。
|
// 获取所有 .sql 迁移文件并按文件名排序。
|
||||||
// 命名规范:使用零填充数字前缀(如 001_init.sql, 002_add_users.sql)。
|
// 命名规范:使用零填充数字前缀(如 001_init.sql, 002_add_users.sql)。
|
||||||
files, err := fs.Glob(fsys, "*.sql")
|
files, err := fs.Glob(fsys, "*.sql")
|
||||||
@@ -172,6 +194,80 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ensureAtlasBaselineAligned(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
||||||
|
hasLegacy, err := tableExists(ctx, db, "schema_migrations")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("check schema_migrations: %w", err)
|
||||||
|
}
|
||||||
|
if !hasLegacy {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
hasAtlas, err := tableExists(ctx, db, "atlas_schema_revisions")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("check atlas_schema_revisions: %w", err)
|
||||||
|
}
|
||||||
|
if !hasAtlas {
|
||||||
|
if _, err := db.ExecContext(ctx, atlasSchemaRevisionsTableDDL); err != nil {
|
||||||
|
return fmt.Errorf("create atlas_schema_revisions: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var count int
|
||||||
|
if err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM atlas_schema_revisions").Scan(&count); err != nil {
|
||||||
|
return fmt.Errorf("count atlas_schema_revisions: %w", err)
|
||||||
|
}
|
||||||
|
if count > 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
version, description, hash, err := latestMigrationBaseline(fsys)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("atlas baseline version: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := db.ExecContext(ctx, `
|
||||||
|
INSERT INTO atlas_schema_revisions (version, description, type, applied, total, executed_at, execution_time, hash)
|
||||||
|
VALUES ($1, $2, $3, 0, 0, NOW(), 0, $4)
|
||||||
|
`, version, description, 1, hash); err != nil {
|
||||||
|
return fmt.Errorf("insert atlas baseline: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func tableExists(ctx context.Context, db *sql.DB, tableName string) (bool, error) {
|
||||||
|
var exists bool
|
||||||
|
err := db.QueryRowContext(ctx, `
|
||||||
|
SELECT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM information_schema.tables
|
||||||
|
WHERE table_schema = 'public' AND table_name = $1
|
||||||
|
)
|
||||||
|
`, tableName).Scan(&exists)
|
||||||
|
return exists, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) {
|
||||||
|
files, err := fs.Glob(fsys, "*.sql")
|
||||||
|
if err != nil {
|
||||||
|
return "", "", "", err
|
||||||
|
}
|
||||||
|
if len(files) == 0 {
|
||||||
|
return "baseline", "baseline", "", nil
|
||||||
|
}
|
||||||
|
sort.Strings(files)
|
||||||
|
name := files[len(files)-1]
|
||||||
|
contentBytes, err := fs.ReadFile(fsys, name)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", "", err
|
||||||
|
}
|
||||||
|
content := strings.TrimSpace(string(contentBytes))
|
||||||
|
sum := sha256.Sum256([]byte(content))
|
||||||
|
hash := hex.EncodeToString(sum[:])
|
||||||
|
version := strings.TrimSuffix(name, ".sql")
|
||||||
|
return version, version, hash, nil
|
||||||
|
}
|
||||||
|
|
||||||
// pgAdvisoryLock 获取 PostgreSQL Advisory Lock。
|
// pgAdvisoryLock 获取 PostgreSQL Advisory Lock。
|
||||||
// Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。
|
// Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。
|
||||||
// 它非常适合用于应用层面的分布式锁场景,如迁移序列化。
|
// 它非常适合用于应用层面的分布式锁场景,如迁移序列化。
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ INSERT INTO ops_error_logs (
|
|||||||
severity,
|
severity,
|
||||||
status_code,
|
status_code,
|
||||||
is_business_limited,
|
is_business_limited,
|
||||||
|
is_count_tokens,
|
||||||
error_message,
|
error_message,
|
||||||
error_body,
|
error_body,
|
||||||
error_source,
|
error_source,
|
||||||
@@ -54,7 +55,6 @@ INSERT INTO ops_error_logs (
|
|||||||
upstream_error_message,
|
upstream_error_message,
|
||||||
upstream_error_detail,
|
upstream_error_detail,
|
||||||
upstream_errors,
|
upstream_errors,
|
||||||
duration_ms,
|
|
||||||
time_to_first_token_ms,
|
time_to_first_token_ms,
|
||||||
request_body,
|
request_body,
|
||||||
request_body_truncated,
|
request_body_truncated,
|
||||||
@@ -88,6 +88,7 @@ INSERT INTO ops_error_logs (
|
|||||||
opsNullString(input.Severity),
|
opsNullString(input.Severity),
|
||||||
opsNullInt(input.StatusCode),
|
opsNullInt(input.StatusCode),
|
||||||
input.IsBusinessLimited,
|
input.IsBusinessLimited,
|
||||||
|
input.IsCountTokens,
|
||||||
opsNullString(input.ErrorMessage),
|
opsNullString(input.ErrorMessage),
|
||||||
opsNullString(input.ErrorBody),
|
opsNullString(input.ErrorBody),
|
||||||
opsNullString(input.ErrorSource),
|
opsNullString(input.ErrorSource),
|
||||||
@@ -96,7 +97,6 @@ INSERT INTO ops_error_logs (
|
|||||||
opsNullString(input.UpstreamErrorMessage),
|
opsNullString(input.UpstreamErrorMessage),
|
||||||
opsNullString(input.UpstreamErrorDetail),
|
opsNullString(input.UpstreamErrorDetail),
|
||||||
opsNullString(input.UpstreamErrorsJSON),
|
opsNullString(input.UpstreamErrorsJSON),
|
||||||
opsNullInt(input.DurationMs),
|
|
||||||
opsNullInt64(input.TimeToFirstTokenMs),
|
opsNullInt64(input.TimeToFirstTokenMs),
|
||||||
opsNullString(input.RequestBodyJSON),
|
opsNullString(input.RequestBodyJSON),
|
||||||
input.RequestBodyTruncated,
|
input.RequestBodyTruncated,
|
||||||
@@ -133,7 +133,7 @@ func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsEr
|
|||||||
}
|
}
|
||||||
|
|
||||||
where, args := buildOpsErrorLogsWhere(filter)
|
where, args := buildOpsErrorLogsWhere(filter)
|
||||||
countSQL := "SELECT COUNT(*) FROM ops_error_logs " + where
|
countSQL := "SELECT COUNT(*) FROM ops_error_logs e " + where
|
||||||
|
|
||||||
var total int
|
var total int
|
||||||
if err := r.db.QueryRowContext(ctx, countSQL, args...).Scan(&total); err != nil {
|
if err := r.db.QueryRowContext(ctx, countSQL, args...).Scan(&total); err != nil {
|
||||||
@@ -144,28 +144,43 @@ func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsEr
|
|||||||
argsWithLimit := append(args, pageSize, offset)
|
argsWithLimit := append(args, pageSize, offset)
|
||||||
selectSQL := `
|
selectSQL := `
|
||||||
SELECT
|
SELECT
|
||||||
id,
|
e.id,
|
||||||
created_at,
|
e.created_at,
|
||||||
error_phase,
|
e.error_phase,
|
||||||
error_type,
|
e.error_type,
|
||||||
severity,
|
COALESCE(e.error_owner, ''),
|
||||||
COALESCE(upstream_status_code, status_code, 0),
|
COALESCE(e.error_source, ''),
|
||||||
COALESCE(platform, ''),
|
e.severity,
|
||||||
COALESCE(model, ''),
|
COALESCE(e.upstream_status_code, e.status_code, 0),
|
||||||
duration_ms,
|
COALESCE(e.platform, ''),
|
||||||
COALESCE(client_request_id, ''),
|
COALESCE(e.model, ''),
|
||||||
COALESCE(request_id, ''),
|
COALESCE(e.is_retryable, false),
|
||||||
COALESCE(error_message, ''),
|
COALESCE(e.retry_count, 0),
|
||||||
user_id,
|
COALESCE(e.resolved, false),
|
||||||
api_key_id,
|
e.resolved_at,
|
||||||
account_id,
|
e.resolved_by_user_id,
|
||||||
group_id,
|
COALESCE(u2.email, ''),
|
||||||
CASE WHEN client_ip IS NULL THEN NULL ELSE client_ip::text END,
|
e.resolved_retry_id,
|
||||||
COALESCE(request_path, ''),
|
COALESCE(e.client_request_id, ''),
|
||||||
stream
|
COALESCE(e.request_id, ''),
|
||||||
FROM ops_error_logs
|
COALESCE(e.error_message, ''),
|
||||||
|
e.user_id,
|
||||||
|
COALESCE(u.email, ''),
|
||||||
|
e.api_key_id,
|
||||||
|
e.account_id,
|
||||||
|
COALESCE(a.name, ''),
|
||||||
|
e.group_id,
|
||||||
|
COALESCE(g.name, ''),
|
||||||
|
CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END,
|
||||||
|
COALESCE(e.request_path, ''),
|
||||||
|
e.stream
|
||||||
|
FROM ops_error_logs e
|
||||||
|
LEFT JOIN accounts a ON e.account_id = a.id
|
||||||
|
LEFT JOIN groups g ON e.group_id = g.id
|
||||||
|
LEFT JOIN users u ON e.user_id = u.id
|
||||||
|
LEFT JOIN users u2 ON e.resolved_by_user_id = u2.id
|
||||||
` + where + `
|
` + where + `
|
||||||
ORDER BY created_at DESC
|
ORDER BY e.created_at DESC
|
||||||
LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
|
LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
|
||||||
|
|
||||||
rows, err := r.db.QueryContext(ctx, selectSQL, argsWithLimit...)
|
rows, err := r.db.QueryContext(ctx, selectSQL, argsWithLimit...)
|
||||||
@@ -177,39 +192,65 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
|
|||||||
out := make([]*service.OpsErrorLog, 0, pageSize)
|
out := make([]*service.OpsErrorLog, 0, pageSize)
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var item service.OpsErrorLog
|
var item service.OpsErrorLog
|
||||||
var latency sql.NullInt64
|
|
||||||
var statusCode sql.NullInt64
|
var statusCode sql.NullInt64
|
||||||
var clientIP sql.NullString
|
var clientIP sql.NullString
|
||||||
var userID sql.NullInt64
|
var userID sql.NullInt64
|
||||||
var apiKeyID sql.NullInt64
|
var apiKeyID sql.NullInt64
|
||||||
var accountID sql.NullInt64
|
var accountID sql.NullInt64
|
||||||
|
var accountName string
|
||||||
var groupID sql.NullInt64
|
var groupID sql.NullInt64
|
||||||
|
var groupName string
|
||||||
|
var userEmail string
|
||||||
|
var resolvedAt sql.NullTime
|
||||||
|
var resolvedBy sql.NullInt64
|
||||||
|
var resolvedByName string
|
||||||
|
var resolvedRetryID sql.NullInt64
|
||||||
if err := rows.Scan(
|
if err := rows.Scan(
|
||||||
&item.ID,
|
&item.ID,
|
||||||
&item.CreatedAt,
|
&item.CreatedAt,
|
||||||
&item.Phase,
|
&item.Phase,
|
||||||
&item.Type,
|
&item.Type,
|
||||||
|
&item.Owner,
|
||||||
|
&item.Source,
|
||||||
&item.Severity,
|
&item.Severity,
|
||||||
&statusCode,
|
&statusCode,
|
||||||
&item.Platform,
|
&item.Platform,
|
||||||
&item.Model,
|
&item.Model,
|
||||||
&latency,
|
&item.IsRetryable,
|
||||||
|
&item.RetryCount,
|
||||||
|
&item.Resolved,
|
||||||
|
&resolvedAt,
|
||||||
|
&resolvedBy,
|
||||||
|
&resolvedByName,
|
||||||
|
&resolvedRetryID,
|
||||||
&item.ClientRequestID,
|
&item.ClientRequestID,
|
||||||
&item.RequestID,
|
&item.RequestID,
|
||||||
&item.Message,
|
&item.Message,
|
||||||
&userID,
|
&userID,
|
||||||
|
&userEmail,
|
||||||
&apiKeyID,
|
&apiKeyID,
|
||||||
&accountID,
|
&accountID,
|
||||||
|
&accountName,
|
||||||
&groupID,
|
&groupID,
|
||||||
|
&groupName,
|
||||||
&clientIP,
|
&clientIP,
|
||||||
&item.RequestPath,
|
&item.RequestPath,
|
||||||
&item.Stream,
|
&item.Stream,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if latency.Valid {
|
if resolvedAt.Valid {
|
||||||
v := int(latency.Int64)
|
t := resolvedAt.Time
|
||||||
item.LatencyMs = &v
|
item.ResolvedAt = &t
|
||||||
|
}
|
||||||
|
if resolvedBy.Valid {
|
||||||
|
v := resolvedBy.Int64
|
||||||
|
item.ResolvedByUserID = &v
|
||||||
|
}
|
||||||
|
item.ResolvedByUserName = resolvedByName
|
||||||
|
if resolvedRetryID.Valid {
|
||||||
|
v := resolvedRetryID.Int64
|
||||||
|
item.ResolvedRetryID = &v
|
||||||
}
|
}
|
||||||
item.StatusCode = int(statusCode.Int64)
|
item.StatusCode = int(statusCode.Int64)
|
||||||
if clientIP.Valid {
|
if clientIP.Valid {
|
||||||
@@ -220,6 +261,7 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
|
|||||||
v := userID.Int64
|
v := userID.Int64
|
||||||
item.UserID = &v
|
item.UserID = &v
|
||||||
}
|
}
|
||||||
|
item.UserEmail = userEmail
|
||||||
if apiKeyID.Valid {
|
if apiKeyID.Valid {
|
||||||
v := apiKeyID.Int64
|
v := apiKeyID.Int64
|
||||||
item.APIKeyID = &v
|
item.APIKeyID = &v
|
||||||
@@ -228,10 +270,12 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
|
|||||||
v := accountID.Int64
|
v := accountID.Int64
|
||||||
item.AccountID = &v
|
item.AccountID = &v
|
||||||
}
|
}
|
||||||
|
item.AccountName = accountName
|
||||||
if groupID.Valid {
|
if groupID.Valid {
|
||||||
v := groupID.Int64
|
v := groupID.Int64
|
||||||
item.GroupID = &v
|
item.GroupID = &v
|
||||||
}
|
}
|
||||||
|
item.GroupName = groupName
|
||||||
out = append(out, &item)
|
out = append(out, &item)
|
||||||
}
|
}
|
||||||
if err := rows.Err(); err != nil {
|
if err := rows.Err(); err != nil {
|
||||||
@@ -256,49 +300,64 @@ func (r *opsRepository) GetErrorLogByID(ctx context.Context, id int64) (*service
|
|||||||
|
|
||||||
q := `
|
q := `
|
||||||
SELECT
|
SELECT
|
||||||
id,
|
e.id,
|
||||||
created_at,
|
e.created_at,
|
||||||
error_phase,
|
e.error_phase,
|
||||||
error_type,
|
e.error_type,
|
||||||
severity,
|
COALESCE(e.error_owner, ''),
|
||||||
COALESCE(upstream_status_code, status_code, 0),
|
COALESCE(e.error_source, ''),
|
||||||
COALESCE(platform, ''),
|
e.severity,
|
||||||
COALESCE(model, ''),
|
COALESCE(e.upstream_status_code, e.status_code, 0),
|
||||||
duration_ms,
|
COALESCE(e.platform, ''),
|
||||||
COALESCE(client_request_id, ''),
|
COALESCE(e.model, ''),
|
||||||
COALESCE(request_id, ''),
|
COALESCE(e.is_retryable, false),
|
||||||
COALESCE(error_message, ''),
|
COALESCE(e.retry_count, 0),
|
||||||
COALESCE(error_body, ''),
|
COALESCE(e.resolved, false),
|
||||||
upstream_status_code,
|
e.resolved_at,
|
||||||
COALESCE(upstream_error_message, ''),
|
e.resolved_by_user_id,
|
||||||
COALESCE(upstream_error_detail, ''),
|
e.resolved_retry_id,
|
||||||
COALESCE(upstream_errors::text, ''),
|
COALESCE(e.client_request_id, ''),
|
||||||
is_business_limited,
|
COALESCE(e.request_id, ''),
|
||||||
user_id,
|
COALESCE(e.error_message, ''),
|
||||||
api_key_id,
|
COALESCE(e.error_body, ''),
|
||||||
account_id,
|
e.upstream_status_code,
|
||||||
group_id,
|
COALESCE(e.upstream_error_message, ''),
|
||||||
CASE WHEN client_ip IS NULL THEN NULL ELSE client_ip::text END,
|
COALESCE(e.upstream_error_detail, ''),
|
||||||
COALESCE(request_path, ''),
|
COALESCE(e.upstream_errors::text, ''),
|
||||||
stream,
|
e.is_business_limited,
|
||||||
COALESCE(user_agent, ''),
|
e.user_id,
|
||||||
auth_latency_ms,
|
COALESCE(u.email, ''),
|
||||||
routing_latency_ms,
|
e.api_key_id,
|
||||||
upstream_latency_ms,
|
e.account_id,
|
||||||
response_latency_ms,
|
COALESCE(a.name, ''),
|
||||||
time_to_first_token_ms,
|
e.group_id,
|
||||||
COALESCE(request_body::text, ''),
|
COALESCE(g.name, ''),
|
||||||
request_body_truncated,
|
CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END,
|
||||||
request_body_bytes,
|
COALESCE(e.request_path, ''),
|
||||||
COALESCE(request_headers::text, '')
|
e.stream,
|
||||||
FROM ops_error_logs
|
COALESCE(e.user_agent, ''),
|
||||||
WHERE id = $1
|
e.auth_latency_ms,
|
||||||
|
e.routing_latency_ms,
|
||||||
|
e.upstream_latency_ms,
|
||||||
|
e.response_latency_ms,
|
||||||
|
e.time_to_first_token_ms,
|
||||||
|
COALESCE(e.request_body::text, ''),
|
||||||
|
e.request_body_truncated,
|
||||||
|
e.request_body_bytes,
|
||||||
|
COALESCE(e.request_headers::text, '')
|
||||||
|
FROM ops_error_logs e
|
||||||
|
LEFT JOIN users u ON e.user_id = u.id
|
||||||
|
LEFT JOIN accounts a ON e.account_id = a.id
|
||||||
|
LEFT JOIN groups g ON e.group_id = g.id
|
||||||
|
WHERE e.id = $1
|
||||||
LIMIT 1`
|
LIMIT 1`
|
||||||
|
|
||||||
var out service.OpsErrorLogDetail
|
var out service.OpsErrorLogDetail
|
||||||
var latency sql.NullInt64
|
|
||||||
var statusCode sql.NullInt64
|
var statusCode sql.NullInt64
|
||||||
var upstreamStatusCode sql.NullInt64
|
var upstreamStatusCode sql.NullInt64
|
||||||
|
var resolvedAt sql.NullTime
|
||||||
|
var resolvedBy sql.NullInt64
|
||||||
|
var resolvedRetryID sql.NullInt64
|
||||||
var clientIP sql.NullString
|
var clientIP sql.NullString
|
||||||
var userID sql.NullInt64
|
var userID sql.NullInt64
|
||||||
var apiKeyID sql.NullInt64
|
var apiKeyID sql.NullInt64
|
||||||
@@ -316,11 +375,18 @@ LIMIT 1`
|
|||||||
&out.CreatedAt,
|
&out.CreatedAt,
|
||||||
&out.Phase,
|
&out.Phase,
|
||||||
&out.Type,
|
&out.Type,
|
||||||
|
&out.Owner,
|
||||||
|
&out.Source,
|
||||||
&out.Severity,
|
&out.Severity,
|
||||||
&statusCode,
|
&statusCode,
|
||||||
&out.Platform,
|
&out.Platform,
|
||||||
&out.Model,
|
&out.Model,
|
||||||
&latency,
|
&out.IsRetryable,
|
||||||
|
&out.RetryCount,
|
||||||
|
&out.Resolved,
|
||||||
|
&resolvedAt,
|
||||||
|
&resolvedBy,
|
||||||
|
&resolvedRetryID,
|
||||||
&out.ClientRequestID,
|
&out.ClientRequestID,
|
||||||
&out.RequestID,
|
&out.RequestID,
|
||||||
&out.Message,
|
&out.Message,
|
||||||
@@ -331,9 +397,12 @@ LIMIT 1`
|
|||||||
&out.UpstreamErrors,
|
&out.UpstreamErrors,
|
||||||
&out.IsBusinessLimited,
|
&out.IsBusinessLimited,
|
||||||
&userID,
|
&userID,
|
||||||
|
&out.UserEmail,
|
||||||
&apiKeyID,
|
&apiKeyID,
|
||||||
&accountID,
|
&accountID,
|
||||||
|
&out.AccountName,
|
||||||
&groupID,
|
&groupID,
|
||||||
|
&out.GroupName,
|
||||||
&clientIP,
|
&clientIP,
|
||||||
&out.RequestPath,
|
&out.RequestPath,
|
||||||
&out.Stream,
|
&out.Stream,
|
||||||
@@ -353,9 +422,17 @@ LIMIT 1`
|
|||||||
}
|
}
|
||||||
|
|
||||||
out.StatusCode = int(statusCode.Int64)
|
out.StatusCode = int(statusCode.Int64)
|
||||||
if latency.Valid {
|
if resolvedAt.Valid {
|
||||||
v := int(latency.Int64)
|
t := resolvedAt.Time
|
||||||
out.LatencyMs = &v
|
out.ResolvedAt = &t
|
||||||
|
}
|
||||||
|
if resolvedBy.Valid {
|
||||||
|
v := resolvedBy.Int64
|
||||||
|
out.ResolvedByUserID = &v
|
||||||
|
}
|
||||||
|
if resolvedRetryID.Valid {
|
||||||
|
v := resolvedRetryID.Int64
|
||||||
|
out.ResolvedRetryID = &v
|
||||||
}
|
}
|
||||||
if clientIP.Valid {
|
if clientIP.Valid {
|
||||||
s := clientIP.String
|
s := clientIP.String
|
||||||
@@ -485,9 +562,15 @@ SET
|
|||||||
status = $2,
|
status = $2,
|
||||||
finished_at = $3,
|
finished_at = $3,
|
||||||
duration_ms = $4,
|
duration_ms = $4,
|
||||||
result_request_id = $5,
|
success = $5,
|
||||||
result_error_id = $6,
|
http_status_code = $6,
|
||||||
error_message = $7
|
upstream_request_id = $7,
|
||||||
|
used_account_id = $8,
|
||||||
|
response_preview = $9,
|
||||||
|
response_truncated = $10,
|
||||||
|
result_request_id = $11,
|
||||||
|
result_error_id = $12,
|
||||||
|
error_message = $13
|
||||||
WHERE id = $1`
|
WHERE id = $1`
|
||||||
|
|
||||||
_, err := r.db.ExecContext(
|
_, err := r.db.ExecContext(
|
||||||
@@ -497,8 +580,14 @@ WHERE id = $1`
|
|||||||
strings.TrimSpace(input.Status),
|
strings.TrimSpace(input.Status),
|
||||||
nullTime(input.FinishedAt),
|
nullTime(input.FinishedAt),
|
||||||
input.DurationMs,
|
input.DurationMs,
|
||||||
|
nullBool(input.Success),
|
||||||
|
nullInt(input.HTTPStatusCode),
|
||||||
|
opsNullString(input.UpstreamRequestID),
|
||||||
|
nullInt64(input.UsedAccountID),
|
||||||
|
opsNullString(input.ResponsePreview),
|
||||||
|
nullBool(input.ResponseTruncated),
|
||||||
opsNullString(input.ResultRequestID),
|
opsNullString(input.ResultRequestID),
|
||||||
opsNullInt64(input.ResultErrorID),
|
nullInt64(input.ResultErrorID),
|
||||||
opsNullString(input.ErrorMessage),
|
opsNullString(input.ErrorMessage),
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
@@ -524,6 +613,12 @@ SELECT
|
|||||||
started_at,
|
started_at,
|
||||||
finished_at,
|
finished_at,
|
||||||
duration_ms,
|
duration_ms,
|
||||||
|
success,
|
||||||
|
http_status_code,
|
||||||
|
upstream_request_id,
|
||||||
|
used_account_id,
|
||||||
|
response_preview,
|
||||||
|
response_truncated,
|
||||||
result_request_id,
|
result_request_id,
|
||||||
result_error_id,
|
result_error_id,
|
||||||
error_message
|
error_message
|
||||||
@@ -538,6 +633,12 @@ LIMIT 1`
|
|||||||
var startedAt sql.NullTime
|
var startedAt sql.NullTime
|
||||||
var finishedAt sql.NullTime
|
var finishedAt sql.NullTime
|
||||||
var durationMs sql.NullInt64
|
var durationMs sql.NullInt64
|
||||||
|
var success sql.NullBool
|
||||||
|
var httpStatusCode sql.NullInt64
|
||||||
|
var upstreamRequestID sql.NullString
|
||||||
|
var usedAccountID sql.NullInt64
|
||||||
|
var responsePreview sql.NullString
|
||||||
|
var responseTruncated sql.NullBool
|
||||||
var resultRequestID sql.NullString
|
var resultRequestID sql.NullString
|
||||||
var resultErrorID sql.NullInt64
|
var resultErrorID sql.NullInt64
|
||||||
var errorMessage sql.NullString
|
var errorMessage sql.NullString
|
||||||
@@ -553,6 +654,12 @@ LIMIT 1`
|
|||||||
&startedAt,
|
&startedAt,
|
||||||
&finishedAt,
|
&finishedAt,
|
||||||
&durationMs,
|
&durationMs,
|
||||||
|
&success,
|
||||||
|
&httpStatusCode,
|
||||||
|
&upstreamRequestID,
|
||||||
|
&usedAccountID,
|
||||||
|
&responsePreview,
|
||||||
|
&responseTruncated,
|
||||||
&resultRequestID,
|
&resultRequestID,
|
||||||
&resultErrorID,
|
&resultErrorID,
|
||||||
&errorMessage,
|
&errorMessage,
|
||||||
@@ -577,6 +684,30 @@ LIMIT 1`
|
|||||||
v := durationMs.Int64
|
v := durationMs.Int64
|
||||||
out.DurationMs = &v
|
out.DurationMs = &v
|
||||||
}
|
}
|
||||||
|
if success.Valid {
|
||||||
|
v := success.Bool
|
||||||
|
out.Success = &v
|
||||||
|
}
|
||||||
|
if httpStatusCode.Valid {
|
||||||
|
v := int(httpStatusCode.Int64)
|
||||||
|
out.HTTPStatusCode = &v
|
||||||
|
}
|
||||||
|
if upstreamRequestID.Valid {
|
||||||
|
s := upstreamRequestID.String
|
||||||
|
out.UpstreamRequestID = &s
|
||||||
|
}
|
||||||
|
if usedAccountID.Valid {
|
||||||
|
v := usedAccountID.Int64
|
||||||
|
out.UsedAccountID = &v
|
||||||
|
}
|
||||||
|
if responsePreview.Valid {
|
||||||
|
s := responsePreview.String
|
||||||
|
out.ResponsePreview = &s
|
||||||
|
}
|
||||||
|
if responseTruncated.Valid {
|
||||||
|
v := responseTruncated.Bool
|
||||||
|
out.ResponseTruncated = &v
|
||||||
|
}
|
||||||
if resultRequestID.Valid {
|
if resultRequestID.Valid {
|
||||||
s := resultRequestID.String
|
s := resultRequestID.String
|
||||||
out.ResultRequestID = &s
|
out.ResultRequestID = &s
|
||||||
@@ -600,30 +731,234 @@ func nullTime(t time.Time) sql.NullTime {
|
|||||||
return sql.NullTime{Time: t, Valid: true}
|
return sql.NullTime{Time: t, Valid: true}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func nullBool(v *bool) sql.NullBool {
|
||||||
|
if v == nil {
|
||||||
|
return sql.NullBool{}
|
||||||
|
}
|
||||||
|
return sql.NullBool{Bool: *v, Valid: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) ListRetryAttemptsByErrorID(ctx context.Context, sourceErrorID int64, limit int) ([]*service.OpsRetryAttempt, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if sourceErrorID <= 0 {
|
||||||
|
return nil, fmt.Errorf("invalid source_error_id")
|
||||||
|
}
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 50
|
||||||
|
}
|
||||||
|
if limit > 200 {
|
||||||
|
limit = 200
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
SELECT
|
||||||
|
r.id,
|
||||||
|
r.created_at,
|
||||||
|
COALESCE(r.requested_by_user_id, 0),
|
||||||
|
r.source_error_id,
|
||||||
|
COALESCE(r.mode, ''),
|
||||||
|
r.pinned_account_id,
|
||||||
|
COALESCE(pa.name, ''),
|
||||||
|
COALESCE(r.status, ''),
|
||||||
|
r.started_at,
|
||||||
|
r.finished_at,
|
||||||
|
r.duration_ms,
|
||||||
|
r.success,
|
||||||
|
r.http_status_code,
|
||||||
|
r.upstream_request_id,
|
||||||
|
r.used_account_id,
|
||||||
|
COALESCE(ua.name, ''),
|
||||||
|
r.response_preview,
|
||||||
|
r.response_truncated,
|
||||||
|
r.result_request_id,
|
||||||
|
r.result_error_id,
|
||||||
|
r.error_message
|
||||||
|
FROM ops_retry_attempts r
|
||||||
|
LEFT JOIN accounts pa ON r.pinned_account_id = pa.id
|
||||||
|
LEFT JOIN accounts ua ON r.used_account_id = ua.id
|
||||||
|
WHERE r.source_error_id = $1
|
||||||
|
ORDER BY r.created_at DESC
|
||||||
|
LIMIT $2`
|
||||||
|
|
||||||
|
rows, err := r.db.QueryContext(ctx, q, sourceErrorID, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
out := make([]*service.OpsRetryAttempt, 0, 16)
|
||||||
|
for rows.Next() {
|
||||||
|
var item service.OpsRetryAttempt
|
||||||
|
var pinnedAccountID sql.NullInt64
|
||||||
|
var pinnedAccountName string
|
||||||
|
var requestedBy sql.NullInt64
|
||||||
|
var startedAt sql.NullTime
|
||||||
|
var finishedAt sql.NullTime
|
||||||
|
var durationMs sql.NullInt64
|
||||||
|
var success sql.NullBool
|
||||||
|
var httpStatusCode sql.NullInt64
|
||||||
|
var upstreamRequestID sql.NullString
|
||||||
|
var usedAccountID sql.NullInt64
|
||||||
|
var usedAccountName string
|
||||||
|
var responsePreview sql.NullString
|
||||||
|
var responseTruncated sql.NullBool
|
||||||
|
var resultRequestID sql.NullString
|
||||||
|
var resultErrorID sql.NullInt64
|
||||||
|
var errorMessage sql.NullString
|
||||||
|
|
||||||
|
if err := rows.Scan(
|
||||||
|
&item.ID,
|
||||||
|
&item.CreatedAt,
|
||||||
|
&requestedBy,
|
||||||
|
&item.SourceErrorID,
|
||||||
|
&item.Mode,
|
||||||
|
&pinnedAccountID,
|
||||||
|
&pinnedAccountName,
|
||||||
|
&item.Status,
|
||||||
|
&startedAt,
|
||||||
|
&finishedAt,
|
||||||
|
&durationMs,
|
||||||
|
&success,
|
||||||
|
&httpStatusCode,
|
||||||
|
&upstreamRequestID,
|
||||||
|
&usedAccountID,
|
||||||
|
&usedAccountName,
|
||||||
|
&responsePreview,
|
||||||
|
&responseTruncated,
|
||||||
|
&resultRequestID,
|
||||||
|
&resultErrorID,
|
||||||
|
&errorMessage,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
item.RequestedByUserID = requestedBy.Int64
|
||||||
|
if pinnedAccountID.Valid {
|
||||||
|
v := pinnedAccountID.Int64
|
||||||
|
item.PinnedAccountID = &v
|
||||||
|
}
|
||||||
|
item.PinnedAccountName = pinnedAccountName
|
||||||
|
if startedAt.Valid {
|
||||||
|
t := startedAt.Time
|
||||||
|
item.StartedAt = &t
|
||||||
|
}
|
||||||
|
if finishedAt.Valid {
|
||||||
|
t := finishedAt.Time
|
||||||
|
item.FinishedAt = &t
|
||||||
|
}
|
||||||
|
if durationMs.Valid {
|
||||||
|
v := durationMs.Int64
|
||||||
|
item.DurationMs = &v
|
||||||
|
}
|
||||||
|
if success.Valid {
|
||||||
|
v := success.Bool
|
||||||
|
item.Success = &v
|
||||||
|
}
|
||||||
|
if httpStatusCode.Valid {
|
||||||
|
v := int(httpStatusCode.Int64)
|
||||||
|
item.HTTPStatusCode = &v
|
||||||
|
}
|
||||||
|
if upstreamRequestID.Valid {
|
||||||
|
item.UpstreamRequestID = &upstreamRequestID.String
|
||||||
|
}
|
||||||
|
if usedAccountID.Valid {
|
||||||
|
v := usedAccountID.Int64
|
||||||
|
item.UsedAccountID = &v
|
||||||
|
}
|
||||||
|
item.UsedAccountName = usedAccountName
|
||||||
|
if responsePreview.Valid {
|
||||||
|
item.ResponsePreview = &responsePreview.String
|
||||||
|
}
|
||||||
|
if responseTruncated.Valid {
|
||||||
|
v := responseTruncated.Bool
|
||||||
|
item.ResponseTruncated = &v
|
||||||
|
}
|
||||||
|
if resultRequestID.Valid {
|
||||||
|
item.ResultRequestID = &resultRequestID.String
|
||||||
|
}
|
||||||
|
if resultErrorID.Valid {
|
||||||
|
v := resultErrorID.Int64
|
||||||
|
item.ResultErrorID = &v
|
||||||
|
}
|
||||||
|
if errorMessage.Valid {
|
||||||
|
item.ErrorMessage = &errorMessage.String
|
||||||
|
}
|
||||||
|
out = append(out, &item)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64, resolvedAt *time.Time) error {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if errorID <= 0 {
|
||||||
|
return fmt.Errorf("invalid error id")
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
UPDATE ops_error_logs
|
||||||
|
SET
|
||||||
|
resolved = $2,
|
||||||
|
resolved_at = $3,
|
||||||
|
resolved_by_user_id = $4,
|
||||||
|
resolved_retry_id = $5
|
||||||
|
WHERE id = $1`
|
||||||
|
|
||||||
|
at := sql.NullTime{}
|
||||||
|
if resolvedAt != nil && !resolvedAt.IsZero() {
|
||||||
|
at = sql.NullTime{Time: resolvedAt.UTC(), Valid: true}
|
||||||
|
} else if resolved {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
at = sql.NullTime{Time: now, Valid: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := r.db.ExecContext(
|
||||||
|
ctx,
|
||||||
|
q,
|
||||||
|
errorID,
|
||||||
|
resolved,
|
||||||
|
at,
|
||||||
|
nullInt64(resolvedByUserID),
|
||||||
|
nullInt64(resolvedRetryID),
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
||||||
clauses := make([]string, 0, 8)
|
clauses := make([]string, 0, 12)
|
||||||
args := make([]any, 0, 8)
|
args := make([]any, 0, 12)
|
||||||
clauses = append(clauses, "1=1")
|
clauses = append(clauses, "1=1")
|
||||||
|
|
||||||
phaseFilter := ""
|
phaseFilter := ""
|
||||||
if filter != nil {
|
if filter != nil {
|
||||||
phaseFilter = strings.TrimSpace(strings.ToLower(filter.Phase))
|
phaseFilter = strings.TrimSpace(strings.ToLower(filter.Phase))
|
||||||
}
|
}
|
||||||
// ops_error_logs primarily stores client-visible error requests (status>=400),
|
// ops_error_logs stores client-visible error requests (status>=400),
|
||||||
// but we also persist "recovered" upstream errors (status<400) for upstream health visibility.
|
// but we also persist "recovered" upstream errors (status<400) for upstream health visibility.
|
||||||
// By default, keep list endpoints scoped to client errors unless explicitly filtering upstream phase.
|
// If Resolved is not specified, do not filter by resolved state (backward-compatible).
|
||||||
|
resolvedFilter := (*bool)(nil)
|
||||||
|
if filter != nil {
|
||||||
|
resolvedFilter = filter.Resolved
|
||||||
|
}
|
||||||
|
// Keep list endpoints scoped to client errors unless explicitly filtering upstream phase.
|
||||||
if phaseFilter != "upstream" {
|
if phaseFilter != "upstream" {
|
||||||
clauses = append(clauses, "COALESCE(status_code, 0) >= 400")
|
clauses = append(clauses, "COALESCE(status_code, 0) >= 400")
|
||||||
}
|
}
|
||||||
|
|
||||||
if filter.StartTime != nil && !filter.StartTime.IsZero() {
|
if filter.StartTime != nil && !filter.StartTime.IsZero() {
|
||||||
args = append(args, filter.StartTime.UTC())
|
args = append(args, filter.StartTime.UTC())
|
||||||
clauses = append(clauses, "created_at >= $"+itoa(len(args)))
|
clauses = append(clauses, "e.created_at >= $"+itoa(len(args)))
|
||||||
}
|
}
|
||||||
if filter.EndTime != nil && !filter.EndTime.IsZero() {
|
if filter.EndTime != nil && !filter.EndTime.IsZero() {
|
||||||
args = append(args, filter.EndTime.UTC())
|
args = append(args, filter.EndTime.UTC())
|
||||||
// Keep time-window semantics consistent with other ops queries: [start, end)
|
// Keep time-window semantics consistent with other ops queries: [start, end)
|
||||||
clauses = append(clauses, "created_at < $"+itoa(len(args)))
|
clauses = append(clauses, "e.created_at < $"+itoa(len(args)))
|
||||||
}
|
}
|
||||||
if p := strings.TrimSpace(filter.Platform); p != "" {
|
if p := strings.TrimSpace(filter.Platform); p != "" {
|
||||||
args = append(args, p)
|
args = append(args, p)
|
||||||
@@ -641,10 +976,59 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
|||||||
args = append(args, phase)
|
args = append(args, phase)
|
||||||
clauses = append(clauses, "error_phase = $"+itoa(len(args)))
|
clauses = append(clauses, "error_phase = $"+itoa(len(args)))
|
||||||
}
|
}
|
||||||
|
if filter != nil {
|
||||||
|
if owner := strings.TrimSpace(strings.ToLower(filter.Owner)); owner != "" {
|
||||||
|
args = append(args, owner)
|
||||||
|
clauses = append(clauses, "LOWER(COALESCE(error_owner,'')) = $"+itoa(len(args)))
|
||||||
|
}
|
||||||
|
if source := strings.TrimSpace(strings.ToLower(filter.Source)); source != "" {
|
||||||
|
args = append(args, source)
|
||||||
|
clauses = append(clauses, "LOWER(COALESCE(error_source,'')) = $"+itoa(len(args)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if resolvedFilter != nil {
|
||||||
|
args = append(args, *resolvedFilter)
|
||||||
|
clauses = append(clauses, "COALESCE(resolved,false) = $"+itoa(len(args)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// View filter: errors vs excluded vs all.
|
||||||
|
// Excluded = upstream 429/529 and business-limited (quota/concurrency/billing) errors.
|
||||||
|
view := ""
|
||||||
|
if filter != nil {
|
||||||
|
view = strings.ToLower(strings.TrimSpace(filter.View))
|
||||||
|
}
|
||||||
|
switch view {
|
||||||
|
case "", "errors":
|
||||||
|
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
|
||||||
|
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)")
|
||||||
|
case "excluded":
|
||||||
|
clauses = append(clauses, "(COALESCE(is_business_limited,false) = true OR COALESCE(upstream_status_code, status_code, 0) IN (429, 529))")
|
||||||
|
case "all":
|
||||||
|
// no-op
|
||||||
|
default:
|
||||||
|
// treat unknown as default 'errors'
|
||||||
|
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
|
||||||
|
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)")
|
||||||
|
}
|
||||||
if len(filter.StatusCodes) > 0 {
|
if len(filter.StatusCodes) > 0 {
|
||||||
args = append(args, pq.Array(filter.StatusCodes))
|
args = append(args, pq.Array(filter.StatusCodes))
|
||||||
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+")")
|
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+")")
|
||||||
|
} else if filter.StatusCodesOther {
|
||||||
|
// "Other" means: status codes not in the common list.
|
||||||
|
known := []int{400, 401, 403, 404, 409, 422, 429, 500, 502, 503, 504, 529}
|
||||||
|
args = append(args, pq.Array(known))
|
||||||
|
clauses = append(clauses, "NOT (COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+"))")
|
||||||
}
|
}
|
||||||
|
// Exact correlation keys (preferred for request↔upstream linkage).
|
||||||
|
if rid := strings.TrimSpace(filter.RequestID); rid != "" {
|
||||||
|
args = append(args, rid)
|
||||||
|
clauses = append(clauses, "COALESCE(request_id,'') = $"+itoa(len(args)))
|
||||||
|
}
|
||||||
|
if crid := strings.TrimSpace(filter.ClientRequestID); crid != "" {
|
||||||
|
args = append(args, crid)
|
||||||
|
clauses = append(clauses, "COALESCE(client_request_id,'') = $"+itoa(len(args)))
|
||||||
|
}
|
||||||
|
|
||||||
if q := strings.TrimSpace(filter.Query); q != "" {
|
if q := strings.TrimSpace(filter.Query); q != "" {
|
||||||
like := "%" + q + "%"
|
like := "%" + q + "%"
|
||||||
args = append(args, like)
|
args = append(args, like)
|
||||||
@@ -652,6 +1036,13 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
|||||||
clauses = append(clauses, "(request_id ILIKE $"+n+" OR client_request_id ILIKE $"+n+" OR error_message ILIKE $"+n+")")
|
clauses = append(clauses, "(request_id ILIKE $"+n+" OR client_request_id ILIKE $"+n+" OR error_message ILIKE $"+n+")")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if userQuery := strings.TrimSpace(filter.UserQuery); userQuery != "" {
|
||||||
|
like := "%" + userQuery + "%"
|
||||||
|
args = append(args, like)
|
||||||
|
n := itoa(len(args))
|
||||||
|
clauses = append(clauses, "u.email ILIKE $"+n)
|
||||||
|
}
|
||||||
|
|
||||||
return "WHERE " + strings.Join(clauses, " AND "), args
|
return "WHERE " + strings.Join(clauses, " AND "), args
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -354,7 +354,7 @@ SELECT
|
|||||||
created_at
|
created_at
|
||||||
FROM ops_alert_events
|
FROM ops_alert_events
|
||||||
` + where + `
|
` + where + `
|
||||||
ORDER BY fired_at DESC
|
ORDER BY fired_at DESC, id DESC
|
||||||
LIMIT ` + limitArg
|
LIMIT ` + limitArg
|
||||||
|
|
||||||
rows, err := r.db.QueryContext(ctx, q, args...)
|
rows, err := r.db.QueryContext(ctx, q, args...)
|
||||||
@@ -413,6 +413,43 @@ LIMIT ` + limitArg
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) GetAlertEventByID(ctx context.Context, eventID int64) (*service.OpsAlertEvent, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if eventID <= 0 {
|
||||||
|
return nil, fmt.Errorf("invalid event id")
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
COALESCE(rule_id, 0),
|
||||||
|
COALESCE(severity, ''),
|
||||||
|
COALESCE(status, ''),
|
||||||
|
COALESCE(title, ''),
|
||||||
|
COALESCE(description, ''),
|
||||||
|
metric_value,
|
||||||
|
threshold_value,
|
||||||
|
dimensions,
|
||||||
|
fired_at,
|
||||||
|
resolved_at,
|
||||||
|
email_sent,
|
||||||
|
created_at
|
||||||
|
FROM ops_alert_events
|
||||||
|
WHERE id = $1`
|
||||||
|
|
||||||
|
row := r.db.QueryRowContext(ctx, q, eventID)
|
||||||
|
ev, err := scanOpsAlertEvent(row)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ev, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *opsRepository) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*service.OpsAlertEvent, error) {
|
func (r *opsRepository) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*service.OpsAlertEvent, error) {
|
||||||
if r == nil || r.db == nil {
|
if r == nil || r.db == nil {
|
||||||
return nil, fmt.Errorf("nil ops repository")
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
@@ -591,6 +628,121 @@ type opsAlertEventRow interface {
|
|||||||
Scan(dest ...any) error
|
Scan(dest ...any) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) CreateAlertSilence(ctx context.Context, input *service.OpsAlertSilence) (*service.OpsAlertSilence, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if input == nil {
|
||||||
|
return nil, fmt.Errorf("nil input")
|
||||||
|
}
|
||||||
|
if input.RuleID <= 0 {
|
||||||
|
return nil, fmt.Errorf("invalid rule_id")
|
||||||
|
}
|
||||||
|
platform := strings.TrimSpace(input.Platform)
|
||||||
|
if platform == "" {
|
||||||
|
return nil, fmt.Errorf("invalid platform")
|
||||||
|
}
|
||||||
|
if input.Until.IsZero() {
|
||||||
|
return nil, fmt.Errorf("invalid until")
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
INSERT INTO ops_alert_silences (
|
||||||
|
rule_id,
|
||||||
|
platform,
|
||||||
|
group_id,
|
||||||
|
region,
|
||||||
|
until,
|
||||||
|
reason,
|
||||||
|
created_by,
|
||||||
|
created_at
|
||||||
|
) VALUES (
|
||||||
|
$1,$2,$3,$4,$5,$6,$7,NOW()
|
||||||
|
)
|
||||||
|
RETURNING id, rule_id, platform, group_id, region, until, COALESCE(reason,''), created_by, created_at`
|
||||||
|
|
||||||
|
row := r.db.QueryRowContext(
|
||||||
|
ctx,
|
||||||
|
q,
|
||||||
|
input.RuleID,
|
||||||
|
platform,
|
||||||
|
opsNullInt64(input.GroupID),
|
||||||
|
opsNullString(input.Region),
|
||||||
|
input.Until,
|
||||||
|
opsNullString(input.Reason),
|
||||||
|
opsNullInt64(input.CreatedBy),
|
||||||
|
)
|
||||||
|
|
||||||
|
var out service.OpsAlertSilence
|
||||||
|
var groupID sql.NullInt64
|
||||||
|
var region sql.NullString
|
||||||
|
var createdBy sql.NullInt64
|
||||||
|
if err := row.Scan(
|
||||||
|
&out.ID,
|
||||||
|
&out.RuleID,
|
||||||
|
&out.Platform,
|
||||||
|
&groupID,
|
||||||
|
®ion,
|
||||||
|
&out.Until,
|
||||||
|
&out.Reason,
|
||||||
|
&createdBy,
|
||||||
|
&out.CreatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if groupID.Valid {
|
||||||
|
v := groupID.Int64
|
||||||
|
out.GroupID = &v
|
||||||
|
}
|
||||||
|
if region.Valid {
|
||||||
|
v := strings.TrimSpace(region.String)
|
||||||
|
if v != "" {
|
||||||
|
out.Region = &v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if createdBy.Valid {
|
||||||
|
v := createdBy.Int64
|
||||||
|
out.CreatedBy = &v
|
||||||
|
}
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return false, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if ruleID <= 0 {
|
||||||
|
return false, fmt.Errorf("invalid rule id")
|
||||||
|
}
|
||||||
|
platform = strings.TrimSpace(platform)
|
||||||
|
if platform == "" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if now.IsZero() {
|
||||||
|
now = time.Now().UTC()
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
SELECT 1
|
||||||
|
FROM ops_alert_silences
|
||||||
|
WHERE rule_id = $1
|
||||||
|
AND platform = $2
|
||||||
|
AND (group_id IS NOT DISTINCT FROM $3)
|
||||||
|
AND (region IS NOT DISTINCT FROM $4)
|
||||||
|
AND until > $5
|
||||||
|
LIMIT 1`
|
||||||
|
|
||||||
|
var dummy int
|
||||||
|
err := r.db.QueryRowContext(ctx, q, ruleID, platform, opsNullInt64(groupID), opsNullString(region), now).Scan(&dummy)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
func scanOpsAlertEvent(row opsAlertEventRow) (*service.OpsAlertEvent, error) {
|
func scanOpsAlertEvent(row opsAlertEventRow) (*service.OpsAlertEvent, error) {
|
||||||
var ev service.OpsAlertEvent
|
var ev service.OpsAlertEvent
|
||||||
var metricValue sql.NullFloat64
|
var metricValue sql.NullFloat64
|
||||||
@@ -652,6 +804,10 @@ func buildOpsAlertEventsWhere(filter *service.OpsAlertEventFilter) (string, []an
|
|||||||
args = append(args, severity)
|
args = append(args, severity)
|
||||||
clauses = append(clauses, "severity = $"+itoa(len(args)))
|
clauses = append(clauses, "severity = $"+itoa(len(args)))
|
||||||
}
|
}
|
||||||
|
if filter.EmailSent != nil {
|
||||||
|
args = append(args, *filter.EmailSent)
|
||||||
|
clauses = append(clauses, "email_sent = $"+itoa(len(args)))
|
||||||
|
}
|
||||||
if filter.StartTime != nil && !filter.StartTime.IsZero() {
|
if filter.StartTime != nil && !filter.StartTime.IsZero() {
|
||||||
args = append(args, *filter.StartTime)
|
args = append(args, *filter.StartTime)
|
||||||
clauses = append(clauses, "fired_at >= $"+itoa(len(args)))
|
clauses = append(clauses, "fired_at >= $"+itoa(len(args)))
|
||||||
@@ -661,6 +817,14 @@ func buildOpsAlertEventsWhere(filter *service.OpsAlertEventFilter) (string, []an
|
|||||||
clauses = append(clauses, "fired_at < $"+itoa(len(args)))
|
clauses = append(clauses, "fired_at < $"+itoa(len(args)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cursor pagination (descending by fired_at, then id)
|
||||||
|
if filter.BeforeFiredAt != nil && !filter.BeforeFiredAt.IsZero() && filter.BeforeID != nil && *filter.BeforeID > 0 {
|
||||||
|
args = append(args, *filter.BeforeFiredAt)
|
||||||
|
tsArg := "$" + itoa(len(args))
|
||||||
|
args = append(args, *filter.BeforeID)
|
||||||
|
idArg := "$" + itoa(len(args))
|
||||||
|
clauses = append(clauses, fmt.Sprintf("(fired_at < %s OR (fired_at = %s AND id < %s))", tsArg, tsArg, idArg))
|
||||||
|
}
|
||||||
// Dimensions are stored in JSONB. We filter best-effort without requiring GIN indexes.
|
// Dimensions are stored in JSONB. We filter best-effort without requiring GIN indexes.
|
||||||
if platform := strings.TrimSpace(filter.Platform); platform != "" {
|
if platform := strings.TrimSpace(filter.Platform); platform != "" {
|
||||||
args = append(args, platform)
|
args = append(args, platform)
|
||||||
|
|||||||
@@ -964,8 +964,8 @@ func buildErrorWhere(filter *service.OpsDashboardFilter, start, end time.Time, s
|
|||||||
}
|
}
|
||||||
|
|
||||||
idx := startIndex
|
idx := startIndex
|
||||||
clauses := make([]string, 0, 4)
|
clauses := make([]string, 0, 5)
|
||||||
args = make([]any, 0, 4)
|
args = make([]any, 0, 5)
|
||||||
|
|
||||||
args = append(args, start)
|
args = append(args, start)
|
||||||
clauses = append(clauses, fmt.Sprintf("created_at >= $%d", idx))
|
clauses = append(clauses, fmt.Sprintf("created_at >= $%d", idx))
|
||||||
@@ -974,6 +974,8 @@ func buildErrorWhere(filter *service.OpsDashboardFilter, start, end time.Time, s
|
|||||||
clauses = append(clauses, fmt.Sprintf("created_at < $%d", idx))
|
clauses = append(clauses, fmt.Sprintf("created_at < $%d", idx))
|
||||||
idx++
|
idx++
|
||||||
|
|
||||||
|
clauses = append(clauses, "is_count_tokens = FALSE")
|
||||||
|
|
||||||
if groupID != nil && *groupID > 0 {
|
if groupID != nil && *groupID > 0 {
|
||||||
args = append(args, *groupID)
|
args = append(args, *groupID)
|
||||||
clauses = append(clauses, fmt.Sprintf("group_id = $%d", idx))
|
clauses = append(clauses, fmt.Sprintf("group_id = $%d", idx))
|
||||||
|
|||||||
@@ -71,14 +71,18 @@ usage_agg AS (
|
|||||||
error_base AS (
|
error_base AS (
|
||||||
SELECT
|
SELECT
|
||||||
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
|
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
|
||||||
platform AS platform,
|
-- platform is NULL for some early-phase errors (e.g. before routing); map to a sentinel
|
||||||
|
-- value so platform-level GROUPING SETS don't collide with the overall (platform=NULL) row.
|
||||||
|
COALESCE(platform, 'unknown') AS platform,
|
||||||
group_id AS group_id,
|
group_id AS group_id,
|
||||||
is_business_limited AS is_business_limited,
|
is_business_limited AS is_business_limited,
|
||||||
error_owner AS error_owner,
|
error_owner AS error_owner,
|
||||||
status_code AS client_status_code,
|
status_code AS client_status_code,
|
||||||
COALESCE(upstream_status_code, status_code, 0) AS effective_status_code
|
COALESCE(upstream_status_code, status_code, 0) AS effective_status_code
|
||||||
FROM ops_error_logs
|
FROM ops_error_logs
|
||||||
|
-- Exclude count_tokens requests from error metrics as they are informational probes
|
||||||
WHERE created_at >= $1 AND created_at < $2
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
|
AND is_count_tokens = FALSE
|
||||||
),
|
),
|
||||||
error_agg AS (
|
error_agg AS (
|
||||||
SELECT
|
SELECT
|
||||||
|
|||||||
129
backend/internal/repository/ops_repo_realtime_traffic.go
Normal file
129
backend/internal/repository/ops_repo_realtime_traffic.go
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r *opsRepository) GetRealtimeTrafficSummary(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsRealtimeTrafficSummary, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if filter == nil {
|
||||||
|
return nil, fmt.Errorf("nil filter")
|
||||||
|
}
|
||||||
|
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||||
|
return nil, fmt.Errorf("start_time/end_time required")
|
||||||
|
}
|
||||||
|
|
||||||
|
start := filter.StartTime.UTC()
|
||||||
|
end := filter.EndTime.UTC()
|
||||||
|
if start.After(end) {
|
||||||
|
return nil, fmt.Errorf("start_time must be <= end_time")
|
||||||
|
}
|
||||||
|
|
||||||
|
window := end.Sub(start)
|
||||||
|
if window <= 0 {
|
||||||
|
return nil, fmt.Errorf("invalid time window")
|
||||||
|
}
|
||||||
|
if window > time.Hour {
|
||||||
|
return nil, fmt.Errorf("window too large")
|
||||||
|
}
|
||||||
|
|
||||||
|
usageJoin, usageWhere, usageArgs, next := buildUsageWhere(filter, start, end, 1)
|
||||||
|
errorWhere, errorArgs, _ := buildErrorWhere(filter, start, end, next)
|
||||||
|
|
||||||
|
q := `
|
||||||
|
WITH usage_buckets AS (
|
||||||
|
SELECT
|
||||||
|
date_trunc('minute', ul.created_at) AS bucket,
|
||||||
|
COALESCE(COUNT(*), 0) AS success_count,
|
||||||
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_sum
|
||||||
|
FROM usage_logs ul
|
||||||
|
` + usageJoin + `
|
||||||
|
` + usageWhere + `
|
||||||
|
GROUP BY 1
|
||||||
|
),
|
||||||
|
error_buckets AS (
|
||||||
|
SELECT
|
||||||
|
date_trunc('minute', created_at) AS bucket,
|
||||||
|
COALESCE(COUNT(*), 0) AS error_count
|
||||||
|
FROM ops_error_logs
|
||||||
|
` + errorWhere + `
|
||||||
|
AND COALESCE(status_code, 0) >= 400
|
||||||
|
GROUP BY 1
|
||||||
|
),
|
||||||
|
combined AS (
|
||||||
|
SELECT
|
||||||
|
COALESCE(u.bucket, e.bucket) AS bucket,
|
||||||
|
COALESCE(u.success_count, 0) AS success_count,
|
||||||
|
COALESCE(u.token_sum, 0) AS token_sum,
|
||||||
|
COALESCE(e.error_count, 0) AS error_count,
|
||||||
|
COALESCE(u.success_count, 0) + COALESCE(e.error_count, 0) AS request_total
|
||||||
|
FROM usage_buckets u
|
||||||
|
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
COALESCE(SUM(success_count), 0) AS success_total,
|
||||||
|
COALESCE(SUM(error_count), 0) AS error_total,
|
||||||
|
COALESCE(SUM(token_sum), 0) AS token_total,
|
||||||
|
COALESCE(MAX(request_total), 0) AS peak_requests_per_min,
|
||||||
|
COALESCE(MAX(token_sum), 0) AS peak_tokens_per_min
|
||||||
|
FROM combined`
|
||||||
|
|
||||||
|
args := append(usageArgs, errorArgs...)
|
||||||
|
var successCount int64
|
||||||
|
var errorTotal int64
|
||||||
|
var tokenConsumed int64
|
||||||
|
var peakRequestsPerMin int64
|
||||||
|
var peakTokensPerMin int64
|
||||||
|
if err := r.db.QueryRowContext(ctx, q, args...).Scan(
|
||||||
|
&successCount,
|
||||||
|
&errorTotal,
|
||||||
|
&tokenConsumed,
|
||||||
|
&peakRequestsPerMin,
|
||||||
|
&peakTokensPerMin,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
windowSeconds := window.Seconds()
|
||||||
|
if windowSeconds <= 0 {
|
||||||
|
windowSeconds = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
requestCountTotal := successCount + errorTotal
|
||||||
|
qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds)
|
||||||
|
tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds)
|
||||||
|
|
||||||
|
// Keep "current" consistent with the dashboard overview semantics: last 1 minute.
|
||||||
|
// This remains "within the selected window" since end=start+window.
|
||||||
|
qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
qpsPeak := roundTo1DP(float64(peakRequestsPerMin) / 60.0)
|
||||||
|
tpsPeak := roundTo1DP(float64(peakTokensPerMin) / 60.0)
|
||||||
|
|
||||||
|
return &service.OpsRealtimeTrafficSummary{
|
||||||
|
StartTime: start,
|
||||||
|
EndTime: end,
|
||||||
|
Platform: strings.TrimSpace(filter.Platform),
|
||||||
|
GroupID: filter.GroupID,
|
||||||
|
QPS: service.OpsRateSummary{
|
||||||
|
Current: qpsCurrent,
|
||||||
|
Peak: qpsPeak,
|
||||||
|
Avg: qpsAvg,
|
||||||
|
},
|
||||||
|
TPS: service.OpsRateSummary{
|
||||||
|
Current: tpsCurrent,
|
||||||
|
Peak: tpsPeak,
|
||||||
|
Avg: tpsAvg,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
@@ -170,6 +170,7 @@ error_totals AS (
|
|||||||
FROM ops_error_logs
|
FROM ops_error_logs
|
||||||
WHERE created_at >= $1 AND created_at < $2
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
AND COALESCE(status_code, 0) >= 400
|
AND COALESCE(status_code, 0) >= 400
|
||||||
|
AND is_count_tokens = FALSE -- 排除 count_tokens 请求的错误
|
||||||
GROUP BY 1
|
GROUP BY 1
|
||||||
),
|
),
|
||||||
combined AS (
|
combined AS (
|
||||||
@@ -243,6 +244,7 @@ error_totals AS (
|
|||||||
AND platform = $3
|
AND platform = $3
|
||||||
AND group_id IS NOT NULL
|
AND group_id IS NOT NULL
|
||||||
AND COALESCE(status_code, 0) >= 400
|
AND COALESCE(status_code, 0) >= 400
|
||||||
|
AND is_count_tokens = FALSE -- 排除 count_tokens 请求的错误
|
||||||
GROUP BY 1
|
GROUP BY 1
|
||||||
),
|
),
|
||||||
combined AS (
|
combined AS (
|
||||||
|
|||||||
74
backend/internal/repository/proxy_latency_cache.go
Normal file
74
backend/internal/repository/proxy_latency_cache.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const proxyLatencyKeyPrefix = "proxy:latency:"
|
||||||
|
|
||||||
|
func proxyLatencyKey(proxyID int64) string {
|
||||||
|
return fmt.Sprintf("%s%d", proxyLatencyKeyPrefix, proxyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
type proxyLatencyCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewProxyLatencyCache(rdb *redis.Client) service.ProxyLatencyCache {
|
||||||
|
return &proxyLatencyCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *proxyLatencyCache) GetProxyLatencies(ctx context.Context, proxyIDs []int64) (map[int64]*service.ProxyLatencyInfo, error) {
|
||||||
|
results := make(map[int64]*service.ProxyLatencyInfo)
|
||||||
|
if len(proxyIDs) == 0 {
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := make([]string, 0, len(proxyIDs))
|
||||||
|
for _, id := range proxyIDs {
|
||||||
|
keys = append(keys, proxyLatencyKey(id))
|
||||||
|
}
|
||||||
|
|
||||||
|
values, err := c.rdb.MGet(ctx, keys...).Result()
|
||||||
|
if err != nil {
|
||||||
|
return results, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, raw := range values {
|
||||||
|
if raw == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var payload []byte
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case string:
|
||||||
|
payload = []byte(v)
|
||||||
|
case []byte:
|
||||||
|
payload = v
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var info service.ProxyLatencyInfo
|
||||||
|
if err := json.Unmarshal(payload, &info); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
results[proxyIDs[i]] = &info
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *proxyLatencyCache) SetProxyLatency(ctx context.Context, proxyID int64, info *service.ProxyLatencyInfo) error {
|
||||||
|
if info == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
payload, err := json.Marshal(info)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.rdb.Set(ctx, proxyLatencyKey(proxyID), payload, 0).Err()
|
||||||
|
}
|
||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
@@ -34,7 +35,10 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const defaultIPInfoURL = "https://ipinfo.io/json"
|
const (
|
||||||
|
defaultIPInfoURL = "http://ip-api.com/json/?lang=zh-CN"
|
||||||
|
defaultProxyProbeTimeout = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
type proxyProbeService struct {
|
type proxyProbeService struct {
|
||||||
ipInfoURL string
|
ipInfoURL string
|
||||||
@@ -46,7 +50,7 @@ type proxyProbeService struct {
|
|||||||
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
||||||
client, err := httpclient.GetClient(httpclient.Options{
|
client, err := httpclient.GetClient(httpclient.Options{
|
||||||
ProxyURL: proxyURL,
|
ProxyURL: proxyURL,
|
||||||
Timeout: 15 * time.Second,
|
Timeout: defaultProxyProbeTimeout,
|
||||||
InsecureSkipVerify: s.insecureSkipVerify,
|
InsecureSkipVerify: s.insecureSkipVerify,
|
||||||
ProxyStrict: true,
|
ProxyStrict: true,
|
||||||
ValidateResolvedIP: s.validateResolvedIP,
|
ValidateResolvedIP: s.validateResolvedIP,
|
||||||
@@ -75,10 +79,14 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
|||||||
}
|
}
|
||||||
|
|
||||||
var ipInfo struct {
|
var ipInfo struct {
|
||||||
IP string `json:"ip"`
|
Status string `json:"status"`
|
||||||
City string `json:"city"`
|
Message string `json:"message"`
|
||||||
Region string `json:"region"`
|
Query string `json:"query"`
|
||||||
Country string `json:"country"`
|
City string `json:"city"`
|
||||||
|
Region string `json:"region"`
|
||||||
|
RegionName string `json:"regionName"`
|
||||||
|
Country string `json:"country"`
|
||||||
|
CountryCode string `json:"countryCode"`
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
@@ -89,11 +97,22 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
|||||||
if err := json.Unmarshal(body, &ipInfo); err != nil {
|
if err := json.Unmarshal(body, &ipInfo); err != nil {
|
||||||
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
|
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
|
||||||
}
|
}
|
||||||
|
if strings.ToLower(ipInfo.Status) != "success" {
|
||||||
|
if ipInfo.Message == "" {
|
||||||
|
ipInfo.Message = "ip-api request failed"
|
||||||
|
}
|
||||||
|
return nil, latencyMs, fmt.Errorf("ip-api request failed: %s", ipInfo.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
region := ipInfo.RegionName
|
||||||
|
if region == "" {
|
||||||
|
region = ipInfo.Region
|
||||||
|
}
|
||||||
return &service.ProxyExitInfo{
|
return &service.ProxyExitInfo{
|
||||||
IP: ipInfo.IP,
|
IP: ipInfo.Query,
|
||||||
City: ipInfo.City,
|
City: ipInfo.City,
|
||||||
Region: ipInfo.Region,
|
Region: region,
|
||||||
Country: ipInfo.Country,
|
Country: ipInfo.Country,
|
||||||
|
CountryCode: ipInfo.CountryCode,
|
||||||
}, latencyMs, nil
|
}, latencyMs, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ type ProxyProbeServiceSuite struct {
|
|||||||
func (s *ProxyProbeServiceSuite) SetupTest() {
|
func (s *ProxyProbeServiceSuite) SetupTest() {
|
||||||
s.ctx = context.Background()
|
s.ctx = context.Background()
|
||||||
s.prober = &proxyProbeService{
|
s.prober = &proxyProbeService{
|
||||||
ipInfoURL: "http://ipinfo.test/json",
|
ipInfoURL: "http://ip-api.test/json/?lang=zh-CN",
|
||||||
allowPrivateHosts: true,
|
allowPrivateHosts: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -54,7 +54,7 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
|
|||||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
seen <- r.RequestURI
|
seen <- r.RequestURI
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
_, _ = io.WriteString(w, `{"ip":"1.2.3.4","city":"c","region":"r","country":"cc"}`)
|
_, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||||
@@ -64,11 +64,12 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
|
|||||||
require.Equal(s.T(), "c", info.City)
|
require.Equal(s.T(), "c", info.City)
|
||||||
require.Equal(s.T(), "r", info.Region)
|
require.Equal(s.T(), "r", info.Region)
|
||||||
require.Equal(s.T(), "cc", info.Country)
|
require.Equal(s.T(), "cc", info.Country)
|
||||||
|
require.Equal(s.T(), "CC", info.CountryCode)
|
||||||
|
|
||||||
// Verify proxy received the request
|
// Verify proxy received the request
|
||||||
select {
|
select {
|
||||||
case uri := <-seen:
|
case uri := <-seen:
|
||||||
require.Contains(s.T(), uri, "ipinfo.test", "expected request to go through proxy")
|
require.Contains(s.T(), uri, "ip-api.test", "expected request to go through proxy")
|
||||||
default:
|
default:
|
||||||
require.Fail(s.T(), "expected proxy to receive request")
|
require.Fail(s.T(), "expected proxy to receive request")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -219,12 +219,54 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
|
|||||||
// CountAccountsByProxyID returns the number of accounts using a specific proxy
|
// CountAccountsByProxyID returns the number of accounts using a specific proxy
|
||||||
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
||||||
var count int64
|
var count int64
|
||||||
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1", []any{proxyID}, &count); err != nil {
|
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1 AND deleted_at IS NULL", []any{proxyID}, &count); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return count, nil
|
return count, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *proxyRepository) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
|
||||||
|
rows, err := r.sql.QueryContext(ctx, `
|
||||||
|
SELECT id, name, platform, type, notes
|
||||||
|
FROM accounts
|
||||||
|
WHERE proxy_id = $1 AND deleted_at IS NULL
|
||||||
|
ORDER BY id DESC
|
||||||
|
`, proxyID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
out := make([]service.ProxyAccountSummary, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var (
|
||||||
|
id int64
|
||||||
|
name string
|
||||||
|
platform string
|
||||||
|
accType string
|
||||||
|
notes sql.NullString
|
||||||
|
)
|
||||||
|
if err := rows.Scan(&id, &name, &platform, &accType, ¬es); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var notesPtr *string
|
||||||
|
if notes.Valid {
|
||||||
|
notesPtr = ¬es.String
|
||||||
|
}
|
||||||
|
out = append(out, service.ProxyAccountSummary{
|
||||||
|
ID: id,
|
||||||
|
Name: name,
|
||||||
|
Platform: platform,
|
||||||
|
Type: accType,
|
||||||
|
Notes: notesPtr,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
|
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
|
||||||
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (counts map[int64]int64, err error) {
|
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (counts map[int64]int64, err error) {
|
||||||
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL AND deleted_at IS NULL GROUP BY proxy_id")
|
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL AND deleted_at IS NULL GROUP BY proxy_id")
|
||||||
|
|||||||
276
backend/internal/repository/scheduler_cache.go
Normal file
276
backend/internal/repository/scheduler_cache.go
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
schedulerBucketSetKey = "sched:buckets"
|
||||||
|
schedulerOutboxWatermarkKey = "sched:outbox:watermark"
|
||||||
|
schedulerAccountPrefix = "sched:acc:"
|
||||||
|
schedulerActivePrefix = "sched:active:"
|
||||||
|
schedulerReadyPrefix = "sched:ready:"
|
||||||
|
schedulerVersionPrefix = "sched:ver:"
|
||||||
|
schedulerSnapshotPrefix = "sched:"
|
||||||
|
schedulerLockPrefix = "sched:lock:"
|
||||||
|
)
|
||||||
|
|
||||||
|
type schedulerCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSchedulerCache(rdb *redis.Client) service.SchedulerCache {
|
||||||
|
return &schedulerCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
|
||||||
|
readyKey := schedulerBucketKey(schedulerReadyPrefix, bucket)
|
||||||
|
readyVal, err := c.rdb.Get(ctx, readyKey).Result()
|
||||||
|
if err == redis.Nil {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
if readyVal != "1" {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
|
||||||
|
activeVal, err := c.rdb.Get(ctx, activeKey).Result()
|
||||||
|
if err == redis.Nil {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
snapshotKey := schedulerSnapshotKey(bucket, activeVal)
|
||||||
|
ids, err := c.rdb.ZRange(ctx, snapshotKey, 0, -1).Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return []*service.Account{}, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := make([]string, 0, len(ids))
|
||||||
|
for _, id := range ids {
|
||||||
|
keys = append(keys, schedulerAccountKey(id))
|
||||||
|
}
|
||||||
|
values, err := c.rdb.MGet(ctx, keys...).Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
accounts := make([]*service.Account, 0, len(values))
|
||||||
|
for _, val := range values {
|
||||||
|
if val == nil {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
account, err := decodeCachedAccount(val)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
accounts = append(accounts, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
return accounts, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
|
||||||
|
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
|
||||||
|
oldActive, _ := c.rdb.Get(ctx, activeKey).Result()
|
||||||
|
|
||||||
|
versionKey := schedulerBucketKey(schedulerVersionPrefix, bucket)
|
||||||
|
version, err := c.rdb.Incr(ctx, versionKey).Result()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
versionStr := strconv.FormatInt(version, 10)
|
||||||
|
snapshotKey := schedulerSnapshotKey(bucket, versionStr)
|
||||||
|
|
||||||
|
pipe := c.rdb.Pipeline()
|
||||||
|
for _, account := range accounts {
|
||||||
|
payload, err := json.Marshal(account)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
pipe.Set(ctx, schedulerAccountKey(strconv.FormatInt(account.ID, 10)), payload, 0)
|
||||||
|
}
|
||||||
|
if len(accounts) > 0 {
|
||||||
|
// 使用序号作为 score,保持数据库返回的排序语义。
|
||||||
|
members := make([]redis.Z, 0, len(accounts))
|
||||||
|
for idx, account := range accounts {
|
||||||
|
members = append(members, redis.Z{
|
||||||
|
Score: float64(idx),
|
||||||
|
Member: strconv.FormatInt(account.ID, 10),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
pipe.ZAdd(ctx, snapshotKey, members...)
|
||||||
|
} else {
|
||||||
|
pipe.Del(ctx, snapshotKey)
|
||||||
|
}
|
||||||
|
pipe.Set(ctx, activeKey, versionStr, 0)
|
||||||
|
pipe.Set(ctx, schedulerBucketKey(schedulerReadyPrefix, bucket), "1", 0)
|
||||||
|
pipe.SAdd(ctx, schedulerBucketSetKey, bucket.String())
|
||||||
|
if _, err := pipe.Exec(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if oldActive != "" && oldActive != versionStr {
|
||||||
|
_ = c.rdb.Del(ctx, schedulerSnapshotKey(bucket, oldActive)).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *schedulerCache) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
|
||||||
|
key := schedulerAccountKey(strconv.FormatInt(accountID, 10))
|
||||||
|
val, err := c.rdb.Get(ctx, key).Result()
|
||||||
|
if err == redis.Nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return decodeCachedAccount(val)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *schedulerCache) SetAccount(ctx context.Context, account *service.Account) error {
|
||||||
|
if account == nil || account.ID <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
payload, err := json.Marshal(account)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
key := schedulerAccountKey(strconv.FormatInt(account.ID, 10))
|
||||||
|
return c.rdb.Set(ctx, key, payload, 0).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *schedulerCache) DeleteAccount(ctx context.Context, accountID int64) error {
|
||||||
|
if accountID <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
key := schedulerAccountKey(strconv.FormatInt(accountID, 10))
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||||
|
if len(updates) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := make([]string, 0, len(updates))
|
||||||
|
ids := make([]int64, 0, len(updates))
|
||||||
|
for id := range updates {
|
||||||
|
keys = append(keys, schedulerAccountKey(strconv.FormatInt(id, 10)))
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
values, err := c.rdb.MGet(ctx, keys...).Result()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
pipe := c.rdb.Pipeline()
|
||||||
|
for i, val := range values {
|
||||||
|
if val == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
account, err := decodeCachedAccount(val)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
account.LastUsedAt = ptrTime(updates[ids[i]])
|
||||||
|
updated, err := json.Marshal(account)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
pipe.Set(ctx, keys[i], updated, 0)
|
||||||
|
}
|
||||||
|
_, err = pipe.Exec(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *schedulerCache) TryLockBucket(ctx context.Context, bucket service.SchedulerBucket, ttl time.Duration) (bool, error) {
|
||||||
|
key := schedulerBucketKey(schedulerLockPrefix, bucket)
|
||||||
|
return c.rdb.SetNX(ctx, key, time.Now().UnixNano(), ttl).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *schedulerCache) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
|
||||||
|
raw, err := c.rdb.SMembers(ctx, schedulerBucketSetKey).Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
out := make([]service.SchedulerBucket, 0, len(raw))
|
||||||
|
for _, entry := range raw {
|
||||||
|
bucket, ok := service.ParseSchedulerBucket(entry)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, bucket)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *schedulerCache) GetOutboxWatermark(ctx context.Context) (int64, error) {
|
||||||
|
val, err := c.rdb.Get(ctx, schedulerOutboxWatermarkKey).Result()
|
||||||
|
if err == redis.Nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
id, err := strconv.ParseInt(val, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *schedulerCache) SetOutboxWatermark(ctx context.Context, id int64) error {
|
||||||
|
return c.rdb.Set(ctx, schedulerOutboxWatermarkKey, strconv.FormatInt(id, 10), 0).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func schedulerBucketKey(prefix string, bucket service.SchedulerBucket) string {
|
||||||
|
return fmt.Sprintf("%s%d:%s:%s", prefix, bucket.GroupID, bucket.Platform, bucket.Mode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func schedulerSnapshotKey(bucket service.SchedulerBucket, version string) string {
|
||||||
|
return fmt.Sprintf("%s%d:%s:%s:v%s", schedulerSnapshotPrefix, bucket.GroupID, bucket.Platform, bucket.Mode, version)
|
||||||
|
}
|
||||||
|
|
||||||
|
func schedulerAccountKey(id string) string {
|
||||||
|
return schedulerAccountPrefix + id
|
||||||
|
}
|
||||||
|
|
||||||
|
func ptrTime(t time.Time) *time.Time {
|
||||||
|
return &t
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeCachedAccount(val any) (*service.Account, error) {
|
||||||
|
var payload []byte
|
||||||
|
switch raw := val.(type) {
|
||||||
|
case string:
|
||||||
|
payload = []byte(raw)
|
||||||
|
case []byte:
|
||||||
|
payload = raw
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unexpected account cache type: %T", val)
|
||||||
|
}
|
||||||
|
var account service.Account
|
||||||
|
if err := json.Unmarshal(payload, &account); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &account, nil
|
||||||
|
}
|
||||||
96
backend/internal/repository/scheduler_outbox_repo.go
Normal file
96
backend/internal/repository/scheduler_outbox_repo.go
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type schedulerOutboxRepository struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSchedulerOutboxRepository(db *sql.DB) service.SchedulerOutboxRepository {
|
||||||
|
return &schedulerOutboxRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *schedulerOutboxRepository) ListAfter(ctx context.Context, afterID int64, limit int) ([]service.SchedulerOutboxEvent, error) {
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 100
|
||||||
|
}
|
||||||
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
|
SELECT id, event_type, account_id, group_id, payload, created_at
|
||||||
|
FROM scheduler_outbox
|
||||||
|
WHERE id > $1
|
||||||
|
ORDER BY id ASC
|
||||||
|
LIMIT $2
|
||||||
|
`, afterID, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = rows.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
events := make([]service.SchedulerOutboxEvent, 0, limit)
|
||||||
|
for rows.Next() {
|
||||||
|
var (
|
||||||
|
payloadRaw []byte
|
||||||
|
accountID sql.NullInt64
|
||||||
|
groupID sql.NullInt64
|
||||||
|
event service.SchedulerOutboxEvent
|
||||||
|
)
|
||||||
|
if err := rows.Scan(&event.ID, &event.EventType, &accountID, &groupID, &payloadRaw, &event.CreatedAt); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if accountID.Valid {
|
||||||
|
v := accountID.Int64
|
||||||
|
event.AccountID = &v
|
||||||
|
}
|
||||||
|
if groupID.Valid {
|
||||||
|
v := groupID.Int64
|
||||||
|
event.GroupID = &v
|
||||||
|
}
|
||||||
|
if len(payloadRaw) > 0 {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(payloadRaw, &payload); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
event.Payload = payload
|
||||||
|
}
|
||||||
|
events = append(events, event)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return events, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *schedulerOutboxRepository) MaxID(ctx context.Context) (int64, error) {
|
||||||
|
var maxID int64
|
||||||
|
if err := r.db.QueryRowContext(ctx, "SELECT COALESCE(MAX(id), 0) FROM scheduler_outbox").Scan(&maxID); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return maxID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func enqueueSchedulerOutbox(ctx context.Context, exec sqlExecutor, eventType string, accountID *int64, groupID *int64, payload any) error {
|
||||||
|
if exec == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var payloadArg any
|
||||||
|
if payload != nil {
|
||||||
|
encoded, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
payloadArg = encoded
|
||||||
|
}
|
||||||
|
_, err := exec.ExecContext(ctx, `
|
||||||
|
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
|
||||||
|
VALUES ($1, $2, $3, $4)
|
||||||
|
`, eventType, accountID, groupID, payloadArg)
|
||||||
|
return err
|
||||||
|
}
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
rdb := testRedis(t)
|
||||||
|
client := testEntClient(t)
|
||||||
|
|
||||||
|
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox")
|
||||||
|
|
||||||
|
accountRepo := newAccountRepositoryWithSQL(client, integrationDB)
|
||||||
|
outboxRepo := NewSchedulerOutboxRepository(integrationDB)
|
||||||
|
cache := NewSchedulerCache(rdb)
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
RunMode: config.RunModeStandard,
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
Scheduling: config.GatewaySchedulingConfig{
|
||||||
|
OutboxPollIntervalSeconds: 1,
|
||||||
|
FullRebuildIntervalSeconds: 0,
|
||||||
|
DbFallbackEnabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &service.Account{
|
||||||
|
Name: "outbox-replay-" + time.Now().Format("150405.000000"),
|
||||||
|
Platform: service.PlatformOpenAI,
|
||||||
|
Type: service.AccountTypeAPIKey,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 3,
|
||||||
|
Priority: 1,
|
||||||
|
Credentials: map[string]any{},
|
||||||
|
Extra: map[string]any{},
|
||||||
|
}
|
||||||
|
require.NoError(t, accountRepo.Create(ctx, account))
|
||||||
|
require.NoError(t, cache.SetAccount(ctx, account))
|
||||||
|
|
||||||
|
svc := service.NewSchedulerSnapshotService(cache, outboxRepo, accountRepo, nil, cfg)
|
||||||
|
svc.Start()
|
||||||
|
t.Cleanup(svc.Stop)
|
||||||
|
|
||||||
|
require.NoError(t, accountRepo.UpdateLastUsed(ctx, account.ID))
|
||||||
|
updated, err := accountRepo.GetByID(ctx, account.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, updated.LastUsedAt)
|
||||||
|
expectedUnix := updated.LastUsedAt.Unix()
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
cached, err := cache.GetAccount(ctx, account.ID)
|
||||||
|
if err != nil || cached == nil || cached.LastUsedAt == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return cached.LastUsedAt.Unix() == expectedUnix
|
||||||
|
}, 5*time.Second, 100*time.Millisecond)
|
||||||
|
}
|
||||||
80
backend/internal/repository/timeout_counter_cache.go
Normal file
80
backend/internal/repository/timeout_counter_cache.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const timeoutCounterPrefix = "timeout_count:account:"
|
||||||
|
|
||||||
|
// timeoutCounterIncrScript 使用 Lua 脚本原子性地增加计数并返回当前值
|
||||||
|
// 如果 key 不存在,则创建并设置过期时间
|
||||||
|
var timeoutCounterIncrScript = redis.NewScript(`
|
||||||
|
local key = KEYS[1]
|
||||||
|
local ttl = tonumber(ARGV[1])
|
||||||
|
|
||||||
|
local count = redis.call('INCR', key)
|
||||||
|
if count == 1 then
|
||||||
|
redis.call('EXPIRE', key, ttl)
|
||||||
|
end
|
||||||
|
|
||||||
|
return count
|
||||||
|
`)
|
||||||
|
|
||||||
|
type timeoutCounterCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTimeoutCounterCache 创建超时计数器缓存实例
|
||||||
|
func NewTimeoutCounterCache(rdb *redis.Client) service.TimeoutCounterCache {
|
||||||
|
return &timeoutCounterCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementTimeoutCount 增加账户的超时计数,返回当前计数值
|
||||||
|
// windowMinutes 是计数窗口时间(分钟),超过此时间计数器会自动重置
|
||||||
|
func (c *timeoutCounterCache) IncrementTimeoutCount(ctx context.Context, accountID int64, windowMinutes int) (int64, error) {
|
||||||
|
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
|
||||||
|
|
||||||
|
ttlSeconds := windowMinutes * 60
|
||||||
|
if ttlSeconds < 60 {
|
||||||
|
ttlSeconds = 60 // 最小1分钟
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := timeoutCounterIncrScript.Run(ctx, c.rdb, []string{key}, ttlSeconds).Int64()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("increment timeout count: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTimeoutCount 获取账户当前的超时计数
|
||||||
|
func (c *timeoutCounterCache) GetTimeoutCount(ctx context.Context, accountID int64) (int64, error) {
|
||||||
|
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
|
||||||
|
|
||||||
|
val, err := c.rdb.Get(ctx, key).Int64()
|
||||||
|
if err == redis.Nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("get timeout count: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetTimeoutCount 重置账户的超时计数
|
||||||
|
func (c *timeoutCounterCache) ResetTimeoutCount(ctx context.Context, accountID int64) error {
|
||||||
|
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTimeoutCountTTL 获取计数器剩余过期时间
|
||||||
|
func (c *timeoutCounterCache) GetTimeoutCountTTL(ctx context.Context, accountID int64) (time.Duration, error) {
|
||||||
|
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
|
||||||
|
return c.rdb.TTL(ctx, key).Result()
|
||||||
|
}
|
||||||
@@ -22,7 +22,7 @@ import (
|
|||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
)
|
)
|
||||||
|
|
||||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
|
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
|
||||||
|
|
||||||
type usageLogRepository struct {
|
type usageLogRepository struct {
|
||||||
client *dbent.Client
|
client *dbent.Client
|
||||||
@@ -105,6 +105,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
total_cost,
|
total_cost,
|
||||||
actual_cost,
|
actual_cost,
|
||||||
rate_multiplier,
|
rate_multiplier,
|
||||||
|
account_rate_multiplier,
|
||||||
billing_type,
|
billing_type,
|
||||||
stream,
|
stream,
|
||||||
duration_ms,
|
duration_ms,
|
||||||
@@ -120,7 +121,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
$8, $9, $10, $11,
|
$8, $9, $10, $11,
|
||||||
$12, $13,
|
$12, $13,
|
||||||
$14, $15, $16, $17, $18, $19,
|
$14, $15, $16, $17, $18, $19,
|
||||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29
|
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30
|
||||||
)
|
)
|
||||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
RETURNING id, created_at
|
RETURNING id, created_at
|
||||||
@@ -160,6 +161,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
log.TotalCost,
|
log.TotalCost,
|
||||||
log.ActualCost,
|
log.ActualCost,
|
||||||
rateMultiplier,
|
rateMultiplier,
|
||||||
|
log.AccountRateMultiplier,
|
||||||
log.BillingType,
|
log.BillingType,
|
||||||
log.Stream,
|
log.Stream,
|
||||||
duration,
|
duration,
|
||||||
@@ -270,13 +272,13 @@ type DashboardStats = usagestats.DashboardStats
|
|||||||
|
|
||||||
func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
||||||
stats := &DashboardStats{}
|
stats := &DashboardStats{}
|
||||||
now := time.Now().UTC()
|
now := timezone.Now()
|
||||||
todayUTC := truncateToDayUTC(now)
|
todayStart := timezone.Today()
|
||||||
|
|
||||||
if err := r.fillDashboardEntityStats(ctx, stats, todayUTC, now); err != nil {
|
if err := r.fillDashboardEntityStats(ctx, stats, todayStart, now); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := r.fillDashboardUsageStatsAggregated(ctx, stats, todayUTC, now); err != nil {
|
if err := r.fillDashboardUsageStatsAggregated(ctx, stats, todayStart, now); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -298,13 +300,13 @@ func (r *usageLogRepository) GetDashboardStatsWithRange(ctx context.Context, sta
|
|||||||
}
|
}
|
||||||
|
|
||||||
stats := &DashboardStats{}
|
stats := &DashboardStats{}
|
||||||
now := time.Now().UTC()
|
now := timezone.Now()
|
||||||
todayUTC := truncateToDayUTC(now)
|
todayStart := timezone.Today()
|
||||||
|
|
||||||
if err := r.fillDashboardEntityStats(ctx, stats, todayUTC, now); err != nil {
|
if err := r.fillDashboardEntityStats(ctx, stats, todayStart, now); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := r.fillDashboardUsageStatsFromUsageLogs(ctx, stats, startUTC, endUTC, todayUTC, now); err != nil {
|
if err := r.fillDashboardUsageStatsFromUsageLogs(ctx, stats, startUTC, endUTC, todayStart, now); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -455,7 +457,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
|
|||||||
FROM usage_dashboard_hourly
|
FROM usage_dashboard_hourly
|
||||||
WHERE bucket_start = $1
|
WHERE bucket_start = $1
|
||||||
`
|
`
|
||||||
hourStart := now.UTC().Truncate(time.Hour)
|
hourStart := now.In(timezone.Location()).Truncate(time.Hour)
|
||||||
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart}, &stats.HourlyActiveUsers); err != nil {
|
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart}, &stats.HourlyActiveUsers); err != nil {
|
||||||
if err != sql.ErrNoRows {
|
if err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
@@ -835,7 +837,9 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
|
|||||||
SELECT
|
SELECT
|
||||||
COUNT(*) as requests,
|
COUNT(*) as requests,
|
||||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||||
COALESCE(SUM(actual_cost), 0) as cost
|
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||||
|
COALESCE(SUM(total_cost), 0) as standard_cost,
|
||||||
|
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE account_id = $1 AND created_at >= $2
|
WHERE account_id = $1 AND created_at >= $2
|
||||||
`
|
`
|
||||||
@@ -849,6 +853,8 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
|
|||||||
&stats.Requests,
|
&stats.Requests,
|
||||||
&stats.Tokens,
|
&stats.Tokens,
|
||||||
&stats.Cost,
|
&stats.Cost,
|
||||||
|
&stats.StandardCost,
|
||||||
|
&stats.UserCost,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -861,7 +867,9 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
|
|||||||
SELECT
|
SELECT
|
||||||
COUNT(*) as requests,
|
COUNT(*) as requests,
|
||||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||||
COALESCE(SUM(actual_cost), 0) as cost
|
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||||
|
COALESCE(SUM(total_cost), 0) as standard_cost,
|
||||||
|
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE account_id = $1 AND created_at >= $2
|
WHERE account_id = $1 AND created_at >= $2
|
||||||
`
|
`
|
||||||
@@ -875,6 +883,8 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
|
|||||||
&stats.Requests,
|
&stats.Requests,
|
||||||
&stats.Tokens,
|
&stats.Tokens,
|
||||||
&stats.Cost,
|
&stats.Cost,
|
||||||
|
&stats.StandardCost,
|
||||||
|
&stats.UserCost,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1400,8 +1410,8 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
|
// GetUsageTrendWithFilters returns usage trend data with optional filters
|
||||||
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) (results []TrendDataPoint, err error) {
|
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) (results []TrendDataPoint, err error) {
|
||||||
dateFormat := "YYYY-MM-DD"
|
dateFormat := "YYYY-MM-DD"
|
||||||
if granularity == "hour" {
|
if granularity == "hour" {
|
||||||
dateFormat = "YYYY-MM-DD HH24:00"
|
dateFormat = "YYYY-MM-DD HH24:00"
|
||||||
@@ -1430,6 +1440,22 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
|||||||
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
|
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
|
||||||
args = append(args, apiKeyID)
|
args = append(args, apiKeyID)
|
||||||
}
|
}
|
||||||
|
if accountID > 0 {
|
||||||
|
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
|
||||||
|
args = append(args, accountID)
|
||||||
|
}
|
||||||
|
if groupID > 0 {
|
||||||
|
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
|
||||||
|
args = append(args, groupID)
|
||||||
|
}
|
||||||
|
if model != "" {
|
||||||
|
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
|
||||||
|
args = append(args, model)
|
||||||
|
}
|
||||||
|
if stream != nil {
|
||||||
|
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
|
||||||
|
args = append(args, *stream)
|
||||||
|
}
|
||||||
query += " GROUP BY date ORDER BY date ASC"
|
query += " GROUP BY date ORDER BY date ASC"
|
||||||
|
|
||||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||||
@@ -1452,9 +1478,15 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
|||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
|
// GetModelStatsWithFilters returns model statistics with optional filters
|
||||||
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) (results []ModelStat, err error) {
|
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) (results []ModelStat, err error) {
|
||||||
query := `
|
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||||
|
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
|
||||||
|
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||||
|
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||||
|
}
|
||||||
|
|
||||||
|
query := fmt.Sprintf(`
|
||||||
SELECT
|
SELECT
|
||||||
model,
|
model,
|
||||||
COUNT(*) as requests,
|
COUNT(*) as requests,
|
||||||
@@ -1462,10 +1494,10 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
|
|||||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||||
COALESCE(SUM(total_cost), 0) as cost,
|
COALESCE(SUM(total_cost), 0) as cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
%s
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE created_at >= $1 AND created_at < $2
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
`
|
`, actualCostExpr)
|
||||||
|
|
||||||
args := []any{startTime, endTime}
|
args := []any{startTime, endTime}
|
||||||
if userID > 0 {
|
if userID > 0 {
|
||||||
@@ -1480,6 +1512,14 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
|
|||||||
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
|
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
|
||||||
args = append(args, accountID)
|
args = append(args, accountID)
|
||||||
}
|
}
|
||||||
|
if groupID > 0 {
|
||||||
|
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
|
||||||
|
args = append(args, groupID)
|
||||||
|
}
|
||||||
|
if stream != nil {
|
||||||
|
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
|
||||||
|
args = append(args, *stream)
|
||||||
|
}
|
||||||
query += " GROUP BY model ORDER BY total_tokens DESC"
|
query += " GROUP BY model ORDER BY total_tokens DESC"
|
||||||
|
|
||||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||||
@@ -1587,12 +1627,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
|
|||||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||||
|
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
|
||||||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
%s
|
%s
|
||||||
`, buildWhere(conditions))
|
`, buildWhere(conditions))
|
||||||
|
|
||||||
stats := &UsageStats{}
|
stats := &UsageStats{}
|
||||||
|
var totalAccountCost float64
|
||||||
if err := scanSingleRow(
|
if err := scanSingleRow(
|
||||||
ctx,
|
ctx,
|
||||||
r.sql,
|
r.sql,
|
||||||
@@ -1604,10 +1646,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
|
|||||||
&stats.TotalCacheTokens,
|
&stats.TotalCacheTokens,
|
||||||
&stats.TotalCost,
|
&stats.TotalCost,
|
||||||
&stats.TotalActualCost,
|
&stats.TotalActualCost,
|
||||||
|
&totalAccountCost,
|
||||||
&stats.AverageDurationMs,
|
&stats.AverageDurationMs,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if filters.AccountID > 0 {
|
||||||
|
stats.TotalAccountCost = &totalAccountCost
|
||||||
|
}
|
||||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||||||
return stats, nil
|
return stats, nil
|
||||||
}
|
}
|
||||||
@@ -1634,7 +1680,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
|||||||
COUNT(*) as requests,
|
COUNT(*) as requests,
|
||||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||||
COALESCE(SUM(total_cost), 0) as cost,
|
COALESCE(SUM(total_cost), 0) as cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
|
||||||
|
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
|
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
|
||||||
GROUP BY date
|
GROUP BY date
|
||||||
@@ -1661,7 +1708,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
|||||||
var tokens int64
|
var tokens int64
|
||||||
var cost float64
|
var cost float64
|
||||||
var actualCost float64
|
var actualCost float64
|
||||||
if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost); err != nil {
|
var userCost float64
|
||||||
|
if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost, &userCost); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
t, _ := time.Parse("2006-01-02", date)
|
t, _ := time.Parse("2006-01-02", date)
|
||||||
@@ -1672,19 +1720,21 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
|||||||
Tokens: tokens,
|
Tokens: tokens,
|
||||||
Cost: cost,
|
Cost: cost,
|
||||||
ActualCost: actualCost,
|
ActualCost: actualCost,
|
||||||
|
UserCost: userCost,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if err = rows.Err(); err != nil {
|
if err = rows.Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var totalActualCost, totalStandardCost float64
|
var totalAccountCost, totalUserCost, totalStandardCost float64
|
||||||
var totalRequests, totalTokens int64
|
var totalRequests, totalTokens int64
|
||||||
var highestCostDay, highestRequestDay *AccountUsageHistory
|
var highestCostDay, highestRequestDay *AccountUsageHistory
|
||||||
|
|
||||||
for i := range history {
|
for i := range history {
|
||||||
h := &history[i]
|
h := &history[i]
|
||||||
totalActualCost += h.ActualCost
|
totalAccountCost += h.ActualCost
|
||||||
|
totalUserCost += h.UserCost
|
||||||
totalStandardCost += h.Cost
|
totalStandardCost += h.Cost
|
||||||
totalRequests += h.Requests
|
totalRequests += h.Requests
|
||||||
totalTokens += h.Tokens
|
totalTokens += h.Tokens
|
||||||
@@ -1711,11 +1761,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
|||||||
summary := AccountUsageSummary{
|
summary := AccountUsageSummary{
|
||||||
Days: daysCount,
|
Days: daysCount,
|
||||||
ActualDaysUsed: actualDaysUsed,
|
ActualDaysUsed: actualDaysUsed,
|
||||||
TotalCost: totalActualCost,
|
TotalCost: totalAccountCost,
|
||||||
|
TotalUserCost: totalUserCost,
|
||||||
TotalStandardCost: totalStandardCost,
|
TotalStandardCost: totalStandardCost,
|
||||||
TotalRequests: totalRequests,
|
TotalRequests: totalRequests,
|
||||||
TotalTokens: totalTokens,
|
TotalTokens: totalTokens,
|
||||||
AvgDailyCost: totalActualCost / float64(actualDaysUsed),
|
AvgDailyCost: totalAccountCost / float64(actualDaysUsed),
|
||||||
|
AvgDailyUserCost: totalUserCost / float64(actualDaysUsed),
|
||||||
AvgDailyRequests: float64(totalRequests) / float64(actualDaysUsed),
|
AvgDailyRequests: float64(totalRequests) / float64(actualDaysUsed),
|
||||||
AvgDailyTokens: float64(totalTokens) / float64(actualDaysUsed),
|
AvgDailyTokens: float64(totalTokens) / float64(actualDaysUsed),
|
||||||
AvgDurationMs: avgDuration,
|
AvgDurationMs: avgDuration,
|
||||||
@@ -1727,11 +1779,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
|||||||
summary.Today = &struct {
|
summary.Today = &struct {
|
||||||
Date string `json:"date"`
|
Date string `json:"date"`
|
||||||
Cost float64 `json:"cost"`
|
Cost float64 `json:"cost"`
|
||||||
|
UserCost float64 `json:"user_cost"`
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
Tokens int64 `json:"tokens"`
|
Tokens int64 `json:"tokens"`
|
||||||
}{
|
}{
|
||||||
Date: history[i].Date,
|
Date: history[i].Date,
|
||||||
Cost: history[i].ActualCost,
|
Cost: history[i].ActualCost,
|
||||||
|
UserCost: history[i].UserCost,
|
||||||
Requests: history[i].Requests,
|
Requests: history[i].Requests,
|
||||||
Tokens: history[i].Tokens,
|
Tokens: history[i].Tokens,
|
||||||
}
|
}
|
||||||
@@ -1744,11 +1798,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
|||||||
Date string `json:"date"`
|
Date string `json:"date"`
|
||||||
Label string `json:"label"`
|
Label string `json:"label"`
|
||||||
Cost float64 `json:"cost"`
|
Cost float64 `json:"cost"`
|
||||||
|
UserCost float64 `json:"user_cost"`
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
}{
|
}{
|
||||||
Date: highestCostDay.Date,
|
Date: highestCostDay.Date,
|
||||||
Label: highestCostDay.Label,
|
Label: highestCostDay.Label,
|
||||||
Cost: highestCostDay.ActualCost,
|
Cost: highestCostDay.ActualCost,
|
||||||
|
UserCost: highestCostDay.UserCost,
|
||||||
Requests: highestCostDay.Requests,
|
Requests: highestCostDay.Requests,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1759,15 +1815,17 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
|||||||
Label string `json:"label"`
|
Label string `json:"label"`
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
Cost float64 `json:"cost"`
|
Cost float64 `json:"cost"`
|
||||||
|
UserCost float64 `json:"user_cost"`
|
||||||
}{
|
}{
|
||||||
Date: highestRequestDay.Date,
|
Date: highestRequestDay.Date,
|
||||||
Label: highestRequestDay.Label,
|
Label: highestRequestDay.Label,
|
||||||
Requests: highestRequestDay.Requests,
|
Requests: highestRequestDay.Requests,
|
||||||
Cost: highestRequestDay.ActualCost,
|
Cost: highestRequestDay.ActualCost,
|
||||||
|
UserCost: highestRequestDay.UserCost,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID)
|
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
models = []ModelStat{}
|
models = []ModelStat{}
|
||||||
}
|
}
|
||||||
@@ -1994,36 +2052,37 @@ func (r *usageLogRepository) loadSubscriptions(ctx context.Context, ids []int64)
|
|||||||
|
|
||||||
func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, error) {
|
func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, error) {
|
||||||
var (
|
var (
|
||||||
id int64
|
id int64
|
||||||
userID int64
|
userID int64
|
||||||
apiKeyID int64
|
apiKeyID int64
|
||||||
accountID int64
|
accountID int64
|
||||||
requestID sql.NullString
|
requestID sql.NullString
|
||||||
model string
|
model string
|
||||||
groupID sql.NullInt64
|
groupID sql.NullInt64
|
||||||
subscriptionID sql.NullInt64
|
subscriptionID sql.NullInt64
|
||||||
inputTokens int
|
inputTokens int
|
||||||
outputTokens int
|
outputTokens int
|
||||||
cacheCreationTokens int
|
cacheCreationTokens int
|
||||||
cacheReadTokens int
|
cacheReadTokens int
|
||||||
cacheCreation5m int
|
cacheCreation5m int
|
||||||
cacheCreation1h int
|
cacheCreation1h int
|
||||||
inputCost float64
|
inputCost float64
|
||||||
outputCost float64
|
outputCost float64
|
||||||
cacheCreationCost float64
|
cacheCreationCost float64
|
||||||
cacheReadCost float64
|
cacheReadCost float64
|
||||||
totalCost float64
|
totalCost float64
|
||||||
actualCost float64
|
actualCost float64
|
||||||
rateMultiplier float64
|
rateMultiplier float64
|
||||||
billingType int16
|
accountRateMultiplier sql.NullFloat64
|
||||||
stream bool
|
billingType int16
|
||||||
durationMs sql.NullInt64
|
stream bool
|
||||||
firstTokenMs sql.NullInt64
|
durationMs sql.NullInt64
|
||||||
userAgent sql.NullString
|
firstTokenMs sql.NullInt64
|
||||||
ipAddress sql.NullString
|
userAgent sql.NullString
|
||||||
imageCount int
|
ipAddress sql.NullString
|
||||||
imageSize sql.NullString
|
imageCount int
|
||||||
createdAt time.Time
|
imageSize sql.NullString
|
||||||
|
createdAt time.Time
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := scanner.Scan(
|
if err := scanner.Scan(
|
||||||
@@ -2048,6 +2107,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
&totalCost,
|
&totalCost,
|
||||||
&actualCost,
|
&actualCost,
|
||||||
&rateMultiplier,
|
&rateMultiplier,
|
||||||
|
&accountRateMultiplier,
|
||||||
&billingType,
|
&billingType,
|
||||||
&stream,
|
&stream,
|
||||||
&durationMs,
|
&durationMs,
|
||||||
@@ -2080,6 +2140,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
TotalCost: totalCost,
|
TotalCost: totalCost,
|
||||||
ActualCost: actualCost,
|
ActualCost: actualCost,
|
||||||
RateMultiplier: rateMultiplier,
|
RateMultiplier: rateMultiplier,
|
||||||
|
AccountRateMultiplier: nullFloat64Ptr(accountRateMultiplier),
|
||||||
BillingType: int8(billingType),
|
BillingType: int8(billingType),
|
||||||
Stream: stream,
|
Stream: stream,
|
||||||
ImageCount: imageCount,
|
ImageCount: imageCount,
|
||||||
@@ -2186,6 +2247,14 @@ func nullInt(v *int) sql.NullInt64 {
|
|||||||
return sql.NullInt64{Int64: int64(*v), Valid: true}
|
return sql.NullInt64{Int64: int64(*v), Valid: true}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func nullFloat64Ptr(v sql.NullFloat64) *float64 {
|
||||||
|
if !v.Valid {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := v.Float64
|
||||||
|
return &out
|
||||||
|
}
|
||||||
|
|
||||||
func nullString(v *string) sql.NullString {
|
func nullString(v *string) sql.NullString {
|
||||||
if v == nil || *v == "" {
|
if v == nil || *v == "" {
|
||||||
return sql.NullString{}
|
return sql.NullString{}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
@@ -36,6 +37,12 @@ func TestUsageLogRepoSuite(t *testing.T) {
|
|||||||
suite.Run(t, new(UsageLogRepoSuite))
|
suite.Run(t, new(UsageLogRepoSuite))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// truncateToDayUTC 截断到 UTC 日期边界(测试辅助函数)
|
||||||
|
func truncateToDayUTC(t time.Time) time.Time {
|
||||||
|
t = t.UTC()
|
||||||
|
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.APIKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
|
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.APIKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
|
||||||
log := &service.UsageLog{
|
log := &service.UsageLog{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
@@ -95,6 +102,34 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
|
|||||||
s.Require().Error(err, "expected error for non-existent ID")
|
s.Require().Error(err, "expected error for non-existent ID")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestGetByID_ReturnsAccountRateMultiplier() {
|
||||||
|
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-mult@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-mult", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-mult"})
|
||||||
|
|
||||||
|
m := 0.5
|
||||||
|
log := &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: uuid.New().String(),
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 1.0,
|
||||||
|
ActualCost: 2.0,
|
||||||
|
AccountRateMultiplier: &m,
|
||||||
|
CreatedAt: timezone.Today().Add(2 * time.Hour),
|
||||||
|
}
|
||||||
|
_, err := s.repo.Create(s.ctx, log)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, log.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().NotNil(got.AccountRateMultiplier)
|
||||||
|
s.Require().InEpsilon(0.5, *got.AccountRateMultiplier, 0.0001)
|
||||||
|
}
|
||||||
|
|
||||||
// --- Delete ---
|
// --- Delete ---
|
||||||
|
|
||||||
func (s *UsageLogRepoSuite) TestDelete() {
|
func (s *UsageLogRepoSuite) TestDelete() {
|
||||||
@@ -403,12 +438,49 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
|
|||||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
|
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
|
||||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"})
|
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"})
|
||||||
|
|
||||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
createdAt := timezone.Today().Add(1 * time.Hour)
|
||||||
|
|
||||||
|
m1 := 1.5
|
||||||
|
m2 := 0.0
|
||||||
|
_, err := s.repo.Create(s.ctx, &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: uuid.New().String(),
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 1.0,
|
||||||
|
ActualCost: 2.0,
|
||||||
|
AccountRateMultiplier: &m1,
|
||||||
|
CreatedAt: createdAt,
|
||||||
|
})
|
||||||
|
s.Require().NoError(err)
|
||||||
|
_, err = s.repo.Create(s.ctx, &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: uuid.New().String(),
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 5,
|
||||||
|
OutputTokens: 5,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 1.0,
|
||||||
|
AccountRateMultiplier: &m2,
|
||||||
|
CreatedAt: createdAt,
|
||||||
|
})
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
stats, err := s.repo.GetAccountTodayStats(s.ctx, account.ID)
|
stats, err := s.repo.GetAccountTodayStats(s.ctx, account.ID)
|
||||||
s.Require().NoError(err, "GetAccountTodayStats")
|
s.Require().NoError(err, "GetAccountTodayStats")
|
||||||
s.Require().Equal(int64(1), stats.Requests)
|
s.Require().Equal(int64(2), stats.Requests)
|
||||||
s.Require().Equal(int64(30), stats.Tokens)
|
s.Require().Equal(int64(40), stats.Tokens)
|
||||||
|
// account cost = SUM(total_cost * account_rate_multiplier)
|
||||||
|
s.Require().InEpsilon(1.5, stats.Cost, 0.0001)
|
||||||
|
// standard cost = SUM(total_cost)
|
||||||
|
s.Require().InEpsilon(1.5, stats.StandardCost, 0.0001)
|
||||||
|
// user cost = SUM(actual_cost)
|
||||||
|
s.Require().InEpsilon(3.0, stats.UserCost, 0.0001)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() {
|
func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() {
|
||||||
@@ -416,8 +488,8 @@ func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() {
|
|||||||
// 使用固定的时间偏移确保 hour1 和 hour2 在同一天且都在过去
|
// 使用固定的时间偏移确保 hour1 和 hour2 在同一天且都在过去
|
||||||
// 选择当天 02:00 和 03:00 作为测试时间点(基于 now 的日期)
|
// 选择当天 02:00 和 03:00 作为测试时间点(基于 now 的日期)
|
||||||
dayStart := truncateToDayUTC(now)
|
dayStart := truncateToDayUTC(now)
|
||||||
hour1 := dayStart.Add(2 * time.Hour) // 当天 02:00
|
hour1 := dayStart.Add(2 * time.Hour) // 当天 02:00
|
||||||
hour2 := dayStart.Add(3 * time.Hour) // 当天 03:00
|
hour2 := dayStart.Add(3 * time.Hour) // 当天 03:00
|
||||||
// 如果当前时间早于 hour2,则使用昨天的时间
|
// 如果当前时间早于 hour2,则使用昨天的时间
|
||||||
if now.Before(hour2.Add(time.Hour)) {
|
if now.Before(hour2.Add(time.Hour)) {
|
||||||
dayStart = dayStart.Add(-24 * time.Hour)
|
dayStart = dayStart.Add(-24 * time.Hour)
|
||||||
@@ -872,17 +944,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
|
|||||||
endTime := base.Add(48 * time.Hour)
|
endTime := base.Add(48 * time.Hour)
|
||||||
|
|
||||||
// Test with user filter
|
// Test with user filter
|
||||||
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0)
|
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil)
|
||||||
s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
|
s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
|
||||||
s.Require().Len(trend, 2)
|
s.Require().Len(trend, 2)
|
||||||
|
|
||||||
// Test with apiKey filter
|
// Test with apiKey filter
|
||||||
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID)
|
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil)
|
||||||
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
|
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
|
||||||
s.Require().Len(trend, 2)
|
s.Require().Len(trend, 2)
|
||||||
|
|
||||||
// Test with both filters
|
// Test with both filters
|
||||||
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID)
|
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil)
|
||||||
s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
|
s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
|
||||||
s.Require().Len(trend, 2)
|
s.Require().Len(trend, 2)
|
||||||
}
|
}
|
||||||
@@ -899,7 +971,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
|
|||||||
startTime := base.Add(-1 * time.Hour)
|
startTime := base.Add(-1 * time.Hour)
|
||||||
endTime := base.Add(3 * time.Hour)
|
endTime := base.Add(3 * time.Hour)
|
||||||
|
|
||||||
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0)
|
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil)
|
||||||
s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
|
s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
|
||||||
s.Require().Len(trend, 2)
|
s.Require().Len(trend, 2)
|
||||||
}
|
}
|
||||||
@@ -945,17 +1017,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
|
|||||||
endTime := base.Add(2 * time.Hour)
|
endTime := base.Add(2 * time.Hour)
|
||||||
|
|
||||||
// Test with user filter
|
// Test with user filter
|
||||||
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0)
|
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil)
|
||||||
s.Require().NoError(err, "GetModelStatsWithFilters user filter")
|
s.Require().NoError(err, "GetModelStatsWithFilters user filter")
|
||||||
s.Require().Len(stats, 2)
|
s.Require().Len(stats, 2)
|
||||||
|
|
||||||
// Test with apiKey filter
|
// Test with apiKey filter
|
||||||
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0)
|
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil)
|
||||||
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
|
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
|
||||||
s.Require().Len(stats, 2)
|
s.Require().Len(stats, 2)
|
||||||
|
|
||||||
// Test with account filter
|
// Test with account filter
|
||||||
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID)
|
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil)
|
||||||
s.Require().NoError(err, "GetModelStatsWithFilters account filter")
|
s.Require().NoError(err, "GetModelStatsWithFilters account filter")
|
||||||
s.Require().Len(stats, 2)
|
s.Require().Len(stats, 2)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewBillingCache,
|
NewBillingCache,
|
||||||
NewAPIKeyCache,
|
NewAPIKeyCache,
|
||||||
NewTempUnschedCache,
|
NewTempUnschedCache,
|
||||||
|
NewTimeoutCounterCache,
|
||||||
ProvideConcurrencyCache,
|
ProvideConcurrencyCache,
|
||||||
NewDashboardCache,
|
NewDashboardCache,
|
||||||
NewEmailCache,
|
NewEmailCache,
|
||||||
@@ -66,6 +67,9 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewRedeemCache,
|
NewRedeemCache,
|
||||||
NewUpdateCache,
|
NewUpdateCache,
|
||||||
NewGeminiTokenCache,
|
NewGeminiTokenCache,
|
||||||
|
NewSchedulerCache,
|
||||||
|
NewSchedulerOutboxRepository,
|
||||||
|
NewProxyLatencyCache,
|
||||||
|
|
||||||
// HTTP service ports (DI Strategy A: return interface directly)
|
// HTTP service ports (DI Strategy A: return interface directly)
|
||||||
NewTurnstileVerifier,
|
NewTurnstileVerifier,
|
||||||
|
|||||||
@@ -239,9 +239,10 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"cache_creation_cost": 0,
|
"cache_creation_cost": 0,
|
||||||
"cache_read_cost": 0,
|
"cache_read_cost": 0,
|
||||||
"total_cost": 0.5,
|
"total_cost": 0.5,
|
||||||
"actual_cost": 0.5,
|
"actual_cost": 0.5,
|
||||||
"rate_multiplier": 1,
|
"rate_multiplier": 1,
|
||||||
"billing_type": 0,
|
"account_rate_multiplier": null,
|
||||||
|
"billing_type": 0,
|
||||||
"stream": true,
|
"stream": true,
|
||||||
"duration_ms": 100,
|
"duration_ms": 100,
|
||||||
"first_token_ms": 50,
|
"first_token_ms": 50,
|
||||||
@@ -262,11 +263,11 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
name: "GET /api/v1/admin/settings",
|
name: "GET /api/v1/admin/settings",
|
||||||
setup: func(t *testing.T, deps *contractDeps) {
|
setup: func(t *testing.T, deps *contractDeps) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
deps.settingRepo.SetAll(map[string]string{
|
deps.settingRepo.SetAll(map[string]string{
|
||||||
service.SettingKeyRegistrationEnabled: "true",
|
service.SettingKeyRegistrationEnabled: "true",
|
||||||
service.SettingKeyEmailVerifyEnabled: "false",
|
service.SettingKeyEmailVerifyEnabled: "false",
|
||||||
|
|
||||||
service.SettingKeySMTPHost: "smtp.example.com",
|
service.SettingKeySMTPHost: "smtp.example.com",
|
||||||
service.SettingKeySMTPPort: "587",
|
service.SettingKeySMTPPort: "587",
|
||||||
service.SettingKeySMTPUsername: "user",
|
service.SettingKeySMTPUsername: "user",
|
||||||
service.SettingKeySMTPPassword: "secret",
|
service.SettingKeySMTPPassword: "secret",
|
||||||
@@ -285,15 +286,15 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
service.SettingKeyContactInfo: "support",
|
service.SettingKeyContactInfo: "support",
|
||||||
service.SettingKeyDocURL: "https://docs.example.com",
|
service.SettingKeyDocURL: "https://docs.example.com",
|
||||||
|
|
||||||
service.SettingKeyDefaultConcurrency: "5",
|
service.SettingKeyDefaultConcurrency: "5",
|
||||||
service.SettingKeyDefaultBalance: "1.25",
|
service.SettingKeyDefaultBalance: "1.25",
|
||||||
|
|
||||||
service.SettingKeyOpsMonitoringEnabled: "false",
|
service.SettingKeyOpsMonitoringEnabled: "false",
|
||||||
service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
|
service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
|
||||||
service.SettingKeyOpsQueryModeDefault: "auto",
|
service.SettingKeyOpsQueryModeDefault: "auto",
|
||||||
service.SettingKeyOpsMetricsIntervalSeconds: "60",
|
service.SettingKeyOpsMetricsIntervalSeconds: "60",
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
path: "/api/v1/admin/settings",
|
path: "/api/v1/admin/settings",
|
||||||
wantStatus: http.StatusOK,
|
wantStatus: http.StatusOK,
|
||||||
@@ -435,7 +436,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
settingRepo := newStubSettingRepo()
|
settingRepo := newStubSettingRepo()
|
||||||
settingService := service.NewSettingService(settingRepo, cfg)
|
settingService := service.NewSettingService(settingRepo, cfg)
|
||||||
|
|
||||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil)
|
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
|
||||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
|
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||||
@@ -858,6 +859,10 @@ func (stubProxyRepo) CountAccountsByProxyID(ctx context.Context, proxyID int64)
|
|||||||
return 0, errors.New("not implemented")
|
return 0, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (stubProxyRepo) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
type stubRedeemCodeRepo struct{}
|
type stubRedeemCodeRepo struct{}
|
||||||
|
|
||||||
func (stubRedeemCodeRepo) Create(ctx context.Context, code *service.RedeemCode) error {
|
func (stubRedeemCodeRepo) Create(ctx context.Context, code *service.RedeemCode) error {
|
||||||
@@ -1229,11 +1234,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) {
|
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error) {
|
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
// Realtime ops signals
|
// Realtime ops signals
|
||||||
ops.GET("/concurrency", h.Admin.Ops.GetConcurrencyStats)
|
ops.GET("/concurrency", h.Admin.Ops.GetConcurrencyStats)
|
||||||
ops.GET("/account-availability", h.Admin.Ops.GetAccountAvailability)
|
ops.GET("/account-availability", h.Admin.Ops.GetAccountAvailability)
|
||||||
|
ops.GET("/realtime-traffic", h.Admin.Ops.GetRealtimeTrafficSummary)
|
||||||
|
|
||||||
// Alerts (rules + events)
|
// Alerts (rules + events)
|
||||||
ops.GET("/alert-rules", h.Admin.Ops.ListAlertRules)
|
ops.GET("/alert-rules", h.Admin.Ops.ListAlertRules)
|
||||||
@@ -80,6 +81,9 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
ops.PUT("/alert-rules/:id", h.Admin.Ops.UpdateAlertRule)
|
ops.PUT("/alert-rules/:id", h.Admin.Ops.UpdateAlertRule)
|
||||||
ops.DELETE("/alert-rules/:id", h.Admin.Ops.DeleteAlertRule)
|
ops.DELETE("/alert-rules/:id", h.Admin.Ops.DeleteAlertRule)
|
||||||
ops.GET("/alert-events", h.Admin.Ops.ListAlertEvents)
|
ops.GET("/alert-events", h.Admin.Ops.ListAlertEvents)
|
||||||
|
ops.GET("/alert-events/:id", h.Admin.Ops.GetAlertEvent)
|
||||||
|
ops.PUT("/alert-events/:id/status", h.Admin.Ops.UpdateAlertEventStatus)
|
||||||
|
ops.POST("/alert-silences", h.Admin.Ops.CreateAlertSilence)
|
||||||
|
|
||||||
// Email notification config (DB-backed)
|
// Email notification config (DB-backed)
|
||||||
ops.GET("/email-notification/config", h.Admin.Ops.GetEmailNotificationConfig)
|
ops.GET("/email-notification/config", h.Admin.Ops.GetEmailNotificationConfig)
|
||||||
@@ -96,16 +100,39 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
ops.GET("/advanced-settings", h.Admin.Ops.GetAdvancedSettings)
|
ops.GET("/advanced-settings", h.Admin.Ops.GetAdvancedSettings)
|
||||||
ops.PUT("/advanced-settings", h.Admin.Ops.UpdateAdvancedSettings)
|
ops.PUT("/advanced-settings", h.Admin.Ops.UpdateAdvancedSettings)
|
||||||
|
|
||||||
|
// Settings group (DB-backed)
|
||||||
|
settings := ops.Group("/settings")
|
||||||
|
{
|
||||||
|
settings.GET("/metric-thresholds", h.Admin.Ops.GetMetricThresholds)
|
||||||
|
settings.PUT("/metric-thresholds", h.Admin.Ops.UpdateMetricThresholds)
|
||||||
|
}
|
||||||
|
|
||||||
// WebSocket realtime (QPS/TPS)
|
// WebSocket realtime (QPS/TPS)
|
||||||
ws := ops.Group("/ws")
|
ws := ops.Group("/ws")
|
||||||
{
|
{
|
||||||
ws.GET("/qps", h.Admin.Ops.QPSWSHandler)
|
ws.GET("/qps", h.Admin.Ops.QPSWSHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error logs (MVP-1)
|
// Error logs (legacy)
|
||||||
ops.GET("/errors", h.Admin.Ops.GetErrorLogs)
|
ops.GET("/errors", h.Admin.Ops.GetErrorLogs)
|
||||||
ops.GET("/errors/:id", h.Admin.Ops.GetErrorLogByID)
|
ops.GET("/errors/:id", h.Admin.Ops.GetErrorLogByID)
|
||||||
|
ops.GET("/errors/:id/retries", h.Admin.Ops.ListRetryAttempts)
|
||||||
ops.POST("/errors/:id/retry", h.Admin.Ops.RetryErrorRequest)
|
ops.POST("/errors/:id/retry", h.Admin.Ops.RetryErrorRequest)
|
||||||
|
ops.PUT("/errors/:id/resolve", h.Admin.Ops.UpdateErrorResolution)
|
||||||
|
|
||||||
|
// Request errors (client-visible failures)
|
||||||
|
ops.GET("/request-errors", h.Admin.Ops.ListRequestErrors)
|
||||||
|
ops.GET("/request-errors/:id", h.Admin.Ops.GetRequestError)
|
||||||
|
ops.GET("/request-errors/:id/upstream-errors", h.Admin.Ops.ListRequestErrorUpstreamErrors)
|
||||||
|
ops.POST("/request-errors/:id/retry-client", h.Admin.Ops.RetryRequestErrorClient)
|
||||||
|
ops.POST("/request-errors/:id/upstream-errors/:idx/retry", h.Admin.Ops.RetryRequestErrorUpstreamEvent)
|
||||||
|
ops.PUT("/request-errors/:id/resolve", h.Admin.Ops.ResolveRequestError)
|
||||||
|
|
||||||
|
// Upstream errors (independent upstream failures)
|
||||||
|
ops.GET("/upstream-errors", h.Admin.Ops.ListUpstreamErrors)
|
||||||
|
ops.GET("/upstream-errors/:id", h.Admin.Ops.GetUpstreamError)
|
||||||
|
ops.POST("/upstream-errors/:id/retry", h.Admin.Ops.RetryUpstreamError)
|
||||||
|
ops.PUT("/upstream-errors/:id/resolve", h.Admin.Ops.ResolveUpstreamError)
|
||||||
|
|
||||||
// Request drilldown (success + error)
|
// Request drilldown (success + error)
|
||||||
ops.GET("/requests", h.Admin.Ops.ListRequestDetails)
|
ops.GET("/requests", h.Admin.Ops.ListRequestDetails)
|
||||||
@@ -242,6 +269,7 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
proxies.POST("/:id/test", h.Admin.Proxy.Test)
|
proxies.POST("/:id/test", h.Admin.Proxy.Test)
|
||||||
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
|
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
|
||||||
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
|
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
|
||||||
|
proxies.POST("/batch-delete", h.Admin.Proxy.BatchDelete)
|
||||||
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
|
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -283,6 +311,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
|
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
|
||||||
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
|
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
|
||||||
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey)
|
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey)
|
||||||
|
// 流超时处理配置
|
||||||
|
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
|
||||||
|
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,16 +9,19 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Account struct {
|
type Account struct {
|
||||||
ID int64
|
ID int64
|
||||||
Name string
|
Name string
|
||||||
Notes *string
|
Notes *string
|
||||||
Platform string
|
Platform string
|
||||||
Type string
|
Type string
|
||||||
Credentials map[string]any
|
Credentials map[string]any
|
||||||
Extra map[string]any
|
Extra map[string]any
|
||||||
ProxyID *int64
|
ProxyID *int64
|
||||||
Concurrency int
|
Concurrency int
|
||||||
Priority int
|
Priority int
|
||||||
|
// RateMultiplier 账号计费倍率(>=0,允许 0 表示该账号计费为 0)。
|
||||||
|
// 使用指针用于兼容旧版本调度缓存(Redis)中缺字段的情况:nil 表示按 1.0 处理。
|
||||||
|
RateMultiplier *float64
|
||||||
Status string
|
Status string
|
||||||
ErrorMessage string
|
ErrorMessage string
|
||||||
LastUsedAt *time.Time
|
LastUsedAt *time.Time
|
||||||
@@ -57,6 +60,20 @@ func (a *Account) IsActive() bool {
|
|||||||
return a.Status == StatusActive
|
return a.Status == StatusActive
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BillingRateMultiplier 返回账号计费倍率。
|
||||||
|
// - nil 表示未配置/旧缓存缺字段,按 1.0 处理
|
||||||
|
// - 允许 0,表示该账号计费为 0
|
||||||
|
// - 负数属于非法数据,出于安全考虑按 1.0 处理
|
||||||
|
func (a *Account) BillingRateMultiplier() float64 {
|
||||||
|
if a == nil || a.RateMultiplier == nil {
|
||||||
|
return 1.0
|
||||||
|
}
|
||||||
|
if *a.RateMultiplier < 0 {
|
||||||
|
return 1.0
|
||||||
|
}
|
||||||
|
return *a.RateMultiplier
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Account) IsSchedulable() bool {
|
func (a *Account) IsSchedulable() bool {
|
||||||
if !a.IsActive() || !a.Schedulable {
|
if !a.IsActive() || !a.Schedulable {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -0,0 +1,27 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAccount_BillingRateMultiplier_DefaultsToOneWhenNil(t *testing.T) {
|
||||||
|
var a Account
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(`{"id":1,"name":"acc","status":"active"}`), &a))
|
||||||
|
require.Nil(t, a.RateMultiplier)
|
||||||
|
require.Equal(t, 1.0, a.BillingRateMultiplier())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccount_BillingRateMultiplier_AllowsZero(t *testing.T) {
|
||||||
|
v := 0.0
|
||||||
|
a := Account{RateMultiplier: &v}
|
||||||
|
require.Equal(t, 0.0, a.BillingRateMultiplier())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccount_BillingRateMultiplier_NegativeFallsBackToOne(t *testing.T) {
|
||||||
|
v := -1.0
|
||||||
|
a := Account{RateMultiplier: &v}
|
||||||
|
require.Equal(t, 1.0, a.BillingRateMultiplier())
|
||||||
|
}
|
||||||
@@ -63,14 +63,15 @@ type AccountRepository interface {
|
|||||||
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
|
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
|
||||||
// Nil pointers mean "do not change".
|
// Nil pointers mean "do not change".
|
||||||
type AccountBulkUpdate struct {
|
type AccountBulkUpdate struct {
|
||||||
Name *string
|
Name *string
|
||||||
ProxyID *int64
|
ProxyID *int64
|
||||||
Concurrency *int
|
Concurrency *int
|
||||||
Priority *int
|
Priority *int
|
||||||
Status *string
|
RateMultiplier *float64
|
||||||
Schedulable *bool
|
Status *string
|
||||||
Credentials map[string]any
|
Schedulable *bool
|
||||||
Extra map[string]any
|
Credentials map[string]any
|
||||||
|
Extra map[string]any
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateAccountRequest 创建账号请求
|
// CreateAccountRequest 创建账号请求
|
||||||
|
|||||||
@@ -32,8 +32,8 @@ type UsageLogRepository interface {
|
|||||||
|
|
||||||
// Admin dashboard stats
|
// Admin dashboard stats
|
||||||
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
||||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
|
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error)
|
||||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
|
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error)
|
||||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
|
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||||
@@ -96,10 +96,16 @@ func NewUsageCache() *UsageCache {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// WindowStats 窗口期统计
|
// WindowStats 窗口期统计
|
||||||
|
//
|
||||||
|
// cost: 账号口径费用(total_cost * account_rate_multiplier)
|
||||||
|
// standard_cost: 标准费用(total_cost,不含倍率)
|
||||||
|
// user_cost: 用户/API Key 口径费用(actual_cost,受分组倍率影响)
|
||||||
type WindowStats struct {
|
type WindowStats struct {
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
Tokens int64 `json:"tokens"`
|
Tokens int64 `json:"tokens"`
|
||||||
Cost float64 `json:"cost"`
|
Cost float64 `json:"cost"`
|
||||||
|
StandardCost float64 `json:"standard_cost"`
|
||||||
|
UserCost float64 `json:"user_cost"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UsageProgress 使用量进度
|
// UsageProgress 使用量进度
|
||||||
@@ -266,7 +272,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
|
|||||||
}
|
}
|
||||||
|
|
||||||
dayStart := geminiDailyWindowStart(now)
|
dayStart := geminiDailyWindowStart(now)
|
||||||
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID)
|
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
|
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -288,7 +294,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
|
|||||||
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
|
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
|
||||||
minuteStart := now.Truncate(time.Minute)
|
minuteStart := now.Truncate(time.Minute)
|
||||||
minuteResetAt := minuteStart.Add(time.Minute)
|
minuteResetAt := minuteStart.Add(time.Minute)
|
||||||
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID)
|
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
|
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -377,9 +383,11 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
|
|||||||
}
|
}
|
||||||
|
|
||||||
windowStats = &WindowStats{
|
windowStats = &WindowStats{
|
||||||
Requests: stats.Requests,
|
Requests: stats.Requests,
|
||||||
Tokens: stats.Tokens,
|
Tokens: stats.Tokens,
|
||||||
Cost: stats.Cost,
|
Cost: stats.Cost,
|
||||||
|
StandardCost: stats.StandardCost,
|
||||||
|
UserCost: stats.UserCost,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 缓存窗口统计(1 分钟)
|
// 缓存窗口统计(1 分钟)
|
||||||
@@ -403,9 +411,11 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &WindowStats{
|
return &WindowStats{
|
||||||
Requests: stats.Requests,
|
Requests: stats.Requests,
|
||||||
Tokens: stats.Tokens,
|
Tokens: stats.Tokens,
|
||||||
Cost: stats.Cost,
|
Cost: stats.Cost,
|
||||||
|
StandardCost: stats.StandardCost,
|
||||||
|
UserCost: stats.UserCost,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -54,7 +54,8 @@ type AdminService interface {
|
|||||||
CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error)
|
CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error)
|
||||||
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error)
|
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error)
|
||||||
DeleteProxy(ctx context.Context, id int64) error
|
DeleteProxy(ctx context.Context, id int64) error
|
||||||
GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error)
|
BatchDeleteProxies(ctx context.Context, ids []int64) (*ProxyBatchDeleteResult, error)
|
||||||
|
GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error)
|
||||||
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
|
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
|
||||||
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
|
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
|
||||||
|
|
||||||
@@ -136,6 +137,7 @@ type CreateAccountInput struct {
|
|||||||
ProxyID *int64
|
ProxyID *int64
|
||||||
Concurrency int
|
Concurrency int
|
||||||
Priority int
|
Priority int
|
||||||
|
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
|
||||||
GroupIDs []int64
|
GroupIDs []int64
|
||||||
ExpiresAt *int64
|
ExpiresAt *int64
|
||||||
AutoPauseOnExpired *bool
|
AutoPauseOnExpired *bool
|
||||||
@@ -151,8 +153,9 @@ type UpdateAccountInput struct {
|
|||||||
Credentials map[string]any
|
Credentials map[string]any
|
||||||
Extra map[string]any
|
Extra map[string]any
|
||||||
ProxyID *int64
|
ProxyID *int64
|
||||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||||
Priority *int // 使用指针区分"未提供"和"设置为0"
|
Priority *int // 使用指针区分"未提供"和"设置为0"
|
||||||
|
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
|
||||||
Status string
|
Status string
|
||||||
GroupIDs *[]int64
|
GroupIDs *[]int64
|
||||||
ExpiresAt *int64
|
ExpiresAt *int64
|
||||||
@@ -162,16 +165,17 @@ type UpdateAccountInput struct {
|
|||||||
|
|
||||||
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
|
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
|
||||||
type BulkUpdateAccountsInput struct {
|
type BulkUpdateAccountsInput struct {
|
||||||
AccountIDs []int64
|
AccountIDs []int64
|
||||||
Name string
|
Name string
|
||||||
ProxyID *int64
|
ProxyID *int64
|
||||||
Concurrency *int
|
Concurrency *int
|
||||||
Priority *int
|
Priority *int
|
||||||
Status string
|
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
|
||||||
Schedulable *bool
|
Status string
|
||||||
GroupIDs *[]int64
|
Schedulable *bool
|
||||||
Credentials map[string]any
|
GroupIDs *[]int64
|
||||||
Extra map[string]any
|
Credentials map[string]any
|
||||||
|
Extra map[string]any
|
||||||
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
|
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
|
||||||
// This should only be set when the caller has explicitly confirmed the risk.
|
// This should only be set when the caller has explicitly confirmed the risk.
|
||||||
SkipMixedChannelCheck bool
|
SkipMixedChannelCheck bool
|
||||||
@@ -220,23 +224,35 @@ type GenerateRedeemCodesInput struct {
|
|||||||
ValidityDays int // 订阅类型专用:有效天数
|
ValidityDays int // 订阅类型专用:有效天数
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProxyTestResult represents the result of testing a proxy
|
type ProxyBatchDeleteResult struct {
|
||||||
type ProxyTestResult struct {
|
DeletedIDs []int64 `json:"deleted_ids"`
|
||||||
Success bool `json:"success"`
|
Skipped []ProxyBatchDeleteSkipped `json:"skipped"`
|
||||||
Message string `json:"message"`
|
|
||||||
LatencyMs int64 `json:"latency_ms,omitempty"`
|
|
||||||
IPAddress string `json:"ip_address,omitempty"`
|
|
||||||
City string `json:"city,omitempty"`
|
|
||||||
Region string `json:"region,omitempty"`
|
|
||||||
Country string `json:"country,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProxyExitInfo represents proxy exit information from ipinfo.io
|
type ProxyBatchDeleteSkipped struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProxyTestResult represents the result of testing a proxy
|
||||||
|
type ProxyTestResult struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
LatencyMs int64 `json:"latency_ms,omitempty"`
|
||||||
|
IPAddress string `json:"ip_address,omitempty"`
|
||||||
|
City string `json:"city,omitempty"`
|
||||||
|
Region string `json:"region,omitempty"`
|
||||||
|
Country string `json:"country,omitempty"`
|
||||||
|
CountryCode string `json:"country_code,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProxyExitInfo represents proxy exit information from ip-api.com
|
||||||
type ProxyExitInfo struct {
|
type ProxyExitInfo struct {
|
||||||
IP string
|
IP string
|
||||||
City string
|
City string
|
||||||
Region string
|
Region string
|
||||||
Country string
|
Country string
|
||||||
|
CountryCode string
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProxyExitInfoProber tests proxy connectivity and retrieves exit information
|
// ProxyExitInfoProber tests proxy connectivity and retrieves exit information
|
||||||
@@ -254,6 +270,7 @@ type adminServiceImpl struct {
|
|||||||
redeemCodeRepo RedeemCodeRepository
|
redeemCodeRepo RedeemCodeRepository
|
||||||
billingCacheService *BillingCacheService
|
billingCacheService *BillingCacheService
|
||||||
proxyProber ProxyExitInfoProber
|
proxyProber ProxyExitInfoProber
|
||||||
|
proxyLatencyCache ProxyLatencyCache
|
||||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -267,6 +284,7 @@ func NewAdminService(
|
|||||||
redeemCodeRepo RedeemCodeRepository,
|
redeemCodeRepo RedeemCodeRepository,
|
||||||
billingCacheService *BillingCacheService,
|
billingCacheService *BillingCacheService,
|
||||||
proxyProber ProxyExitInfoProber,
|
proxyProber ProxyExitInfoProber,
|
||||||
|
proxyLatencyCache ProxyLatencyCache,
|
||||||
authCacheInvalidator APIKeyAuthCacheInvalidator,
|
authCacheInvalidator APIKeyAuthCacheInvalidator,
|
||||||
) AdminService {
|
) AdminService {
|
||||||
return &adminServiceImpl{
|
return &adminServiceImpl{
|
||||||
@@ -278,6 +296,7 @@ func NewAdminService(
|
|||||||
redeemCodeRepo: redeemCodeRepo,
|
redeemCodeRepo: redeemCodeRepo,
|
||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
proxyProber: proxyProber,
|
proxyProber: proxyProber,
|
||||||
|
proxyLatencyCache: proxyLatencyCache,
|
||||||
authCacheInvalidator: authCacheInvalidator,
|
authCacheInvalidator: authCacheInvalidator,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -817,6 +836,12 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
|||||||
} else {
|
} else {
|
||||||
account.AutoPauseOnExpired = true
|
account.AutoPauseOnExpired = true
|
||||||
}
|
}
|
||||||
|
if input.RateMultiplier != nil {
|
||||||
|
if *input.RateMultiplier < 0 {
|
||||||
|
return nil, errors.New("rate_multiplier must be >= 0")
|
||||||
|
}
|
||||||
|
account.RateMultiplier = input.RateMultiplier
|
||||||
|
}
|
||||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -869,6 +894,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
|||||||
if input.Priority != nil {
|
if input.Priority != nil {
|
||||||
account.Priority = *input.Priority
|
account.Priority = *input.Priority
|
||||||
}
|
}
|
||||||
|
if input.RateMultiplier != nil {
|
||||||
|
if *input.RateMultiplier < 0 {
|
||||||
|
return nil, errors.New("rate_multiplier must be >= 0")
|
||||||
|
}
|
||||||
|
account.RateMultiplier = input.RateMultiplier
|
||||||
|
}
|
||||||
if input.Status != "" {
|
if input.Status != "" {
|
||||||
account.Status = input.Status
|
account.Status = input.Status
|
||||||
}
|
}
|
||||||
@@ -942,6 +973,12 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.RateMultiplier != nil {
|
||||||
|
if *input.RateMultiplier < 0 {
|
||||||
|
return nil, errors.New("rate_multiplier must be >= 0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Prepare bulk updates for columns and JSONB fields.
|
// Prepare bulk updates for columns and JSONB fields.
|
||||||
repoUpdates := AccountBulkUpdate{
|
repoUpdates := AccountBulkUpdate{
|
||||||
Credentials: input.Credentials,
|
Credentials: input.Credentials,
|
||||||
@@ -959,6 +996,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
|||||||
if input.Priority != nil {
|
if input.Priority != nil {
|
||||||
repoUpdates.Priority = input.Priority
|
repoUpdates.Priority = input.Priority
|
||||||
}
|
}
|
||||||
|
if input.RateMultiplier != nil {
|
||||||
|
repoUpdates.RateMultiplier = input.RateMultiplier
|
||||||
|
}
|
||||||
if input.Status != "" {
|
if input.Status != "" {
|
||||||
repoUpdates.Status = &input.Status
|
repoUpdates.Status = &input.Status
|
||||||
}
|
}
|
||||||
@@ -1069,6 +1109,7 @@ func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
s.attachProxyLatency(ctx, proxies)
|
||||||
return proxies, result.Total, nil
|
return proxies, result.Total, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1077,7 +1118,12 @@ func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) {
|
func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) {
|
||||||
return s.proxyRepo.ListActiveWithAccountCount(ctx)
|
proxies, err := s.proxyRepo.ListActiveWithAccountCount(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.attachProxyLatency(ctx, proxies)
|
||||||
|
return proxies, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) {
|
func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) {
|
||||||
@@ -1097,6 +1143,8 @@ func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyIn
|
|||||||
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
|
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
// Probe latency asynchronously so creation isn't blocked by network timeout.
|
||||||
|
go s.probeProxyLatency(context.Background(), proxy)
|
||||||
return proxy, nil
|
return proxy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1135,12 +1183,53 @@ func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *Upd
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error {
|
func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error {
|
||||||
|
count, err := s.proxyRepo.CountAccountsByProxyID(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count > 0 {
|
||||||
|
return ErrProxyInUse
|
||||||
|
}
|
||||||
return s.proxyRepo.Delete(ctx, id)
|
return s.proxyRepo.Delete(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error) {
|
func (s *adminServiceImpl) BatchDeleteProxies(ctx context.Context, ids []int64) (*ProxyBatchDeleteResult, error) {
|
||||||
// Return mock data for now - would need a dedicated repository method
|
result := &ProxyBatchDeleteResult{}
|
||||||
return []Account{}, 0, nil
|
if len(ids) == 0 {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, id := range ids {
|
||||||
|
count, err := s.proxyRepo.CountAccountsByProxyID(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
|
||||||
|
ID: id,
|
||||||
|
Reason: err.Error(),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if count > 0 {
|
||||||
|
result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
|
||||||
|
ID: id,
|
||||||
|
Reason: ErrProxyInUse.Error(),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := s.proxyRepo.Delete(ctx, id); err != nil {
|
||||||
|
result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
|
||||||
|
ID: id,
|
||||||
|
Reason: err.Error(),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result.DeletedIDs = append(result.DeletedIDs, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) {
|
||||||
|
return s.proxyRepo.ListAccountSummariesByProxyID(ctx, proxyID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
||||||
@@ -1240,23 +1329,69 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
|
|||||||
proxyURL := proxy.URL()
|
proxyURL := proxy.URL()
|
||||||
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
|
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{
|
||||||
|
Success: false,
|
||||||
|
Message: err.Error(),
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
})
|
||||||
return &ProxyTestResult{
|
return &ProxyTestResult{
|
||||||
Success: false,
|
Success: false,
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
latency := latencyMs
|
||||||
|
s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{
|
||||||
|
Success: true,
|
||||||
|
LatencyMs: &latency,
|
||||||
|
Message: "Proxy is accessible",
|
||||||
|
IPAddress: exitInfo.IP,
|
||||||
|
Country: exitInfo.Country,
|
||||||
|
CountryCode: exitInfo.CountryCode,
|
||||||
|
Region: exitInfo.Region,
|
||||||
|
City: exitInfo.City,
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
})
|
||||||
return &ProxyTestResult{
|
return &ProxyTestResult{
|
||||||
Success: true,
|
Success: true,
|
||||||
Message: "Proxy is accessible",
|
Message: "Proxy is accessible",
|
||||||
LatencyMs: latencyMs,
|
LatencyMs: latencyMs,
|
||||||
IPAddress: exitInfo.IP,
|
IPAddress: exitInfo.IP,
|
||||||
City: exitInfo.City,
|
City: exitInfo.City,
|
||||||
Region: exitInfo.Region,
|
Region: exitInfo.Region,
|
||||||
Country: exitInfo.Country,
|
Country: exitInfo.Country,
|
||||||
|
CountryCode: exitInfo.CountryCode,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) {
|
||||||
|
if s.proxyProber == nil || proxy == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxy.URL())
|
||||||
|
if err != nil {
|
||||||
|
s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{
|
||||||
|
Success: false,
|
||||||
|
Message: err.Error(),
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
latency := latencyMs
|
||||||
|
s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{
|
||||||
|
Success: true,
|
||||||
|
LatencyMs: &latency,
|
||||||
|
Message: "Proxy is accessible",
|
||||||
|
IPAddress: exitInfo.IP,
|
||||||
|
Country: exitInfo.Country,
|
||||||
|
CountryCode: exitInfo.CountryCode,
|
||||||
|
Region: exitInfo.Region,
|
||||||
|
City: exitInfo.City,
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// checkMixedChannelRisk 检查分组中是否存在混合渠道(Antigravity + Anthropic)
|
// checkMixedChannelRisk 检查分组中是否存在混合渠道(Antigravity + Anthropic)
|
||||||
// 如果存在混合,返回错误提示用户确认
|
// 如果存在混合,返回错误提示用户确认
|
||||||
func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
|
func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
|
||||||
@@ -1306,6 +1441,51 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) {
|
||||||
|
if s.proxyLatencyCache == nil || len(proxies) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ids := make([]int64, 0, len(proxies))
|
||||||
|
for i := range proxies {
|
||||||
|
ids = append(ids, proxies[i].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, ids)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Warning: load proxy latency cache failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range proxies {
|
||||||
|
info := latencies[proxies[i].ID]
|
||||||
|
if info == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if info.Success {
|
||||||
|
proxies[i].LatencyStatus = "success"
|
||||||
|
proxies[i].LatencyMs = info.LatencyMs
|
||||||
|
} else {
|
||||||
|
proxies[i].LatencyStatus = "failed"
|
||||||
|
}
|
||||||
|
proxies[i].LatencyMessage = info.Message
|
||||||
|
proxies[i].IPAddress = info.IPAddress
|
||||||
|
proxies[i].Country = info.Country
|
||||||
|
proxies[i].CountryCode = info.CountryCode
|
||||||
|
proxies[i].Region = info.Region
|
||||||
|
proxies[i].City = info.City
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64, info *ProxyLatencyInfo) {
|
||||||
|
if s.proxyLatencyCache == nil || info == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, info); err != nil {
|
||||||
|
log.Printf("Warning: store proxy latency cache failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识
|
// getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识
|
||||||
func getAccountPlatform(accountPlatform string) string {
|
func getAccountPlatform(accountPlatform string) string {
|
||||||
switch strings.ToLower(strings.TrimSpace(accountPlatform)) {
|
switch strings.ToLower(strings.TrimSpace(accountPlatform)) {
|
||||||
|
|||||||
@@ -12,9 +12,9 @@ import (
|
|||||||
|
|
||||||
type accountRepoStubForBulkUpdate struct {
|
type accountRepoStubForBulkUpdate struct {
|
||||||
accountRepoStub
|
accountRepoStub
|
||||||
bulkUpdateErr error
|
bulkUpdateErr error
|
||||||
bulkUpdateIDs []int64
|
bulkUpdateIDs []int64
|
||||||
bindGroupErrByID map[int64]error
|
bindGroupErrByID map[int64]error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
|
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
|
||||||
|
|||||||
@@ -153,8 +153,10 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI
|
|||||||
}
|
}
|
||||||
|
|
||||||
type proxyRepoStub struct {
|
type proxyRepoStub struct {
|
||||||
deleteErr error
|
deleteErr error
|
||||||
deletedIDs []int64
|
countErr error
|
||||||
|
accountCount int64
|
||||||
|
deletedIDs []int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *proxyRepoStub) Create(ctx context.Context, proxy *Proxy) error {
|
func (s *proxyRepoStub) Create(ctx context.Context, proxy *Proxy) error {
|
||||||
@@ -199,7 +201,14 @@ func (s *proxyRepoStub) ExistsByHostPortAuth(ctx context.Context, host string, p
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *proxyRepoStub) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
func (s *proxyRepoStub) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
||||||
panic("unexpected CountAccountsByProxyID call")
|
if s.countErr != nil {
|
||||||
|
return 0, s.countErr
|
||||||
|
}
|
||||||
|
return s.accountCount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *proxyRepoStub) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) {
|
||||||
|
panic("unexpected ListAccountSummariesByProxyID call")
|
||||||
}
|
}
|
||||||
|
|
||||||
type redeemRepoStub struct {
|
type redeemRepoStub struct {
|
||||||
@@ -409,6 +418,15 @@ func TestAdminService_DeleteProxy_Idempotent(t *testing.T) {
|
|||||||
require.Equal(t, []int64{404}, repo.deletedIDs)
|
require.Equal(t, []int64{404}, repo.deletedIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAdminService_DeleteProxy_InUse(t *testing.T) {
|
||||||
|
repo := &proxyRepoStub{accountCount: 2}
|
||||||
|
svc := &adminServiceImpl{proxyRepo: repo}
|
||||||
|
|
||||||
|
err := svc.DeleteProxy(context.Background(), 77)
|
||||||
|
require.ErrorIs(t, err, ErrProxyInUse)
|
||||||
|
require.Empty(t, repo.deletedIDs)
|
||||||
|
}
|
||||||
|
|
||||||
func TestAdminService_DeleteProxy_Error(t *testing.T) {
|
func TestAdminService_DeleteProxy_Error(t *testing.T) {
|
||||||
deleteErr := errors.New("delete failed")
|
deleteErr := errors.New("delete failed")
|
||||||
repo := &proxyRepoStub{deleteErr: deleteErr}
|
repo := &proxyRepoStub{deleteErr: deleteErr}
|
||||||
|
|||||||
@@ -523,6 +523,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
proxyURL = account.Proxy.URL()
|
proxyURL = account.Proxy.URL()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sanitize thinking blocks (clean cache_control and flatten history thinking)
|
||||||
|
sanitizeThinkingBlocks(&claudeReq)
|
||||||
|
|
||||||
// 获取转换选项
|
// 获取转换选项
|
||||||
// Antigravity 上游要求必须包含身份提示词,否则会返回 429
|
// Antigravity 上游要求必须包含身份提示词,否则会返回 429
|
||||||
transformOpts := s.getClaudeTransformOptions(ctx)
|
transformOpts := s.getClaudeTransformOptions(ctx)
|
||||||
@@ -534,6 +537,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
return nil, fmt.Errorf("transform request: %w", err)
|
return nil, fmt.Errorf("transform request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Safety net: ensure no cache_control leaked into Gemini request
|
||||||
|
geminiBody = cleanCacheControlFromGeminiJSON(geminiBody)
|
||||||
|
|
||||||
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
|
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
|
||||||
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
|
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
|
||||||
action := "streamGenerateContent"
|
action := "streamGenerateContent"
|
||||||
@@ -558,6 +564,10 @@ urlFallbackLoop:
|
|||||||
}
|
}
|
||||||
|
|
||||||
upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, action, accessToken, geminiBody)
|
upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, action, accessToken, geminiBody)
|
||||||
|
// Capture upstream request body for ops retry of this attempt.
|
||||||
|
if c != nil {
|
||||||
|
c.Set(OpsUpstreamRequestBodyKey, string(geminiBody))
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -568,6 +578,7 @@ urlFallbackLoop:
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: 0,
|
UpstreamStatusCode: 0,
|
||||||
Kind: "request_error",
|
Kind: "request_error",
|
||||||
Message: safeErr,
|
Message: safeErr,
|
||||||
@@ -609,6 +620,7 @@ urlFallbackLoop:
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: "retry",
|
Kind: "retry",
|
||||||
@@ -639,6 +651,7 @@ urlFallbackLoop:
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: "retry",
|
Kind: "retry",
|
||||||
@@ -691,6 +704,7 @@ urlFallbackLoop:
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: "signature_error",
|
Kind: "signature_error",
|
||||||
@@ -734,6 +748,7 @@ urlFallbackLoop:
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: 0,
|
UpstreamStatusCode: 0,
|
||||||
Kind: "signature_retry_request_error",
|
Kind: "signature_retry_request_error",
|
||||||
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
|
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
|
||||||
@@ -764,6 +779,7 @@ urlFallbackLoop:
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: retryResp.StatusCode,
|
UpstreamStatusCode: retryResp.StatusCode,
|
||||||
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
|
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
|
||||||
Kind: kind,
|
Kind: kind,
|
||||||
@@ -811,6 +827,7 @@ urlFallbackLoop:
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: "failover",
|
Kind: "failover",
|
||||||
@@ -903,6 +920,143 @@ func extractAntigravityErrorMessage(body []byte) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cleanCacheControlFromGeminiJSON removes cache_control from Gemini JSON (emergency fix)
|
||||||
|
// This should not be needed if transformation is correct, but serves as a safety net
|
||||||
|
func cleanCacheControlFromGeminiJSON(body []byte) []byte {
|
||||||
|
// Try a more robust approach: parse and clean
|
||||||
|
var data map[string]any
|
||||||
|
if err := json.Unmarshal(body, &data); err != nil {
|
||||||
|
log.Printf("[Antigravity] Failed to parse Gemini JSON for cache_control cleaning: %v", err)
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
cleaned := removeCacheControlFromAny(data)
|
||||||
|
if !cleaned {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
if result, err := json.Marshal(data); err == nil {
|
||||||
|
log.Printf("[Antigravity] Successfully cleaned cache_control from Gemini JSON")
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeCacheControlFromAny recursively removes cache_control fields
|
||||||
|
func removeCacheControlFromAny(v any) bool {
|
||||||
|
cleaned := false
|
||||||
|
|
||||||
|
switch val := v.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
for k, child := range val {
|
||||||
|
if k == "cache_control" {
|
||||||
|
delete(val, k)
|
||||||
|
cleaned = true
|
||||||
|
} else if removeCacheControlFromAny(child) {
|
||||||
|
cleaned = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
for _, item := range val {
|
||||||
|
if removeCacheControlFromAny(item) {
|
||||||
|
cleaned = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cleaned
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanitizeThinkingBlocks cleans cache_control and flattens history thinking blocks
|
||||||
|
// Thinking blocks do NOT support cache_control field (Anthropic API/Vertex AI requirement)
|
||||||
|
// Additionally, history thinking blocks are flattened to text to avoid upstream validation errors
|
||||||
|
func sanitizeThinkingBlocks(req *antigravity.ClaudeRequest) {
|
||||||
|
if req == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[Antigravity] sanitizeThinkingBlocks: processing request with %d messages", len(req.Messages))
|
||||||
|
|
||||||
|
// Clean system blocks
|
||||||
|
if len(req.System) > 0 {
|
||||||
|
var systemBlocks []map[string]any
|
||||||
|
if err := json.Unmarshal(req.System, &systemBlocks); err == nil {
|
||||||
|
for i := range systemBlocks {
|
||||||
|
if blockType, _ := systemBlocks[i]["type"].(string); blockType == "thinking" || systemBlocks[i]["thinking"] != nil {
|
||||||
|
if removeCacheControlFromAny(systemBlocks[i]) {
|
||||||
|
log.Printf("[Antigravity] Deep cleaned cache_control from thinking block in system[%d]", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Marshal back
|
||||||
|
if cleaned, err := json.Marshal(systemBlocks); err == nil {
|
||||||
|
req.System = cleaned
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean message content blocks and flatten history
|
||||||
|
lastMsgIdx := len(req.Messages) - 1
|
||||||
|
for msgIdx := range req.Messages {
|
||||||
|
raw := req.Messages[msgIdx].Content
|
||||||
|
if len(raw) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to parse as blocks array
|
||||||
|
var blocks []map[string]any
|
||||||
|
if err := json.Unmarshal(raw, &blocks); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
cleaned := false
|
||||||
|
for blockIdx := range blocks {
|
||||||
|
blockType, _ := blocks[blockIdx]["type"].(string)
|
||||||
|
|
||||||
|
// Check for thinking blocks (typed or untyped)
|
||||||
|
if blockType == "thinking" || blocks[blockIdx]["thinking"] != nil {
|
||||||
|
// 1. Clean cache_control
|
||||||
|
if removeCacheControlFromAny(blocks[blockIdx]) {
|
||||||
|
log.Printf("[Antigravity] Deep cleaned cache_control from thinking block in messages[%d].content[%d]", msgIdx, blockIdx)
|
||||||
|
cleaned = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Flatten to text if it's a history message (not the last one)
|
||||||
|
if msgIdx < lastMsgIdx {
|
||||||
|
log.Printf("[Antigravity] Flattening history thinking block to text at messages[%d].content[%d]", msgIdx, blockIdx)
|
||||||
|
|
||||||
|
// Extract thinking content
|
||||||
|
var textContent string
|
||||||
|
if t, ok := blocks[blockIdx]["thinking"].(string); ok {
|
||||||
|
textContent = t
|
||||||
|
} else {
|
||||||
|
// Fallback for non-string content (marshal it)
|
||||||
|
if b, err := json.Marshal(blocks[blockIdx]["thinking"]); err == nil {
|
||||||
|
textContent = string(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to text block
|
||||||
|
blocks[blockIdx]["type"] = "text"
|
||||||
|
blocks[blockIdx]["text"] = textContent
|
||||||
|
delete(blocks[blockIdx], "thinking")
|
||||||
|
delete(blocks[blockIdx], "signature")
|
||||||
|
delete(blocks[blockIdx], "cache_control") // Ensure it's gone
|
||||||
|
cleaned = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal back if modified
|
||||||
|
if cleaned {
|
||||||
|
if marshaled, err := json.Marshal(blocks); err == nil {
|
||||||
|
req.Messages[msgIdx].Content = marshaled
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
|
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
|
||||||
// This preserves the thinking content while avoiding signature validation errors.
|
// This preserves the thinking content while avoiding signature validation errors.
|
||||||
// Note: redacted_thinking blocks are removed because they cannot be converted to text.
|
// Note: redacted_thinking blocks are removed because they cannot be converted to text.
|
||||||
@@ -1228,6 +1382,7 @@ urlFallbackLoop:
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: 0,
|
UpstreamStatusCode: 0,
|
||||||
Kind: "request_error",
|
Kind: "request_error",
|
||||||
Message: safeErr,
|
Message: safeErr,
|
||||||
@@ -1269,6 +1424,7 @@ urlFallbackLoop:
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: "retry",
|
Kind: "retry",
|
||||||
@@ -1299,6 +1455,7 @@ urlFallbackLoop:
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: "retry",
|
Kind: "retry",
|
||||||
@@ -1400,6 +1557,7 @@ urlFallbackLoop:
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: requestID,
|
UpstreamRequestID: requestID,
|
||||||
Kind: "failover",
|
Kind: "failover",
|
||||||
@@ -1416,6 +1574,7 @@ urlFallbackLoop:
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: requestID,
|
UpstreamRequestID: requestID,
|
||||||
Kind: "http_error",
|
Kind: "http_error",
|
||||||
@@ -1717,6 +1876,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Printf("Stream data interval timeout (antigravity)")
|
log.Printf("Stream data interval timeout (antigravity)")
|
||||||
|
// 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout
|
||||||
sendErrorEvent("stream_timeout")
|
sendErrorEvent("stream_timeout")
|
||||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
}
|
}
|
||||||
@@ -1895,6 +2055,7 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: upstreamStatus,
|
UpstreamStatusCode: upstreamStatus,
|
||||||
UpstreamRequestID: upstreamRequestID,
|
UpstreamRequestID: upstreamRequestID,
|
||||||
Kind: "http_error",
|
Kind: "http_error",
|
||||||
@@ -2271,6 +2432,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Printf("Stream data interval timeout (antigravity)")
|
log.Printf("Stream data interval timeout (antigravity)")
|
||||||
|
// 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout
|
||||||
sendErrorEvent("stream_timeout")
|
sendErrorEvent("stream_timeout")
|
||||||
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
|||||||
return "", errors.New("not an antigravity oauth account")
|
return "", errors.New("not an antigravity oauth account")
|
||||||
}
|
}
|
||||||
|
|
||||||
cacheKey := antigravityTokenCacheKey(account)
|
cacheKey := AntigravityTokenCacheKey(account)
|
||||||
|
|
||||||
// 1. 先尝试缓存
|
// 1. 先尝试缓存
|
||||||
if p.tokenCache != nil {
|
if p.tokenCache != nil {
|
||||||
@@ -121,7 +121,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
|||||||
return accessToken, nil
|
return accessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func antigravityTokenCacheKey(account *Account) string {
|
func AntigravityTokenCacheKey(account *Account) string {
|
||||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||||
if projectID != "" {
|
if projectID != "" {
|
||||||
return "ag:" + projectID
|
return "ag:" + projectID
|
||||||
|
|||||||
@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
|
|||||||
return stats, nil
|
return stats, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) {
|
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) {
|
||||||
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID)
|
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get usage trend with filters: %w", err)
|
return nil, fmt.Errorf("get usage trend with filters: %w", err)
|
||||||
}
|
}
|
||||||
return trend, nil
|
return trend, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID int64) ([]usagestats.ModelStat, error) {
|
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) {
|
||||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, 0)
|
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get model stats with filters: %w", err)
|
return nil, fmt.Errorf("get model stats with filters: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -146,6 +146,13 @@ const (
|
|||||||
|
|
||||||
// SettingKeyOpsAdvancedSettings stores JSON config for ops advanced settings (data retention, aggregation).
|
// SettingKeyOpsAdvancedSettings stores JSON config for ops advanced settings (data retention, aggregation).
|
||||||
SettingKeyOpsAdvancedSettings = "ops_advanced_settings"
|
SettingKeyOpsAdvancedSettings = "ops_advanced_settings"
|
||||||
|
|
||||||
|
// =========================
|
||||||
|
// Stream Timeout Handling
|
||||||
|
// =========================
|
||||||
|
|
||||||
|
// SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling.
|
||||||
|
SettingKeyStreamTimeoutSettings = "stream_timeout_settings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
||||||
|
|||||||
@@ -1211,6 +1211,72 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
require.Nil(t, result)
|
require.Nil(t, result)
|
||||||
require.Contains(t, err.Error(), "no available accounts")
|
require.Contains(t, err.Error(), "no available accounts")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("过滤不可调度账号-限流账号被跳过", func(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
resetAt := now.Add(10 * time.Minute)
|
||||||
|
|
||||||
|
repo := &mockAccountRepoForPlatform{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, RateLimitResetAt: &resetAt},
|
||||||
|
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||||
|
},
|
||||||
|
accountsByID: map[int64]*Account{},
|
||||||
|
}
|
||||||
|
for i := range repo.accounts {
|
||||||
|
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &mockGatewayCacheForPlatform{}
|
||||||
|
cfg := testConfig()
|
||||||
|
cfg.Gateway.Scheduling.LoadBatchEnabled = false
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
cfg: cfg,
|
||||||
|
concurrencyService: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, result.Account)
|
||||||
|
require.Equal(t, int64(2), result.Account.ID, "应跳过限流账号,选择可用账号")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("过滤不可调度账号-过载账号被跳过", func(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
overloadUntil := now.Add(10 * time.Minute)
|
||||||
|
|
||||||
|
repo := &mockAccountRepoForPlatform{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, OverloadUntil: &overloadUntil},
|
||||||
|
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||||
|
},
|
||||||
|
accountsByID: map[int64]*Account{},
|
||||||
|
}
|
||||||
|
for i := range repo.accounts {
|
||||||
|
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &mockGatewayCacheForPlatform{}
|
||||||
|
cfg := testConfig()
|
||||||
|
cfg.Gateway.Scheduling.LoadBatchEnabled = false
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
cfg: cfg,
|
||||||
|
concurrencyService: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, result.Account)
|
||||||
|
require.Equal(t, int64(2), result.Account.ID, "应跳过过载账号,选择可用账号")
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) {
|
func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) {
|
||||||
|
|||||||
@@ -151,6 +151,7 @@ type GatewayService struct {
|
|||||||
userSubRepo UserSubscriptionRepository
|
userSubRepo UserSubscriptionRepository
|
||||||
cache GatewayCache
|
cache GatewayCache
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
|
schedulerSnapshot *SchedulerSnapshotService
|
||||||
billingService *BillingService
|
billingService *BillingService
|
||||||
rateLimitService *RateLimitService
|
rateLimitService *RateLimitService
|
||||||
billingCacheService *BillingCacheService
|
billingCacheService *BillingCacheService
|
||||||
@@ -169,6 +170,7 @@ func NewGatewayService(
|
|||||||
userSubRepo UserSubscriptionRepository,
|
userSubRepo UserSubscriptionRepository,
|
||||||
cache GatewayCache,
|
cache GatewayCache,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
|
schedulerSnapshot *SchedulerSnapshotService,
|
||||||
concurrencyService *ConcurrencyService,
|
concurrencyService *ConcurrencyService,
|
||||||
billingService *BillingService,
|
billingService *BillingService,
|
||||||
rateLimitService *RateLimitService,
|
rateLimitService *RateLimitService,
|
||||||
@@ -185,6 +187,7 @@ func NewGatewayService(
|
|||||||
userSubRepo: userSubRepo,
|
userSubRepo: userSubRepo,
|
||||||
cache: cache,
|
cache: cache,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
|
schedulerSnapshot: schedulerSnapshot,
|
||||||
concurrencyService: concurrencyService,
|
concurrencyService: concurrencyService,
|
||||||
billingService: billingService,
|
billingService: billingService,
|
||||||
rateLimitService: rateLimitService,
|
rateLimitService: rateLimitService,
|
||||||
@@ -508,6 +511,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
if isExcluded(acc.ID) {
|
if isExcluded(acc.ID) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// Scheduler snapshots can be temporarily stale (bucket rebuild is throttled);
|
||||||
|
// re-check schedulability here so recently rate-limited/overloaded accounts
|
||||||
|
// are not selected again before the bucket is rebuilt.
|
||||||
|
if !acc.IsSchedulable() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
|
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -745,6 +754,9 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
|
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
|
||||||
|
if s.schedulerSnapshot != nil {
|
||||||
|
return s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||||
|
}
|
||||||
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
|
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
|
||||||
if useMixed {
|
if useMixed {
|
||||||
platforms := []string{platform, PlatformAntigravity}
|
platforms := []string{platform, PlatformAntigravity}
|
||||||
@@ -821,6 +833,13 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in
|
|||||||
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||||
|
if s.schedulerSnapshot != nil {
|
||||||
|
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||||
|
}
|
||||||
|
return s.accountRepo.GetByID(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
|
func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
|
||||||
sort.SliceStable(accounts, func(i, j int) bool {
|
sort.SliceStable(accounts, func(i, j int) bool {
|
||||||
a, b := accounts[i], accounts[j]
|
a, b := accounts[i], accounts[j]
|
||||||
@@ -851,7 +870,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
if err == nil && accountID > 0 {
|
if err == nil && accountID > 0 {
|
||||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||||
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
||||||
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||||
@@ -864,16 +883,11 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2. 获取可调度账号列表(单平台)
|
// 2. 获取可调度账号列表(单平台)
|
||||||
var accounts []Account
|
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||||
var err error
|
if hasForcePlatform && forcePlatform == "" {
|
||||||
if s.cfg.RunMode == config.RunModeSimple {
|
hasForcePlatform = false
|
||||||
// 简易模式:忽略 groupID,查询所有可用账号
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
|
||||||
} else if groupID != nil {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
|
|
||||||
} else {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
|
||||||
}
|
}
|
||||||
|
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -885,6 +899,11 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
|
||||||
|
// avoid selecting accounts that were recently rate-limited/overloaded.
|
||||||
|
if !acc.IsSchedulable() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if !acc.IsSchedulableForModel(requestedModel) {
|
if !acc.IsSchedulableForModel(requestedModel) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -935,7 +954,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
// selectAccountWithMixedScheduling 选择账户(支持混合调度)
|
// selectAccountWithMixedScheduling 选择账户(支持混合调度)
|
||||||
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
|
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
|
||||||
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
|
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
|
||||||
platforms := []string{nativePlatform, PlatformAntigravity}
|
|
||||||
preferOAuth := nativePlatform == PlatformGemini
|
preferOAuth := nativePlatform == PlatformGemini
|
||||||
|
|
||||||
// 1. 查询粘性会话
|
// 1. 查询粘性会话
|
||||||
@@ -943,7 +961,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
if err == nil && accountID > 0 {
|
if err == nil && accountID > 0 {
|
||||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||||
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
||||||
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||||
@@ -958,13 +976,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2. 获取可调度账号列表
|
// 2. 获取可调度账号列表
|
||||||
var accounts []Account
|
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, nativePlatform, false)
|
||||||
var err error
|
|
||||||
if groupID != nil {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
|
|
||||||
} else {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -976,6 +988,11 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
|
||||||
|
// avoid selecting accounts that were recently rate-limited/overloaded.
|
||||||
|
if !acc.IsSchedulable() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
|
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
|
||||||
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||||
continue
|
continue
|
||||||
@@ -1226,6 +1243,9 @@ func enforceCacheControlLimit(body []byte) []byte {
|
|||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 清理 thinking 块中的非法 cache_control(thinking 块不支持该字段)
|
||||||
|
removeCacheControlFromThinkingBlocks(data)
|
||||||
|
|
||||||
// 计算当前 cache_control 块数量
|
// 计算当前 cache_control 块数量
|
||||||
count := countCacheControlBlocks(data)
|
count := countCacheControlBlocks(data)
|
||||||
if count <= maxCacheControlBlocks {
|
if count <= maxCacheControlBlocks {
|
||||||
@@ -1253,6 +1273,7 @@ func enforceCacheControlLimit(body []byte) []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// countCacheControlBlocks 统计 system 和 messages 中的 cache_control 块数量
|
// countCacheControlBlocks 统计 system 和 messages 中的 cache_control 块数量
|
||||||
|
// 注意:thinking 块不支持 cache_control,统计时跳过
|
||||||
func countCacheControlBlocks(data map[string]any) int {
|
func countCacheControlBlocks(data map[string]any) int {
|
||||||
count := 0
|
count := 0
|
||||||
|
|
||||||
@@ -1260,6 +1281,10 @@ func countCacheControlBlocks(data map[string]any) int {
|
|||||||
if system, ok := data["system"].([]any); ok {
|
if system, ok := data["system"].([]any); ok {
|
||||||
for _, item := range system {
|
for _, item := range system {
|
||||||
if m, ok := item.(map[string]any); ok {
|
if m, ok := item.(map[string]any); ok {
|
||||||
|
// thinking 块不支持 cache_control,跳过
|
||||||
|
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if _, has := m["cache_control"]; has {
|
if _, has := m["cache_control"]; has {
|
||||||
count++
|
count++
|
||||||
}
|
}
|
||||||
@@ -1274,6 +1299,10 @@ func countCacheControlBlocks(data map[string]any) int {
|
|||||||
if content, ok := msgMap["content"].([]any); ok {
|
if content, ok := msgMap["content"].([]any); ok {
|
||||||
for _, item := range content {
|
for _, item := range content {
|
||||||
if m, ok := item.(map[string]any); ok {
|
if m, ok := item.(map[string]any); ok {
|
||||||
|
// thinking 块不支持 cache_control,跳过
|
||||||
|
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if _, has := m["cache_control"]; has {
|
if _, has := m["cache_control"]; has {
|
||||||
count++
|
count++
|
||||||
}
|
}
|
||||||
@@ -1289,6 +1318,7 @@ func countCacheControlBlocks(data map[string]any) int {
|
|||||||
|
|
||||||
// removeCacheControlFromMessages 从 messages 中移除一个 cache_control(从头开始)
|
// removeCacheControlFromMessages 从 messages 中移除一个 cache_control(从头开始)
|
||||||
// 返回 true 表示成功移除,false 表示没有可移除的
|
// 返回 true 表示成功移除,false 表示没有可移除的
|
||||||
|
// 注意:跳过 thinking 块(它不支持 cache_control)
|
||||||
func removeCacheControlFromMessages(data map[string]any) bool {
|
func removeCacheControlFromMessages(data map[string]any) bool {
|
||||||
messages, ok := data["messages"].([]any)
|
messages, ok := data["messages"].([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -1306,6 +1336,10 @@ func removeCacheControlFromMessages(data map[string]any) bool {
|
|||||||
}
|
}
|
||||||
for _, item := range content {
|
for _, item := range content {
|
||||||
if m, ok := item.(map[string]any); ok {
|
if m, ok := item.(map[string]any); ok {
|
||||||
|
// thinking 块不支持 cache_control,跳过
|
||||||
|
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if _, has := m["cache_control"]; has {
|
if _, has := m["cache_control"]; has {
|
||||||
delete(m, "cache_control")
|
delete(m, "cache_control")
|
||||||
return true
|
return true
|
||||||
@@ -1318,6 +1352,7 @@ func removeCacheControlFromMessages(data map[string]any) bool {
|
|||||||
|
|
||||||
// removeCacheControlFromSystem 从 system 中移除一个 cache_control(从尾部开始,保护注入的 prompt)
|
// removeCacheControlFromSystem 从 system 中移除一个 cache_control(从尾部开始,保护注入的 prompt)
|
||||||
// 返回 true 表示成功移除,false 表示没有可移除的
|
// 返回 true 表示成功移除,false 表示没有可移除的
|
||||||
|
// 注意:跳过 thinking 块(它不支持 cache_control)
|
||||||
func removeCacheControlFromSystem(data map[string]any) bool {
|
func removeCacheControlFromSystem(data map[string]any) bool {
|
||||||
system, ok := data["system"].([]any)
|
system, ok := data["system"].([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -1327,6 +1362,10 @@ func removeCacheControlFromSystem(data map[string]any) bool {
|
|||||||
// 从尾部开始移除,保护开头注入的 Claude Code prompt
|
// 从尾部开始移除,保护开头注入的 Claude Code prompt
|
||||||
for i := len(system) - 1; i >= 0; i-- {
|
for i := len(system) - 1; i >= 0; i-- {
|
||||||
if m, ok := system[i].(map[string]any); ok {
|
if m, ok := system[i].(map[string]any); ok {
|
||||||
|
// thinking 块不支持 cache_control,跳过
|
||||||
|
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if _, has := m["cache_control"]; has {
|
if _, has := m["cache_control"]; has {
|
||||||
delete(m, "cache_control")
|
delete(m, "cache_control")
|
||||||
return true
|
return true
|
||||||
@@ -1336,6 +1375,44 @@ func removeCacheControlFromSystem(data map[string]any) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// removeCacheControlFromThinkingBlocks 强制清理所有 thinking 块中的非法 cache_control
|
||||||
|
// thinking 块不支持 cache_control 字段,这个函数确保所有 thinking 块都不含该字段
|
||||||
|
func removeCacheControlFromThinkingBlocks(data map[string]any) {
|
||||||
|
// 清理 system 中的 thinking 块
|
||||||
|
if system, ok := data["system"].([]any); ok {
|
||||||
|
for _, item := range system {
|
||||||
|
if m, ok := item.(map[string]any); ok {
|
||||||
|
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||||
|
if _, has := m["cache_control"]; has {
|
||||||
|
delete(m, "cache_control")
|
||||||
|
log.Printf("[Warning] Removed illegal cache_control from thinking block in system")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清理 messages 中的 thinking 块
|
||||||
|
if messages, ok := data["messages"].([]any); ok {
|
||||||
|
for msgIdx, msg := range messages {
|
||||||
|
if msgMap, ok := msg.(map[string]any); ok {
|
||||||
|
if content, ok := msgMap["content"].([]any); ok {
|
||||||
|
for contentIdx, item := range content {
|
||||||
|
if m, ok := item.(map[string]any); ok {
|
||||||
|
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||||
|
if _, has := m["cache_control"]; has {
|
||||||
|
delete(m, "cache_control")
|
||||||
|
log.Printf("[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIdx, contentIdx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Forward 转发请求到Claude API
|
// Forward 转发请求到Claude API
|
||||||
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
|
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
@@ -1389,6 +1466,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
||||||
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
|
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
|
||||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel)
|
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel)
|
||||||
|
// Capture upstream request body for ops retry of this attempt.
|
||||||
|
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1405,6 +1485,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: 0,
|
UpstreamStatusCode: 0,
|
||||||
Kind: "request_error",
|
Kind: "request_error",
|
||||||
Message: safeErr,
|
Message: safeErr,
|
||||||
@@ -1429,6 +1510,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: "signature_error",
|
Kind: "signature_error",
|
||||||
@@ -1480,6 +1562,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: retryResp.StatusCode,
|
UpstreamStatusCode: retryResp.StatusCode,
|
||||||
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
|
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
|
||||||
Kind: "signature_retry_thinking",
|
Kind: "signature_retry_thinking",
|
||||||
@@ -1508,6 +1591,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: 0,
|
UpstreamStatusCode: 0,
|
||||||
Kind: "signature_retry_tools_request_error",
|
Kind: "signature_retry_tools_request_error",
|
||||||
Message: sanitizeUpstreamErrorMessage(retryErr2.Error()),
|
Message: sanitizeUpstreamErrorMessage(retryErr2.Error()),
|
||||||
@@ -1566,6 +1650,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: "retry",
|
Kind: "retry",
|
||||||
@@ -1614,6 +1699,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: "retry_exhausted_failover",
|
Kind: "retry_exhausted_failover",
|
||||||
@@ -1680,6 +1766,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: "failover_on_400",
|
Kind: "failover_on_400",
|
||||||
@@ -2340,6 +2427,10 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||||
}
|
}
|
||||||
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||||
|
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
||||||
|
if s.rateLimitService != nil {
|
||||||
|
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
|
||||||
|
}
|
||||||
sendErrorEvent("stream_timeout")
|
sendErrorEvent("stream_timeout")
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
}
|
}
|
||||||
@@ -2553,30 +2644,32 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
if result.ImageSize != "" {
|
if result.ImageSize != "" {
|
||||||
imageSize = &result.ImageSize
|
imageSize = &result.ImageSize
|
||||||
}
|
}
|
||||||
|
accountRateMultiplier := account.BillingRateMultiplier()
|
||||||
usageLog := &UsageLog{
|
usageLog := &UsageLog{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
APIKeyID: apiKey.ID,
|
APIKeyID: apiKey.ID,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
RequestID: result.RequestID,
|
RequestID: result.RequestID,
|
||||||
Model: result.Model,
|
Model: result.Model,
|
||||||
InputTokens: result.Usage.InputTokens,
|
InputTokens: result.Usage.InputTokens,
|
||||||
OutputTokens: result.Usage.OutputTokens,
|
OutputTokens: result.Usage.OutputTokens,
|
||||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||||
InputCost: cost.InputCost,
|
InputCost: cost.InputCost,
|
||||||
OutputCost: cost.OutputCost,
|
OutputCost: cost.OutputCost,
|
||||||
CacheCreationCost: cost.CacheCreationCost,
|
CacheCreationCost: cost.CacheCreationCost,
|
||||||
CacheReadCost: cost.CacheReadCost,
|
CacheReadCost: cost.CacheReadCost,
|
||||||
TotalCost: cost.TotalCost,
|
TotalCost: cost.TotalCost,
|
||||||
ActualCost: cost.ActualCost,
|
ActualCost: cost.ActualCost,
|
||||||
RateMultiplier: multiplier,
|
RateMultiplier: multiplier,
|
||||||
BillingType: billingType,
|
AccountRateMultiplier: &accountRateMultiplier,
|
||||||
Stream: result.Stream,
|
BillingType: billingType,
|
||||||
DurationMs: &durationMs,
|
Stream: result.Stream,
|
||||||
FirstTokenMs: result.FirstTokenMs,
|
DurationMs: &durationMs,
|
||||||
ImageCount: result.ImageCount,
|
FirstTokenMs: result.FirstTokenMs,
|
||||||
ImageSize: imageSize,
|
ImageCount: result.ImageCount,
|
||||||
CreatedAt: time.Now(),
|
ImageSize: imageSize,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
// 添加 UserAgent
|
// 添加 UserAgent
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ type GeminiMessagesCompatService struct {
|
|||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
groupRepo GroupRepository
|
groupRepo GroupRepository
|
||||||
cache GatewayCache
|
cache GatewayCache
|
||||||
|
schedulerSnapshot *SchedulerSnapshotService
|
||||||
tokenProvider *GeminiTokenProvider
|
tokenProvider *GeminiTokenProvider
|
||||||
rateLimitService *RateLimitService
|
rateLimitService *RateLimitService
|
||||||
httpUpstream HTTPUpstream
|
httpUpstream HTTPUpstream
|
||||||
@@ -51,6 +52,7 @@ func NewGeminiMessagesCompatService(
|
|||||||
accountRepo AccountRepository,
|
accountRepo AccountRepository,
|
||||||
groupRepo GroupRepository,
|
groupRepo GroupRepository,
|
||||||
cache GatewayCache,
|
cache GatewayCache,
|
||||||
|
schedulerSnapshot *SchedulerSnapshotService,
|
||||||
tokenProvider *GeminiTokenProvider,
|
tokenProvider *GeminiTokenProvider,
|
||||||
rateLimitService *RateLimitService,
|
rateLimitService *RateLimitService,
|
||||||
httpUpstream HTTPUpstream,
|
httpUpstream HTTPUpstream,
|
||||||
@@ -61,6 +63,7 @@ func NewGeminiMessagesCompatService(
|
|||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
groupRepo: groupRepo,
|
groupRepo: groupRepo,
|
||||||
cache: cache,
|
cache: cache,
|
||||||
|
schedulerSnapshot: schedulerSnapshot,
|
||||||
tokenProvider: tokenProvider,
|
tokenProvider: tokenProvider,
|
||||||
rateLimitService: rateLimitService,
|
rateLimitService: rateLimitService,
|
||||||
httpUpstream: httpUpstream,
|
httpUpstream: httpUpstream,
|
||||||
@@ -105,12 +108,6 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
|||||||
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||||||
// 注意:强制平台模式不走混合调度
|
// 注意:强制平台模式不走混合调度
|
||||||
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
|
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
|
||||||
var queryPlatforms []string
|
|
||||||
if useMixedScheduling {
|
|
||||||
queryPlatforms = []string{PlatformGemini, PlatformAntigravity}
|
|
||||||
} else {
|
|
||||||
queryPlatforms = []string{platform}
|
|
||||||
}
|
|
||||||
|
|
||||||
cacheKey := "gemini:" + sessionHash
|
cacheKey := "gemini:" + sessionHash
|
||||||
|
|
||||||
@@ -118,7 +115,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
|||||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||||
if err == nil && accountID > 0 {
|
if err == nil && accountID > 0 {
|
||||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||||
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
|
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
|
||||||
if err == nil && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
if err == nil && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||||
valid := false
|
valid := false
|
||||||
@@ -149,22 +146,16 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
|
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
|
||||||
var accounts []Account
|
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, platform, hasForcePlatform)
|
||||||
var err error
|
if err != nil {
|
||||||
if groupID != nil {
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
|
}
|
||||||
|
// 强制平台模式下,分组中找不到账户时回退查询全部
|
||||||
|
if len(accounts) == 0 && groupID != nil && hasForcePlatform {
|
||||||
|
accounts, err = s.listSchedulableAccountsOnce(ctx, nil, platform, hasForcePlatform)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||||
}
|
}
|
||||||
// 强制平台模式下,分组中找不到账户时回退查询全部
|
|
||||||
if len(accounts) == 0 && hasForcePlatform {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var selected *Account
|
var selected *Account
|
||||||
@@ -245,6 +236,31 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
|
|||||||
return s.antigravityGatewayService
|
return s.antigravityGatewayService
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *GeminiMessagesCompatService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||||
|
if s.schedulerSnapshot != nil {
|
||||||
|
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||||
|
}
|
||||||
|
return s.accountRepo.GetByID(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GeminiMessagesCompatService) listSchedulableAccountsOnce(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, error) {
|
||||||
|
if s.schedulerSnapshot != nil {
|
||||||
|
accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||||
|
return accounts, err
|
||||||
|
}
|
||||||
|
|
||||||
|
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
|
||||||
|
queryPlatforms := []string{platform}
|
||||||
|
if useMixedScheduling {
|
||||||
|
queryPlatforms = []string{platform, PlatformAntigravity}
|
||||||
|
}
|
||||||
|
|
||||||
|
if groupID != nil {
|
||||||
|
return s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
|
||||||
|
}
|
||||||
|
return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
|
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||||
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
||||||
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
||||||
@@ -266,13 +282,7 @@ func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (strin
|
|||||||
|
|
||||||
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
|
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
|
||||||
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
|
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
|
||||||
var accounts []Account
|
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, PlatformAntigravity, false)
|
||||||
var err error
|
|
||||||
if groupID != nil {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAntigravity)
|
|
||||||
} else {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAntigravity)
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@@ -288,13 +298,7 @@ func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context
|
|||||||
// 3) OAuth accounts explicitly marked as ai_studio
|
// 3) OAuth accounts explicitly marked as ai_studio
|
||||||
// 4) Any remaining Gemini accounts (fallback)
|
// 4) Any remaining Gemini accounts (fallback)
|
||||||
func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*Account, error) {
|
func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*Account, error) {
|
||||||
var accounts []Account
|
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, PlatformGemini, true)
|
||||||
var err error
|
|
||||||
if groupID != nil {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini)
|
|
||||||
} else {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini)
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -541,12 +545,19 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
}
|
}
|
||||||
requestIDHeader = idHeader
|
requestIDHeader = idHeader
|
||||||
|
|
||||||
|
// Capture upstream request body for ops retry of this attempt.
|
||||||
|
if c != nil {
|
||||||
|
// In this code path `body` is already the JSON sent to upstream.
|
||||||
|
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: 0,
|
UpstreamStatusCode: 0,
|
||||||
Kind: "request_error",
|
Kind: "request_error",
|
||||||
Message: safeErr,
|
Message: safeErr,
|
||||||
@@ -584,6 +595,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: upstreamReqID,
|
UpstreamRequestID: upstreamReqID,
|
||||||
Kind: "signature_error",
|
Kind: "signature_error",
|
||||||
@@ -658,6 +670,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: upstreamReqID,
|
UpstreamRequestID: upstreamReqID,
|
||||||
Kind: "retry",
|
Kind: "retry",
|
||||||
@@ -707,6 +720,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: upstreamReqID,
|
UpstreamRequestID: upstreamReqID,
|
||||||
Kind: "failover",
|
Kind: "failover",
|
||||||
@@ -733,6 +747,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: upstreamReqID,
|
UpstreamRequestID: upstreamReqID,
|
||||||
Kind: "failover",
|
Kind: "failover",
|
||||||
@@ -968,12 +983,19 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
}
|
}
|
||||||
requestIDHeader = idHeader
|
requestIDHeader = idHeader
|
||||||
|
|
||||||
|
// Capture upstream request body for ops retry of this attempt.
|
||||||
|
if c != nil {
|
||||||
|
// In this code path `body` is already the JSON sent to upstream.
|
||||||
|
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: 0,
|
UpstreamStatusCode: 0,
|
||||||
Kind: "request_error",
|
Kind: "request_error",
|
||||||
Message: safeErr,
|
Message: safeErr,
|
||||||
@@ -1032,6 +1054,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: upstreamReqID,
|
UpstreamRequestID: upstreamReqID,
|
||||||
Kind: "retry",
|
Kind: "retry",
|
||||||
@@ -1116,6 +1139,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: requestID,
|
UpstreamRequestID: requestID,
|
||||||
Kind: "failover",
|
Kind: "failover",
|
||||||
@@ -1139,6 +1163,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: requestID,
|
UpstreamRequestID: requestID,
|
||||||
Kind: "failover",
|
Kind: "failover",
|
||||||
@@ -1164,6 +1189,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: requestID,
|
UpstreamRequestID: requestID,
|
||||||
Kind: "http_error",
|
Kind: "http_error",
|
||||||
@@ -1296,6 +1322,7 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, acc
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: upstreamStatus,
|
UpstreamStatusCode: upstreamStatus,
|
||||||
UpstreamRequestID: upstreamRequestID,
|
UpstreamRequestID: upstreamRequestID,
|
||||||
Kind: "http_error",
|
Kind: "http_error",
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ type GeminiTokenCache interface {
|
|||||||
// cacheKey should be stable for the token scope; for GeminiCli OAuth we primarily use project_id.
|
// cacheKey should be stable for the token scope; for GeminiCli OAuth we primarily use project_id.
|
||||||
GetAccessToken(ctx context.Context, cacheKey string) (string, error)
|
GetAccessToken(ctx context.Context, cacheKey string) (string, error)
|
||||||
SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error
|
SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error
|
||||||
|
DeleteAccessToken(ctx context.Context, cacheKey string) error
|
||||||
|
|
||||||
AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error)
|
AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error)
|
||||||
ReleaseRefreshLock(ctx context.Context, cacheKey string) error
|
ReleaseRefreshLock(ctx context.Context, cacheKey string) error
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
return "", errors.New("not a gemini oauth account")
|
return "", errors.New("not a gemini oauth account")
|
||||||
}
|
}
|
||||||
|
|
||||||
cacheKey := geminiTokenCacheKey(account)
|
cacheKey := GeminiTokenCacheKey(account)
|
||||||
|
|
||||||
// 1) Try cache first.
|
// 1) Try cache first.
|
||||||
if p.tokenCache != nil {
|
if p.tokenCache != nil {
|
||||||
@@ -151,7 +151,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
return accessToken, nil
|
return accessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func geminiTokenCacheKey(account *Account) string {
|
func GeminiTokenCacheKey(account *Account) string {
|
||||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||||
if projectID != "" {
|
if projectID != "" {
|
||||||
return projectID
|
return projectID
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
_ "embed"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -16,6 +17,9 @@ const (
|
|||||||
codexCacheTTL = 15 * time.Minute
|
codexCacheTTL = 15 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
|
//go:embed prompts/codex_cli_instructions.md
|
||||||
|
var codexCLIInstructions string
|
||||||
|
|
||||||
var codexModelMap = map[string]string{
|
var codexModelMap = map[string]string{
|
||||||
"gpt-5.1-codex": "gpt-5.1-codex",
|
"gpt-5.1-codex": "gpt-5.1-codex",
|
||||||
"gpt-5.1-codex-low": "gpt-5.1-codex",
|
"gpt-5.1-codex-low": "gpt-5.1-codex",
|
||||||
@@ -70,6 +74,8 @@ type opencodeCacheMetadata struct {
|
|||||||
|
|
||||||
func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
|
func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
|
||||||
result := codexTransformResult{}
|
result := codexTransformResult{}
|
||||||
|
// 工具续链需求会影响存储策略与 input 过滤逻辑。
|
||||||
|
needsToolContinuation := NeedsToolContinuation(reqBody)
|
||||||
|
|
||||||
model := ""
|
model := ""
|
||||||
if v, ok := reqBody["model"].(string); ok {
|
if v, ok := reqBody["model"].(string); ok {
|
||||||
@@ -84,6 +90,8 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
|
|||||||
result.NormalizedModel = normalizedModel
|
result.NormalizedModel = normalizedModel
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OAuth 走 ChatGPT internal API 时,store 必须为 false;显式 true 也会强制覆盖。
|
||||||
|
// 避免上游返回 "Store must be set to false"。
|
||||||
if v, ok := reqBody["store"].(bool); !ok || v {
|
if v, ok := reqBody["store"].(bool); !ok || v {
|
||||||
reqBody["store"] = false
|
reqBody["store"] = false
|
||||||
result.Modified = true
|
result.Modified = true
|
||||||
@@ -119,10 +127,18 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
|
|||||||
reqBody["instructions"] = instructions
|
reqBody["instructions"] = instructions
|
||||||
result.Modified = true
|
result.Modified = true
|
||||||
}
|
}
|
||||||
|
} else if existingInstructions == "" {
|
||||||
|
// 未获取到 opencode 指令时,回退使用 Codex CLI 指令。
|
||||||
|
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
|
||||||
|
if codexInstructions != "" {
|
||||||
|
reqBody["instructions"] = codexInstructions
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。
|
||||||
if input, ok := reqBody["input"].([]any); ok {
|
if input, ok := reqBody["input"].([]any); ok {
|
||||||
input = filterCodexInput(input)
|
input = filterCodexInput(input, needsToolContinuation)
|
||||||
reqBody["input"] = input
|
reqBody["input"] = input
|
||||||
result.Modified = true
|
result.Modified = true
|
||||||
}
|
}
|
||||||
@@ -235,14 +251,75 @@ func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getOpenCodeCodexHeader() string {
|
func getOpenCodeCodexHeader() string {
|
||||||
return getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json")
|
// 优先从 opencode 仓库缓存获取指令。
|
||||||
|
opencodeInstructions := getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json")
|
||||||
|
|
||||||
|
// 若 opencode 指令可用,直接返回。
|
||||||
|
if opencodeInstructions != "" {
|
||||||
|
return opencodeInstructions
|
||||||
|
}
|
||||||
|
|
||||||
|
// 否则回退使用本地 Codex CLI 指令。
|
||||||
|
return getCodexCLIInstructions()
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCodexCLIInstructions() string {
|
||||||
|
return codexCLIInstructions
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetOpenCodeInstructions() string {
|
func GetOpenCodeInstructions() string {
|
||||||
return getOpenCodeCodexHeader()
|
return getOpenCodeCodexHeader()
|
||||||
}
|
}
|
||||||
|
|
||||||
func filterCodexInput(input []any) []any {
|
// GetCodexCLIInstructions 返回内置的 Codex CLI 指令内容。
|
||||||
|
func GetCodexCLIInstructions() string {
|
||||||
|
return getCodexCLIInstructions()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReplaceWithCodexInstructions 将请求 instructions 替换为内置 Codex 指令(必要时)。
|
||||||
|
func ReplaceWithCodexInstructions(reqBody map[string]any) bool {
|
||||||
|
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
|
||||||
|
if codexInstructions == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
existingInstructions, _ := reqBody["instructions"].(string)
|
||||||
|
if strings.TrimSpace(existingInstructions) != codexInstructions {
|
||||||
|
reqBody["instructions"] = codexInstructions
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsInstructionError 判断错误信息是否与指令格式/系统提示相关。
|
||||||
|
func IsInstructionError(errorMessage string) bool {
|
||||||
|
if errorMessage == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
lowerMsg := strings.ToLower(errorMessage)
|
||||||
|
instructionKeywords := []string{
|
||||||
|
"instruction",
|
||||||
|
"instructions",
|
||||||
|
"system prompt",
|
||||||
|
"system message",
|
||||||
|
"invalid prompt",
|
||||||
|
"prompt format",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, keyword := range instructionKeywords {
|
||||||
|
if strings.Contains(lowerMsg, keyword) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterCodexInput 按需过滤 item_reference 与 id。
|
||||||
|
// preserveReferences 为 true 时保持引用与 id,以满足续链请求对上下文的依赖。
|
||||||
|
func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||||
filtered := make([]any, 0, len(input))
|
filtered := make([]any, 0, len(input))
|
||||||
for _, item := range input {
|
for _, item := range input {
|
||||||
m, ok := item.(map[string]any)
|
m, ok := item.(map[string]any)
|
||||||
@@ -250,15 +327,62 @@ func filterCodexInput(input []any) []any {
|
|||||||
filtered = append(filtered, item)
|
filtered = append(filtered, item)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if typ, ok := m["type"].(string); ok && typ == "item_reference" {
|
typ, _ := m["type"].(string)
|
||||||
|
if typ == "item_reference" {
|
||||||
|
if !preserveReferences {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newItem := make(map[string]any, len(m))
|
||||||
|
for key, value := range m {
|
||||||
|
newItem[key] = value
|
||||||
|
}
|
||||||
|
filtered = append(filtered, newItem)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
delete(m, "id")
|
|
||||||
filtered = append(filtered, m)
|
newItem := m
|
||||||
|
copied := false
|
||||||
|
// 仅在需要修改字段时创建副本,避免直接改写原始输入。
|
||||||
|
ensureCopy := func() {
|
||||||
|
if copied {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
newItem = make(map[string]any, len(m))
|
||||||
|
for key, value := range m {
|
||||||
|
newItem[key] = value
|
||||||
|
}
|
||||||
|
copied = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if isCodexToolCallItemType(typ) {
|
||||||
|
if callID, ok := m["call_id"].(string); !ok || strings.TrimSpace(callID) == "" {
|
||||||
|
if id, ok := m["id"].(string); ok && strings.TrimSpace(id) != "" {
|
||||||
|
ensureCopy()
|
||||||
|
newItem["call_id"] = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !preserveReferences {
|
||||||
|
ensureCopy()
|
||||||
|
delete(newItem, "id")
|
||||||
|
if !isCodexToolCallItemType(typ) {
|
||||||
|
delete(newItem, "call_id")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered = append(filtered, newItem)
|
||||||
}
|
}
|
||||||
return filtered
|
return filtered
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isCodexToolCallItemType(typ string) bool {
|
||||||
|
if typ == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.HasSuffix(typ, "_call") || strings.HasSuffix(typ, "_call_output")
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeCodexTools(reqBody map[string]any) bool {
|
func normalizeCodexTools(reqBody map[string]any) bool {
|
||||||
rawTools, ok := reqBody["tools"]
|
rawTools, ok := reqBody["tools"]
|
||||||
if !ok || rawTools == nil {
|
if !ok || rawTools == nil {
|
||||||
|
|||||||
167
backend/internal/service/openai_codex_transform_test.go
Normal file
167
backend/internal/service/openai_codex_transform_test.go
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
|
||||||
|
// 续链场景:保留 item_reference 与 id,但不再强制 store=true。
|
||||||
|
setupCodexCache(t)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.2",
|
||||||
|
"input": []any{
|
||||||
|
map[string]any{"type": "item_reference", "id": "ref1", "text": "x"},
|
||||||
|
map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "ok", "id": "o1"},
|
||||||
|
},
|
||||||
|
"tool_choice": "auto",
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCodexOAuthTransform(reqBody)
|
||||||
|
|
||||||
|
// 未显式设置 store=true,默认为 false。
|
||||||
|
store, ok := reqBody["store"].(bool)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.False(t, store)
|
||||||
|
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, input, 2)
|
||||||
|
|
||||||
|
// 校验 input[0] 为 map,避免断言失败导致测试中断。
|
||||||
|
first, ok := input[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "item_reference", first["type"])
|
||||||
|
require.Equal(t, "ref1", first["id"])
|
||||||
|
|
||||||
|
// 校验 input[1] 为 map,确保后续字段断言安全。
|
||||||
|
second, ok := input[1].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "o1", second["id"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
|
||||||
|
// 续链场景:显式 store=false 不再强制为 true,保持 false。
|
||||||
|
setupCodexCache(t)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.1",
|
||||||
|
"store": false,
|
||||||
|
"input": []any{
|
||||||
|
map[string]any{"type": "function_call_output", "call_id": "call_1"},
|
||||||
|
},
|
||||||
|
"tool_choice": "auto",
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCodexOAuthTransform(reqBody)
|
||||||
|
|
||||||
|
store, ok := reqBody["store"].(bool)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.False(t, store)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) {
|
||||||
|
// 显式 store=true 也会强制为 false。
|
||||||
|
setupCodexCache(t)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.1",
|
||||||
|
"store": true,
|
||||||
|
"input": []any{
|
||||||
|
map[string]any{"type": "function_call_output", "call_id": "call_1"},
|
||||||
|
},
|
||||||
|
"tool_choice": "auto",
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCodexOAuthTransform(reqBody)
|
||||||
|
|
||||||
|
store, ok := reqBody["store"].(bool)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.False(t, store)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(t *testing.T) {
|
||||||
|
// 非续链场景:未设置 store 时默认 false,并移除 input 中的 id。
|
||||||
|
setupCodexCache(t)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.1",
|
||||||
|
"input": []any{
|
||||||
|
map[string]any{"type": "text", "id": "t1", "text": "hi"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCodexOAuthTransform(reqBody)
|
||||||
|
|
||||||
|
store, ok := reqBody["store"].(bool)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.False(t, store)
|
||||||
|
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, input, 1)
|
||||||
|
// 校验 input[0] 为 map,避免类型不匹配触发 errcheck。
|
||||||
|
item, ok := input[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
_, hasID := item["id"]
|
||||||
|
require.False(t, hasID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
|
||||||
|
input := []any{
|
||||||
|
map[string]any{"type": "item_reference", "id": "ref1"},
|
||||||
|
map[string]any{"type": "text", "id": "t1", "text": "hi"},
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := filterCodexInput(input, false)
|
||||||
|
require.Len(t, filtered, 1)
|
||||||
|
// 校验 filtered[0] 为 map,确保字段检查可靠。
|
||||||
|
item, ok := filtered[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "text", item["type"])
|
||||||
|
_, hasID := item["id"]
|
||||||
|
require.False(t, hasID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
|
||||||
|
// 空 input 应保持为空且不触发异常。
|
||||||
|
setupCodexCache(t)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.1",
|
||||||
|
"input": []any{},
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCodexOAuthTransform(reqBody)
|
||||||
|
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, input, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupCodexCache(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// 使用临时 HOME 避免触发网络拉取 header。
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
t.Setenv("HOME", tempDir)
|
||||||
|
|
||||||
|
cacheDir := filepath.Join(tempDir, ".opencode", "cache")
|
||||||
|
require.NoError(t, os.MkdirAll(cacheDir, 0o755))
|
||||||
|
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header.txt"), []byte("header"), 0o644))
|
||||||
|
|
||||||
|
meta := map[string]any{
|
||||||
|
"etag": "",
|
||||||
|
"lastFetch": time.Now().UTC().Format(time.RFC3339),
|
||||||
|
"lastChecked": time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(meta)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644))
|
||||||
|
}
|
||||||
@@ -42,6 +42,7 @@ var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`)
|
|||||||
var openaiAllowedHeaders = map[string]bool{
|
var openaiAllowedHeaders = map[string]bool{
|
||||||
"accept-language": true,
|
"accept-language": true,
|
||||||
"content-type": true,
|
"content-type": true,
|
||||||
|
"conversation_id": true,
|
||||||
"user-agent": true,
|
"user-agent": true,
|
||||||
"originator": true,
|
"originator": true,
|
||||||
"session_id": true,
|
"session_id": true,
|
||||||
@@ -85,6 +86,7 @@ type OpenAIGatewayService struct {
|
|||||||
userSubRepo UserSubscriptionRepository
|
userSubRepo UserSubscriptionRepository
|
||||||
cache GatewayCache
|
cache GatewayCache
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
|
schedulerSnapshot *SchedulerSnapshotService
|
||||||
concurrencyService *ConcurrencyService
|
concurrencyService *ConcurrencyService
|
||||||
billingService *BillingService
|
billingService *BillingService
|
||||||
rateLimitService *RateLimitService
|
rateLimitService *RateLimitService
|
||||||
@@ -101,6 +103,7 @@ func NewOpenAIGatewayService(
|
|||||||
userSubRepo UserSubscriptionRepository,
|
userSubRepo UserSubscriptionRepository,
|
||||||
cache GatewayCache,
|
cache GatewayCache,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
|
schedulerSnapshot *SchedulerSnapshotService,
|
||||||
concurrencyService *ConcurrencyService,
|
concurrencyService *ConcurrencyService,
|
||||||
billingService *BillingService,
|
billingService *BillingService,
|
||||||
rateLimitService *RateLimitService,
|
rateLimitService *RateLimitService,
|
||||||
@@ -115,6 +118,7 @@ func NewOpenAIGatewayService(
|
|||||||
userSubRepo: userSubRepo,
|
userSubRepo: userSubRepo,
|
||||||
cache: cache,
|
cache: cache,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
|
schedulerSnapshot: schedulerSnapshot,
|
||||||
concurrencyService: concurrencyService,
|
concurrencyService: concurrencyService,
|
||||||
billingService: billingService,
|
billingService: billingService,
|
||||||
rateLimitService: rateLimitService,
|
rateLimitService: rateLimitService,
|
||||||
@@ -159,7 +163,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
|||||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||||
if err == nil && accountID > 0 {
|
if err == nil && accountID > 0 {
|
||||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||||
// Refresh sticky session TTL
|
// Refresh sticky session TTL
|
||||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
|
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
|
||||||
@@ -170,16 +174,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2. Get schedulable OpenAI accounts
|
// 2. Get schedulable OpenAI accounts
|
||||||
var accounts []Account
|
accounts, err := s.listSchedulableAccounts(ctx, groupID)
|
||||||
var err error
|
|
||||||
// 简易模式:忽略分组限制,查询所有可用账号
|
|
||||||
if s.cfg.RunMode == config.RunModeSimple {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
|
||||||
} else if groupID != nil {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
|
|
||||||
} else {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -191,6 +186,11 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
|||||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
|
||||||
|
// avoid selecting accounts that were recently rate-limited/overloaded.
|
||||||
|
if !acc.IsSchedulable() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
// Check model support
|
// Check model support
|
||||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||||
continue
|
continue
|
||||||
@@ -301,7 +301,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
|
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
|
||||||
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||||
@@ -337,6 +337,12 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
if isExcluded(acc.ID) {
|
if isExcluded(acc.ID) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// Scheduler snapshots can be temporarily stale (bucket rebuild is throttled);
|
||||||
|
// re-check schedulability here so recently rate-limited/overloaded accounts
|
||||||
|
// are not selected again before the bucket is rebuilt.
|
||||||
|
if !acc.IsSchedulable() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -446,6 +452,10 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) {
|
func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) {
|
||||||
|
if s.schedulerSnapshot != nil {
|
||||||
|
accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, PlatformOpenAI, false)
|
||||||
|
return accounts, err
|
||||||
|
}
|
||||||
var accounts []Account
|
var accounts []Account
|
||||||
var err error
|
var err error
|
||||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||||
@@ -468,6 +478,13 @@ func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accoun
|
|||||||
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||||
|
if s.schedulerSnapshot != nil {
|
||||||
|
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||||
|
}
|
||||||
|
return s.accountRepo.GetByID(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||||
if s.cfg != nil {
|
if s.cfg != nil {
|
||||||
return s.cfg.Gateway.Scheduling
|
return s.cfg.Gateway.Scheduling
|
||||||
@@ -540,16 +557,35 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
|
|
||||||
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
|
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
|
||||||
|
|
||||||
// Apply model mapping (skip for Codex CLI for transparent forwarding)
|
// 对所有请求执行模型映射(包含 Codex CLI)。
|
||||||
mappedModel := reqModel
|
mappedModel := account.GetMappedModel(reqModel)
|
||||||
if !isCodexCLI {
|
if mappedModel != reqModel {
|
||||||
mappedModel = account.GetMappedModel(reqModel)
|
log.Printf("[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI)
|
||||||
if mappedModel != reqModel {
|
reqBody["model"] = mappedModel
|
||||||
reqBody["model"] = mappedModel
|
bodyModified = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。
|
||||||
|
if model, ok := reqBody["model"].(string); ok {
|
||||||
|
normalizedModel := normalizeCodexModel(model)
|
||||||
|
if normalizedModel != "" && normalizedModel != model {
|
||||||
|
log.Printf("[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
|
||||||
|
model, normalizedModel, account.Name, account.Type, isCodexCLI)
|
||||||
|
reqBody["model"] = normalizedModel
|
||||||
|
mappedModel = normalizedModel
|
||||||
bodyModified = true
|
bodyModified = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 规范化 reasoning.effort 参数(minimal -> none),与上游允许值对齐。
|
||||||
|
if reasoning, ok := reqBody["reasoning"].(map[string]any); ok {
|
||||||
|
if effort, ok := reasoning["effort"].(string); ok && effort == "minimal" {
|
||||||
|
reasoning["effort"] = "none"
|
||||||
|
bodyModified = true
|
||||||
|
log.Printf("[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if account.Type == AccountTypeOAuth && !isCodexCLI {
|
if account.Type == AccountTypeOAuth && !isCodexCLI {
|
||||||
codexResult := applyCodexOAuthTransform(reqBody)
|
codexResult := applyCodexOAuthTransform(reqBody)
|
||||||
if codexResult.Modified {
|
if codexResult.Modified {
|
||||||
@@ -563,6 +599,44 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle max_output_tokens based on platform and account type
|
||||||
|
if !isCodexCLI {
|
||||||
|
if maxOutputTokens, hasMaxOutputTokens := reqBody["max_output_tokens"]; hasMaxOutputTokens {
|
||||||
|
switch account.Platform {
|
||||||
|
case PlatformOpenAI:
|
||||||
|
// For OpenAI API Key, remove max_output_tokens (not supported)
|
||||||
|
// For OpenAI OAuth (Responses API), keep it (supported)
|
||||||
|
if account.Type == AccountTypeAPIKey {
|
||||||
|
delete(reqBody, "max_output_tokens")
|
||||||
|
bodyModified = true
|
||||||
|
}
|
||||||
|
case PlatformAnthropic:
|
||||||
|
// For Anthropic (Claude), convert to max_tokens
|
||||||
|
delete(reqBody, "max_output_tokens")
|
||||||
|
if _, hasMaxTokens := reqBody["max_tokens"]; !hasMaxTokens {
|
||||||
|
reqBody["max_tokens"] = maxOutputTokens
|
||||||
|
}
|
||||||
|
bodyModified = true
|
||||||
|
case PlatformGemini:
|
||||||
|
// For Gemini, remove (will be handled by Gemini-specific transform)
|
||||||
|
delete(reqBody, "max_output_tokens")
|
||||||
|
bodyModified = true
|
||||||
|
default:
|
||||||
|
// For unknown platforms, remove to be safe
|
||||||
|
delete(reqBody, "max_output_tokens")
|
||||||
|
bodyModified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also handle max_completion_tokens (similar logic)
|
||||||
|
if _, hasMaxCompletionTokens := reqBody["max_completion_tokens"]; hasMaxCompletionTokens {
|
||||||
|
if account.Type == AccountTypeAPIKey || account.Platform != PlatformOpenAI {
|
||||||
|
delete(reqBody, "max_completion_tokens")
|
||||||
|
bodyModified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Re-serialize body only if modified
|
// Re-serialize body only if modified
|
||||||
if bodyModified {
|
if bodyModified {
|
||||||
var err error
|
var err error
|
||||||
@@ -590,6 +664,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
proxyURL = account.Proxy.URL()
|
proxyURL = account.Proxy.URL()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Capture upstream request body for ops retry of this attempt.
|
||||||
|
if c != nil {
|
||||||
|
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
// Send request
|
// Send request
|
||||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -599,6 +678,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: 0,
|
UpstreamStatusCode: 0,
|
||||||
Kind: "request_error",
|
Kind: "request_error",
|
||||||
Message: safeErr,
|
Message: safeErr,
|
||||||
@@ -633,6 +713,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: "failover",
|
Kind: "failover",
|
||||||
@@ -742,9 +823,6 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
|||||||
if promptCacheKey != "" {
|
if promptCacheKey != "" {
|
||||||
req.Header.Set("conversation_id", promptCacheKey)
|
req.Header.Set("conversation_id", promptCacheKey)
|
||||||
req.Header.Set("session_id", promptCacheKey)
|
req.Header.Set("session_id", promptCacheKey)
|
||||||
} else {
|
|
||||||
req.Header.Del("conversation_id")
|
|
||||||
req.Header.Del("session_id")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -793,6 +871,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: "http_error",
|
Kind: "http_error",
|
||||||
@@ -823,6 +902,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
|
|||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
Kind: kind,
|
Kind: kind,
|
||||||
@@ -1042,6 +1122,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||||
|
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
||||||
|
if s.rateLimitService != nil {
|
||||||
|
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
|
||||||
|
}
|
||||||
sendErrorEvent("stream_timeout")
|
sendErrorEvent("stream_timeout")
|
||||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
|
|
||||||
@@ -1368,28 +1452,30 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
|
|
||||||
// Create usage log
|
// Create usage log
|
||||||
durationMs := int(result.Duration.Milliseconds())
|
durationMs := int(result.Duration.Milliseconds())
|
||||||
|
accountRateMultiplier := account.BillingRateMultiplier()
|
||||||
usageLog := &UsageLog{
|
usageLog := &UsageLog{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
APIKeyID: apiKey.ID,
|
APIKeyID: apiKey.ID,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
RequestID: result.RequestID,
|
RequestID: result.RequestID,
|
||||||
Model: result.Model,
|
Model: result.Model,
|
||||||
InputTokens: actualInputTokens,
|
InputTokens: actualInputTokens,
|
||||||
OutputTokens: result.Usage.OutputTokens,
|
OutputTokens: result.Usage.OutputTokens,
|
||||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||||
InputCost: cost.InputCost,
|
InputCost: cost.InputCost,
|
||||||
OutputCost: cost.OutputCost,
|
OutputCost: cost.OutputCost,
|
||||||
CacheCreationCost: cost.CacheCreationCost,
|
CacheCreationCost: cost.CacheCreationCost,
|
||||||
CacheReadCost: cost.CacheReadCost,
|
CacheReadCost: cost.CacheReadCost,
|
||||||
TotalCost: cost.TotalCost,
|
TotalCost: cost.TotalCost,
|
||||||
ActualCost: cost.ActualCost,
|
ActualCost: cost.ActualCost,
|
||||||
RateMultiplier: multiplier,
|
RateMultiplier: multiplier,
|
||||||
BillingType: billingType,
|
AccountRateMultiplier: &accountRateMultiplier,
|
||||||
Stream: result.Stream,
|
BillingType: billingType,
|
||||||
DurationMs: &durationMs,
|
Stream: result.Stream,
|
||||||
FirstTokenMs: result.FirstTokenMs,
|
DurationMs: &durationMs,
|
||||||
CreatedAt: time.Now(),
|
FirstTokenMs: result.FirstTokenMs,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
// 添加 UserAgent
|
// 添加 UserAgent
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package service
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -15,6 +16,129 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type stubOpenAIAccountRepo struct {
|
||||||
|
AccountRepository
|
||||||
|
accounts []Account
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r stubOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
|
||||||
|
return append([]Account(nil), r.accounts...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||||
|
return append([]Account(nil), r.accounts...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type stubConcurrencyCache struct {
|
||||||
|
ConcurrencyCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||||
|
out := make(map[int64]*AccountLoadInfo, len(accounts))
|
||||||
|
for _, acc := range accounts {
|
||||||
|
out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
resetAt := now.Add(10 * time.Minute)
|
||||||
|
groupID := int64(1)
|
||||||
|
|
||||||
|
rateLimited := Account{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Priority: 0,
|
||||||
|
RateLimitResetAt: &resetAt,
|
||||||
|
}
|
||||||
|
available := Account{
|
||||||
|
ID: 2,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Priority: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: stubOpenAIAccountRepo{accounts: []Account{rateLimited, available}},
|
||||||
|
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-5.2", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||||
|
}
|
||||||
|
if selection == nil || selection.Account == nil {
|
||||||
|
t.Fatalf("expected selection with account")
|
||||||
|
}
|
||||||
|
if selection.Account.ID != available.ID {
|
||||||
|
t.Fatalf("expected account %d, got %d", available.ID, selection.Account.ID)
|
||||||
|
}
|
||||||
|
if selection.ReleaseFunc != nil {
|
||||||
|
selection.ReleaseFunc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurrencyService(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
resetAt := now.Add(10 * time.Minute)
|
||||||
|
groupID := int64(1)
|
||||||
|
|
||||||
|
rateLimited := Account{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Priority: 0,
|
||||||
|
RateLimitResetAt: &resetAt,
|
||||||
|
}
|
||||||
|
available := Account{
|
||||||
|
ID: 2,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Priority: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: stubOpenAIAccountRepo{accounts: []Account{rateLimited, available}},
|
||||||
|
// concurrencyService is nil, forcing the non-load-batch selection path.
|
||||||
|
}
|
||||||
|
|
||||||
|
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-5.2", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||||
|
}
|
||||||
|
if selection == nil || selection.Account == nil {
|
||||||
|
t.Fatalf("expected selection with account")
|
||||||
|
}
|
||||||
|
if selection.Account.ID != available.ID {
|
||||||
|
t.Fatalf("expected account %d, got %d", available.ID, selection.Account.ID)
|
||||||
|
}
|
||||||
|
if selection.ReleaseFunc != nil {
|
||||||
|
selection.ReleaseFunc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIStreamingTimeout(t *testing.T) {
|
func TestOpenAIStreamingTimeout(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
|
|||||||
213
backend/internal/service/openai_tool_continuation.go
Normal file
213
backend/internal/service/openai_tool_continuation.go
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
|
||||||
|
// 满足以下任一信号即视为续链:previous_response_id、input 内包含 function_call_output/item_reference、
|
||||||
|
// 或显式声明 tools/tool_choice。
|
||||||
|
func NeedsToolContinuation(reqBody map[string]any) bool {
|
||||||
|
if reqBody == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if hasNonEmptyString(reqBody["previous_response_id"]) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if hasToolsSignal(reqBody) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if hasToolChoiceSignal(reqBody) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if inputHasType(reqBody, "function_call_output") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if inputHasType(reqBody, "item_reference") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasFunctionCallOutput 判断 input 是否包含 function_call_output,用于触发续链校验。
|
||||||
|
func HasFunctionCallOutput(reqBody map[string]any) bool {
|
||||||
|
if reqBody == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return inputHasType(reqBody, "function_call_output")
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call,
|
||||||
|
// 用于判断 function_call_output 是否具备可关联的上下文。
|
||||||
|
func HasToolCallContext(reqBody map[string]any) bool {
|
||||||
|
if reqBody == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, item := range input {
|
||||||
|
itemMap, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemType, _ := itemMap["type"].(string)
|
||||||
|
if itemType != "tool_call" && itemType != "function_call" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。
|
||||||
|
// 仅返回非空 call_id,用于与 item_reference.id 做匹配校验。
|
||||||
|
func FunctionCallOutputCallIDs(reqBody map[string]any) []string {
|
||||||
|
if reqBody == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
ids := make(map[string]struct{})
|
||||||
|
for _, item := range input {
|
||||||
|
itemMap, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemType, _ := itemMap["type"].(string)
|
||||||
|
if itemType != "function_call_output" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
|
||||||
|
ids[callID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
result := make([]string, 0, len(ids))
|
||||||
|
for id := range ids {
|
||||||
|
result = append(result, id)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。
|
||||||
|
func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool {
|
||||||
|
if reqBody == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, item := range input {
|
||||||
|
itemMap, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemType, _ := itemMap["type"].(string)
|
||||||
|
if itemType != "function_call_output" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
callID, _ := itemMap["call_id"].(string)
|
||||||
|
if strings.TrimSpace(callID) == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasItemReferenceForCallIDs 判断 item_reference.id 是否覆盖所有 call_id。
|
||||||
|
// 用于仅依赖引用项完成续链场景的校验。
|
||||||
|
func HasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool {
|
||||||
|
if reqBody == nil || len(callIDs) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
referenceIDs := make(map[string]struct{})
|
||||||
|
for _, item := range input {
|
||||||
|
itemMap, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemType, _ := itemMap["type"].(string)
|
||||||
|
if itemType != "item_reference" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
idValue, _ := itemMap["id"].(string)
|
||||||
|
idValue = strings.TrimSpace(idValue)
|
||||||
|
if idValue == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
referenceIDs[idValue] = struct{}{}
|
||||||
|
}
|
||||||
|
if len(referenceIDs) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, callID := range callIDs {
|
||||||
|
if _, ok := referenceIDs[callID]; !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// inputHasType 判断 input 中是否存在指定类型的 item。
|
||||||
|
func inputHasType(reqBody map[string]any, want string) bool {
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, item := range input {
|
||||||
|
itemMap, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemType, _ := itemMap["type"].(string)
|
||||||
|
if itemType == want {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasNonEmptyString 判断字段是否为非空字符串。
|
||||||
|
func hasNonEmptyString(value any) bool {
|
||||||
|
stringValue, ok := value.(string)
|
||||||
|
return ok && strings.TrimSpace(stringValue) != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasToolsSignal 判断 tools 字段是否显式声明(存在且不为空)。
|
||||||
|
func hasToolsSignal(reqBody map[string]any) bool {
|
||||||
|
raw, exists := reqBody["tools"]
|
||||||
|
if !exists || raw == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if tools, ok := raw.([]any); ok {
|
||||||
|
return len(tools) > 0
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasToolChoiceSignal 判断 tool_choice 是否显式声明(非空或非 nil)。
|
||||||
|
func hasToolChoiceSignal(reqBody map[string]any) bool {
|
||||||
|
raw, exists := reqBody["tool_choice"]
|
||||||
|
if !exists || raw == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch value := raw.(type) {
|
||||||
|
case string:
|
||||||
|
return strings.TrimSpace(value) != ""
|
||||||
|
case map[string]any:
|
||||||
|
return len(value) > 0
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
98
backend/internal/service/openai_tool_continuation_test.go
Normal file
98
backend/internal/service/openai_tool_continuation_test.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNeedsToolContinuationSignals(t *testing.T) {
|
||||||
|
// 覆盖所有触发续链的信号来源,确保判定逻辑完整。
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
body map[string]any
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{name: "nil", body: nil, want: false},
|
||||||
|
{name: "previous_response_id", body: map[string]any{"previous_response_id": "resp_1"}, want: true},
|
||||||
|
{name: "previous_response_id_blank", body: map[string]any{"previous_response_id": " "}, want: false},
|
||||||
|
{name: "function_call_output", body: map[string]any{"input": []any{map[string]any{"type": "function_call_output"}}}, want: true},
|
||||||
|
{name: "item_reference", body: map[string]any{"input": []any{map[string]any{"type": "item_reference"}}}, want: true},
|
||||||
|
{name: "tools", body: map[string]any{"tools": []any{map[string]any{"type": "function"}}}, want: true},
|
||||||
|
{name: "tools_empty", body: map[string]any{"tools": []any{}}, want: false},
|
||||||
|
{name: "tools_invalid", body: map[string]any{"tools": "bad"}, want: false},
|
||||||
|
{name: "tool_choice", body: map[string]any{"tool_choice": "auto"}, want: true},
|
||||||
|
{name: "tool_choice_object", body: map[string]any{"tool_choice": map[string]any{"type": "function"}}, want: true},
|
||||||
|
{name: "tool_choice_empty_object", body: map[string]any{"tool_choice": map[string]any{}}, want: false},
|
||||||
|
{name: "none", body: map[string]any{"input": []any{map[string]any{"type": "text", "text": "hi"}}}, want: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
require.Equal(t, tt.want, NeedsToolContinuation(tt.body))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasFunctionCallOutput(t *testing.T) {
|
||||||
|
// 仅当 input 中存在 function_call_output 才视为续链输出。
|
||||||
|
require.False(t, HasFunctionCallOutput(nil))
|
||||||
|
require.True(t, HasFunctionCallOutput(map[string]any{
|
||||||
|
"input": []any{map[string]any{"type": "function_call_output"}},
|
||||||
|
}))
|
||||||
|
require.False(t, HasFunctionCallOutput(map[string]any{
|
||||||
|
"input": "text",
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasToolCallContext(t *testing.T) {
|
||||||
|
// tool_call/function_call 必须包含 call_id,才能作为可关联上下文。
|
||||||
|
require.False(t, HasToolCallContext(nil))
|
||||||
|
require.True(t, HasToolCallContext(map[string]any{
|
||||||
|
"input": []any{map[string]any{"type": "tool_call", "call_id": "call_1"}},
|
||||||
|
}))
|
||||||
|
require.True(t, HasToolCallContext(map[string]any{
|
||||||
|
"input": []any{map[string]any{"type": "function_call", "call_id": "call_2"}},
|
||||||
|
}))
|
||||||
|
require.False(t, HasToolCallContext(map[string]any{
|
||||||
|
"input": []any{map[string]any{"type": "tool_call"}},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFunctionCallOutputCallIDs(t *testing.T) {
|
||||||
|
// 仅提取非空 call_id,去重后返回。
|
||||||
|
require.Empty(t, FunctionCallOutputCallIDs(nil))
|
||||||
|
callIDs := FunctionCallOutputCallIDs(map[string]any{
|
||||||
|
"input": []any{
|
||||||
|
map[string]any{"type": "function_call_output", "call_id": "call_1"},
|
||||||
|
map[string]any{"type": "function_call_output", "call_id": ""},
|
||||||
|
map[string]any{"type": "function_call_output", "call_id": "call_1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.ElementsMatch(t, []string{"call_1"}, callIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasFunctionCallOutputMissingCallID(t *testing.T) {
|
||||||
|
require.False(t, HasFunctionCallOutputMissingCallID(nil))
|
||||||
|
require.True(t, HasFunctionCallOutputMissingCallID(map[string]any{
|
||||||
|
"input": []any{map[string]any{"type": "function_call_output"}},
|
||||||
|
}))
|
||||||
|
require.False(t, HasFunctionCallOutputMissingCallID(map[string]any{
|
||||||
|
"input": []any{map[string]any{"type": "function_call_output", "call_id": "call_1"}},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasItemReferenceForCallIDs(t *testing.T) {
|
||||||
|
// item_reference 需要覆盖所有 call_id 才视为可关联上下文。
|
||||||
|
require.False(t, HasItemReferenceForCallIDs(nil, []string{"call_1"}))
|
||||||
|
require.False(t, HasItemReferenceForCallIDs(map[string]any{}, []string{"call_1"}))
|
||||||
|
req := map[string]any{
|
||||||
|
"input": []any{
|
||||||
|
map[string]any{"type": "item_reference", "id": "call_1"},
|
||||||
|
map[string]any{"type": "item_reference", "id": "call_2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.True(t, HasItemReferenceForCallIDs(req, []string{"call_1"}))
|
||||||
|
require.True(t, HasItemReferenceForCallIDs(req, []string{"call_1", "call_2"}))
|
||||||
|
require.False(t, HasItemReferenceForCallIDs(req, []string{"call_1", "call_3"}))
|
||||||
|
}
|
||||||
@@ -206,7 +206,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
scopePlatform, scopeGroupID := parseOpsAlertRuleScope(rule.Filters)
|
scopePlatform, scopeGroupID, scopeRegion := parseOpsAlertRuleScope(rule.Filters)
|
||||||
|
|
||||||
windowMinutes := rule.WindowMinutes
|
windowMinutes := rule.WindowMinutes
|
||||||
if windowMinutes <= 0 {
|
if windowMinutes <= 0 {
|
||||||
@@ -236,6 +236,17 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Scoped silencing: if a matching silence exists, skip creating a firing event.
|
||||||
|
if s.opsService != nil {
|
||||||
|
platform := strings.TrimSpace(scopePlatform)
|
||||||
|
region := scopeRegion
|
||||||
|
if platform != "" {
|
||||||
|
if ok, err := s.opsService.IsAlertSilenced(ctx, rule.ID, platform, scopeGroupID, region, now); err == nil && ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
latestEvent, err := s.opsRepo.GetLatestAlertEvent(ctx, rule.ID)
|
latestEvent, err := s.opsRepo.GetLatestAlertEvent(ctx, rule.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[OpsAlertEvaluator] get latest event failed (rule=%d): %v", rule.ID, err)
|
log.Printf("[OpsAlertEvaluator] get latest event failed (rule=%d): %v", rule.ID, err)
|
||||||
@@ -359,9 +370,9 @@ func requiredSustainedBreaches(sustainedMinutes int, interval time.Duration) int
|
|||||||
return required
|
return required
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseOpsAlertRuleScope(filters map[string]any) (platform string, groupID *int64) {
|
func parseOpsAlertRuleScope(filters map[string]any) (platform string, groupID *int64, region *string) {
|
||||||
if filters == nil {
|
if filters == nil {
|
||||||
return "", nil
|
return "", nil, nil
|
||||||
}
|
}
|
||||||
if v, ok := filters["platform"]; ok {
|
if v, ok := filters["platform"]; ok {
|
||||||
if s, ok := v.(string); ok {
|
if s, ok := v.(string); ok {
|
||||||
@@ -392,7 +403,15 @@ func parseOpsAlertRuleScope(filters map[string]any) (platform string, groupID *i
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return platform, groupID
|
if v, ok := filters["region"]; ok {
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
vv := strings.TrimSpace(s)
|
||||||
|
if vv != "" {
|
||||||
|
region = &vv
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return platform, groupID, region
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpsAlertEvaluatorService) computeRuleMetric(
|
func (s *OpsAlertEvaluatorService) computeRuleMetric(
|
||||||
@@ -504,16 +523,6 @@ func (s *OpsAlertEvaluatorService) computeRuleMetric(
|
|||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
return overview.UpstreamErrorRate * 100, true
|
return overview.UpstreamErrorRate * 100, true
|
||||||
case "p95_latency_ms":
|
|
||||||
if overview.Duration.P95 == nil {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
return float64(*overview.Duration.P95), true
|
|
||||||
case "p99_latency_ms":
|
|
||||||
if overview.Duration.P99 == nil {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
return float64(*overview.Duration.P99), true
|
|
||||||
default:
|
default:
|
||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,8 +8,9 @@ import "time"
|
|||||||
// with the existing ops dashboard frontend (backup style).
|
// with the existing ops dashboard frontend (backup style).
|
||||||
|
|
||||||
const (
|
const (
|
||||||
OpsAlertStatusFiring = "firing"
|
OpsAlertStatusFiring = "firing"
|
||||||
OpsAlertStatusResolved = "resolved"
|
OpsAlertStatusResolved = "resolved"
|
||||||
|
OpsAlertStatusManualResolved = "manual_resolved"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OpsAlertRule struct {
|
type OpsAlertRule struct {
|
||||||
@@ -58,12 +59,32 @@ type OpsAlertEvent struct {
|
|||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OpsAlertSilence struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
|
||||||
|
RuleID int64 `json:"rule_id"`
|
||||||
|
Platform string `json:"platform"`
|
||||||
|
GroupID *int64 `json:"group_id,omitempty"`
|
||||||
|
Region *string `json:"region,omitempty"`
|
||||||
|
|
||||||
|
Until time.Time `json:"until"`
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
|
||||||
|
CreatedBy *int64 `json:"created_by,omitempty"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
type OpsAlertEventFilter struct {
|
type OpsAlertEventFilter struct {
|
||||||
Limit int
|
Limit int
|
||||||
|
|
||||||
|
// Cursor pagination (descending by fired_at, then id).
|
||||||
|
BeforeFiredAt *time.Time
|
||||||
|
BeforeID *int64
|
||||||
|
|
||||||
// Optional filters.
|
// Optional filters.
|
||||||
Status string
|
Status string
|
||||||
Severity string
|
Severity string
|
||||||
|
EmailSent *bool
|
||||||
|
|
||||||
StartTime *time.Time
|
StartTime *time.Time
|
||||||
EndTime *time.Time
|
EndTime *time.Time
|
||||||
|
|||||||
@@ -88,6 +88,29 @@ func (s *OpsService) ListAlertEvents(ctx context.Context, filter *OpsAlertEventF
|
|||||||
return s.opsRepo.ListAlertEvents(ctx, filter)
|
return s.opsRepo.ListAlertEvents(ctx, filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpsService) GetAlertEventByID(ctx context.Context, eventID int64) (*OpsAlertEvent, error) {
|
||||||
|
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.opsRepo == nil {
|
||||||
|
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||||
|
}
|
||||||
|
if eventID <= 0 {
|
||||||
|
return nil, infraerrors.BadRequest("INVALID_EVENT_ID", "invalid event id")
|
||||||
|
}
|
||||||
|
ev, err := s.opsRepo.GetAlertEventByID(ctx, eventID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, infraerrors.NotFound("OPS_ALERT_EVENT_NOT_FOUND", "alert event not found")
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if ev == nil {
|
||||||
|
return nil, infraerrors.NotFound("OPS_ALERT_EVENT_NOT_FOUND", "alert event not found")
|
||||||
|
}
|
||||||
|
return ev, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpsService) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
|
func (s *OpsService) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
|
||||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -101,6 +124,49 @@ func (s *OpsService) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*Op
|
|||||||
return s.opsRepo.GetActiveAlertEvent(ctx, ruleID)
|
return s.opsRepo.GetActiveAlertEvent(ctx, ruleID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpsService) CreateAlertSilence(ctx context.Context, input *OpsAlertSilence) (*OpsAlertSilence, error) {
|
||||||
|
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.opsRepo == nil {
|
||||||
|
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||||
|
}
|
||||||
|
if input == nil {
|
||||||
|
return nil, infraerrors.BadRequest("INVALID_SILENCE", "invalid silence")
|
||||||
|
}
|
||||||
|
if input.RuleID <= 0 {
|
||||||
|
return nil, infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(input.Platform) == "" {
|
||||||
|
return nil, infraerrors.BadRequest("INVALID_PLATFORM", "invalid platform")
|
||||||
|
}
|
||||||
|
if input.Until.IsZero() {
|
||||||
|
return nil, infraerrors.BadRequest("INVALID_UNTIL", "invalid until")
|
||||||
|
}
|
||||||
|
|
||||||
|
created, err := s.opsRepo.CreateAlertSilence(ctx, input)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return created, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsService) IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error) {
|
||||||
|
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if s.opsRepo == nil {
|
||||||
|
return false, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||||
|
}
|
||||||
|
if ruleID <= 0 {
|
||||||
|
return false, infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(platform) == "" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return s.opsRepo.IsAlertSilenced(ctx, ruleID, platform, groupID, region, now)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpsService) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
|
func (s *OpsService) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
|
||||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -142,7 +208,11 @@ func (s *OpsService) UpdateAlertEventStatus(ctx context.Context, eventID int64,
|
|||||||
if eventID <= 0 {
|
if eventID <= 0 {
|
||||||
return infraerrors.BadRequest("INVALID_EVENT_ID", "invalid event id")
|
return infraerrors.BadRequest("INVALID_EVENT_ID", "invalid event id")
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(status) == "" {
|
status = strings.TrimSpace(status)
|
||||||
|
if status == "" {
|
||||||
|
return infraerrors.BadRequest("INVALID_STATUS", "invalid status")
|
||||||
|
}
|
||||||
|
if status != OpsAlertStatusResolved && status != OpsAlertStatusManualResolved {
|
||||||
return infraerrors.BadRequest("INVALID_STATUS", "invalid status")
|
return infraerrors.BadRequest("INVALID_STATUS", "invalid status")
|
||||||
}
|
}
|
||||||
return s.opsRepo.UpdateAlertEventStatus(ctx, eventID, status, resolvedAt)
|
return s.opsRepo.UpdateAlertEventStatus(ctx, eventID, status, resolvedAt)
|
||||||
|
|||||||
@@ -32,49 +32,38 @@ func computeDashboardHealthScore(now time.Time, overview *OpsDashboardOverview)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// computeBusinessHealth calculates business health score (0-100)
|
// computeBusinessHealth calculates business health score (0-100)
|
||||||
// Components: SLA (50%) + Error Rate (30%) + Latency (20%)
|
// Components: Error Rate (50%) + TTFT (50%)
|
||||||
func computeBusinessHealth(overview *OpsDashboardOverview) float64 {
|
func computeBusinessHealth(overview *OpsDashboardOverview) float64 {
|
||||||
// SLA score: 99.5% → 100, 95% → 0 (linear)
|
// Error rate score: 1% → 100, 10% → 0 (linear)
|
||||||
slaScore := 100.0
|
|
||||||
slaPct := clampFloat64(overview.SLA*100, 0, 100)
|
|
||||||
if slaPct < 99.5 {
|
|
||||||
if slaPct >= 95 {
|
|
||||||
slaScore = (slaPct - 95) / 4.5 * 100
|
|
||||||
} else {
|
|
||||||
slaScore = 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error rate score: 0.5% → 100, 5% → 0 (linear)
|
|
||||||
// Combines request errors and upstream errors
|
// Combines request errors and upstream errors
|
||||||
errorScore := 100.0
|
errorScore := 100.0
|
||||||
errorPct := clampFloat64(overview.ErrorRate*100, 0, 100)
|
errorPct := clampFloat64(overview.ErrorRate*100, 0, 100)
|
||||||
upstreamPct := clampFloat64(overview.UpstreamErrorRate*100, 0, 100)
|
upstreamPct := clampFloat64(overview.UpstreamErrorRate*100, 0, 100)
|
||||||
combinedErrorPct := math.Max(errorPct, upstreamPct) // Use worst case
|
combinedErrorPct := math.Max(errorPct, upstreamPct) // Use worst case
|
||||||
if combinedErrorPct > 0.5 {
|
if combinedErrorPct > 1.0 {
|
||||||
if combinedErrorPct <= 5 {
|
if combinedErrorPct <= 10.0 {
|
||||||
errorScore = (5 - combinedErrorPct) / 4.5 * 100
|
errorScore = (10.0 - combinedErrorPct) / 9.0 * 100
|
||||||
} else {
|
} else {
|
||||||
errorScore = 0
|
errorScore = 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Latency score: 1s → 100, 10s → 0 (linear)
|
// TTFT score: 1s → 100, 3s → 0 (linear)
|
||||||
// Uses P99 of duration (TTFT is less critical for overall health)
|
// Time to first token is critical for user experience
|
||||||
latencyScore := 100.0
|
ttftScore := 100.0
|
||||||
if overview.Duration.P99 != nil {
|
if overview.TTFT.P99 != nil {
|
||||||
p99 := float64(*overview.Duration.P99)
|
p99 := float64(*overview.TTFT.P99)
|
||||||
if p99 > 1000 {
|
if p99 > 1000 {
|
||||||
if p99 <= 10000 {
|
if p99 <= 3000 {
|
||||||
latencyScore = (10000 - p99) / 9000 * 100
|
ttftScore = (3000 - p99) / 2000 * 100
|
||||||
} else {
|
} else {
|
||||||
latencyScore = 0
|
ttftScore = 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Weighted combination
|
// Weighted combination: 50% error rate + 50% TTFT
|
||||||
return slaScore*0.5 + errorScore*0.3 + latencyScore*0.2
|
return errorScore*0.5 + ttftScore*0.5
|
||||||
}
|
}
|
||||||
|
|
||||||
// computeInfraHealth calculates infrastructure health score (0-100)
|
// computeInfraHealth calculates infrastructure health score (0-100)
|
||||||
|
|||||||
@@ -127,8 +127,8 @@ func TestComputeDashboardHealthScore_Comprehensive(t *testing.T) {
|
|||||||
MemoryUsagePercent: float64Ptr(75),
|
MemoryUsagePercent: float64Ptr(75),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
wantMin: 60,
|
wantMin: 96,
|
||||||
wantMax: 85,
|
wantMax: 97,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "DB failure",
|
name: "DB failure",
|
||||||
@@ -203,8 +203,8 @@ func TestComputeDashboardHealthScore_Comprehensive(t *testing.T) {
|
|||||||
MemoryUsagePercent: float64Ptr(30),
|
MemoryUsagePercent: float64Ptr(30),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
wantMin: 25,
|
wantMin: 84,
|
||||||
wantMax: 50,
|
wantMax: 85,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "combined failures - business healthy + infra degraded",
|
name: "combined failures - business healthy + infra degraded",
|
||||||
@@ -277,30 +277,41 @@ func TestComputeBusinessHealth(t *testing.T) {
|
|||||||
UpstreamErrorRate: 0,
|
UpstreamErrorRate: 0,
|
||||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||||
},
|
},
|
||||||
wantMin: 50,
|
wantMin: 100,
|
||||||
wantMax: 60,
|
wantMax: 100,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "error rate boundary 0.5%",
|
name: "error rate boundary 1%",
|
||||||
overview: &OpsDashboardOverview{
|
overview: &OpsDashboardOverview{
|
||||||
SLA: 0.995,
|
SLA: 0.99,
|
||||||
ErrorRate: 0.005,
|
ErrorRate: 0.01,
|
||||||
UpstreamErrorRate: 0,
|
UpstreamErrorRate: 0,
|
||||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||||
},
|
},
|
||||||
wantMin: 95,
|
wantMin: 100,
|
||||||
wantMax: 100,
|
wantMax: 100,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "latency boundary 1000ms",
|
name: "error rate 5%",
|
||||||
overview: &OpsDashboardOverview{
|
overview: &OpsDashboardOverview{
|
||||||
SLA: 0.995,
|
SLA: 0.95,
|
||||||
|
ErrorRate: 0.05,
|
||||||
|
UpstreamErrorRate: 0,
|
||||||
|
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||||
|
},
|
||||||
|
wantMin: 77,
|
||||||
|
wantMax: 78,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TTFT boundary 2s",
|
||||||
|
overview: &OpsDashboardOverview{
|
||||||
|
SLA: 0.99,
|
||||||
ErrorRate: 0,
|
ErrorRate: 0,
|
||||||
UpstreamErrorRate: 0,
|
UpstreamErrorRate: 0,
|
||||||
Duration: OpsPercentiles{P99: intPtr(1000)},
|
TTFT: OpsPercentiles{P99: intPtr(2000)},
|
||||||
},
|
},
|
||||||
wantMin: 95,
|
wantMin: 75,
|
||||||
wantMax: 100,
|
wantMax: 75,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "upstream error dominates",
|
name: "upstream error dominates",
|
||||||
@@ -310,7 +321,7 @@ func TestComputeBusinessHealth(t *testing.T) {
|
|||||||
UpstreamErrorRate: 0.03,
|
UpstreamErrorRate: 0.03,
|
||||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||||
},
|
},
|
||||||
wantMin: 75,
|
wantMin: 88,
|
||||||
wantMax: 90,
|
wantMax: 90,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,24 +6,43 @@ type OpsErrorLog struct {
|
|||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
|
||||||
Phase string `json:"phase"`
|
// Standardized classification
|
||||||
Type string `json:"type"`
|
// - phase: request|auth|routing|upstream|network|internal
|
||||||
|
// - owner: client|provider|platform
|
||||||
|
// - source: client_request|upstream_http|gateway
|
||||||
|
Phase string `json:"phase"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
|
||||||
|
Owner string `json:"error_owner"`
|
||||||
|
Source string `json:"error_source"`
|
||||||
|
|
||||||
Severity string `json:"severity"`
|
Severity string `json:"severity"`
|
||||||
|
|
||||||
StatusCode int `json:"status_code"`
|
StatusCode int `json:"status_code"`
|
||||||
Platform string `json:"platform"`
|
Platform string `json:"platform"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
|
||||||
LatencyMs *int `json:"latency_ms"`
|
IsRetryable bool `json:"is_retryable"`
|
||||||
|
RetryCount int `json:"retry_count"`
|
||||||
|
|
||||||
|
Resolved bool `json:"resolved"`
|
||||||
|
ResolvedAt *time.Time `json:"resolved_at"`
|
||||||
|
ResolvedByUserID *int64 `json:"resolved_by_user_id"`
|
||||||
|
ResolvedByUserName string `json:"resolved_by_user_name"`
|
||||||
|
ResolvedRetryID *int64 `json:"resolved_retry_id"`
|
||||||
|
ResolvedStatusRaw string `json:"-"`
|
||||||
|
|
||||||
ClientRequestID string `json:"client_request_id"`
|
ClientRequestID string `json:"client_request_id"`
|
||||||
RequestID string `json:"request_id"`
|
RequestID string `json:"request_id"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
|
|
||||||
UserID *int64 `json:"user_id"`
|
UserID *int64 `json:"user_id"`
|
||||||
APIKeyID *int64 `json:"api_key_id"`
|
UserEmail string `json:"user_email"`
|
||||||
AccountID *int64 `json:"account_id"`
|
APIKeyID *int64 `json:"api_key_id"`
|
||||||
GroupID *int64 `json:"group_id"`
|
AccountID *int64 `json:"account_id"`
|
||||||
|
AccountName string `json:"account_name"`
|
||||||
|
GroupID *int64 `json:"group_id"`
|
||||||
|
GroupName string `json:"group_name"`
|
||||||
|
|
||||||
ClientIP *string `json:"client_ip"`
|
ClientIP *string `json:"client_ip"`
|
||||||
RequestPath string `json:"request_path"`
|
RequestPath string `json:"request_path"`
|
||||||
@@ -67,9 +86,24 @@ type OpsErrorLogFilter struct {
|
|||||||
GroupID *int64
|
GroupID *int64
|
||||||
AccountID *int64
|
AccountID *int64
|
||||||
|
|
||||||
StatusCodes []int
|
StatusCodes []int
|
||||||
Phase string
|
StatusCodesOther bool
|
||||||
Query string
|
Phase string
|
||||||
|
Owner string
|
||||||
|
Source string
|
||||||
|
Resolved *bool
|
||||||
|
Query string
|
||||||
|
UserQuery string // Search by user email
|
||||||
|
|
||||||
|
// Optional correlation keys for exact matching.
|
||||||
|
RequestID string
|
||||||
|
ClientRequestID string
|
||||||
|
|
||||||
|
// View controls error categorization for list endpoints.
|
||||||
|
// - errors: show actionable errors (exclude business-limited / 429 / 529)
|
||||||
|
// - excluded: only show excluded errors
|
||||||
|
// - all: show everything
|
||||||
|
View string
|
||||||
|
|
||||||
Page int
|
Page int
|
||||||
PageSize int
|
PageSize int
|
||||||
@@ -90,12 +124,23 @@ type OpsRetryAttempt struct {
|
|||||||
SourceErrorID int64 `json:"source_error_id"`
|
SourceErrorID int64 `json:"source_error_id"`
|
||||||
Mode string `json:"mode"`
|
Mode string `json:"mode"`
|
||||||
PinnedAccountID *int64 `json:"pinned_account_id"`
|
PinnedAccountID *int64 `json:"pinned_account_id"`
|
||||||
|
PinnedAccountName string `json:"pinned_account_name"`
|
||||||
|
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
StartedAt *time.Time `json:"started_at"`
|
StartedAt *time.Time `json:"started_at"`
|
||||||
FinishedAt *time.Time `json:"finished_at"`
|
FinishedAt *time.Time `json:"finished_at"`
|
||||||
DurationMs *int64 `json:"duration_ms"`
|
DurationMs *int64 `json:"duration_ms"`
|
||||||
|
|
||||||
|
// Persisted execution results (best-effort)
|
||||||
|
Success *bool `json:"success"`
|
||||||
|
HTTPStatusCode *int `json:"http_status_code"`
|
||||||
|
UpstreamRequestID *string `json:"upstream_request_id"`
|
||||||
|
UsedAccountID *int64 `json:"used_account_id"`
|
||||||
|
UsedAccountName string `json:"used_account_name"`
|
||||||
|
ResponsePreview *string `json:"response_preview"`
|
||||||
|
ResponseTruncated *bool `json:"response_truncated"`
|
||||||
|
|
||||||
|
// Optional correlation
|
||||||
ResultRequestID *string `json:"result_request_id"`
|
ResultRequestID *string `json:"result_request_id"`
|
||||||
ResultErrorID *int64 `json:"result_error_id"`
|
ResultErrorID *int64 `json:"result_error_id"`
|
||||||
|
|
||||||
|
|||||||
@@ -14,9 +14,13 @@ type OpsRepository interface {
|
|||||||
InsertRetryAttempt(ctx context.Context, input *OpsInsertRetryAttemptInput) (int64, error)
|
InsertRetryAttempt(ctx context.Context, input *OpsInsertRetryAttemptInput) (int64, error)
|
||||||
UpdateRetryAttempt(ctx context.Context, input *OpsUpdateRetryAttemptInput) error
|
UpdateRetryAttempt(ctx context.Context, input *OpsUpdateRetryAttemptInput) error
|
||||||
GetLatestRetryAttemptForError(ctx context.Context, sourceErrorID int64) (*OpsRetryAttempt, error)
|
GetLatestRetryAttemptForError(ctx context.Context, sourceErrorID int64) (*OpsRetryAttempt, error)
|
||||||
|
ListRetryAttemptsByErrorID(ctx context.Context, sourceErrorID int64, limit int) ([]*OpsRetryAttempt, error)
|
||||||
|
UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64, resolvedAt *time.Time) error
|
||||||
|
|
||||||
// Lightweight window stats (for realtime WS / quick sampling).
|
// Lightweight window stats (for realtime WS / quick sampling).
|
||||||
GetWindowStats(ctx context.Context, filter *OpsDashboardFilter) (*OpsWindowStats, error)
|
GetWindowStats(ctx context.Context, filter *OpsDashboardFilter) (*OpsWindowStats, error)
|
||||||
|
// Lightweight realtime traffic summary (for the Ops dashboard header card).
|
||||||
|
GetRealtimeTrafficSummary(ctx context.Context, filter *OpsDashboardFilter) (*OpsRealtimeTrafficSummary, error)
|
||||||
|
|
||||||
GetDashboardOverview(ctx context.Context, filter *OpsDashboardFilter) (*OpsDashboardOverview, error)
|
GetDashboardOverview(ctx context.Context, filter *OpsDashboardFilter) (*OpsDashboardOverview, error)
|
||||||
GetThroughputTrend(ctx context.Context, filter *OpsDashboardFilter, bucketSeconds int) (*OpsThroughputTrendResponse, error)
|
GetThroughputTrend(ctx context.Context, filter *OpsDashboardFilter, bucketSeconds int) (*OpsThroughputTrendResponse, error)
|
||||||
@@ -37,12 +41,17 @@ type OpsRepository interface {
|
|||||||
DeleteAlertRule(ctx context.Context, id int64) error
|
DeleteAlertRule(ctx context.Context, id int64) error
|
||||||
|
|
||||||
ListAlertEvents(ctx context.Context, filter *OpsAlertEventFilter) ([]*OpsAlertEvent, error)
|
ListAlertEvents(ctx context.Context, filter *OpsAlertEventFilter) ([]*OpsAlertEvent, error)
|
||||||
|
GetAlertEventByID(ctx context.Context, eventID int64) (*OpsAlertEvent, error)
|
||||||
GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error)
|
GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error)
|
||||||
GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error)
|
GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error)
|
||||||
CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) (*OpsAlertEvent, error)
|
CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) (*OpsAlertEvent, error)
|
||||||
UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error
|
UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error
|
||||||
UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error
|
UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error
|
||||||
|
|
||||||
|
// Alert silences
|
||||||
|
CreateAlertSilence(ctx context.Context, input *OpsAlertSilence) (*OpsAlertSilence, error)
|
||||||
|
IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error)
|
||||||
|
|
||||||
// Pre-aggregation (hourly/daily) used for long-window dashboard performance.
|
// Pre-aggregation (hourly/daily) used for long-window dashboard performance.
|
||||||
UpsertHourlyMetrics(ctx context.Context, startTime, endTime time.Time) error
|
UpsertHourlyMetrics(ctx context.Context, startTime, endTime time.Time) error
|
||||||
UpsertDailyMetrics(ctx context.Context, startTime, endTime time.Time) error
|
UpsertDailyMetrics(ctx context.Context, startTime, endTime time.Time) error
|
||||||
@@ -71,6 +80,7 @@ type OpsInsertErrorLogInput struct {
|
|||||||
Severity string
|
Severity string
|
||||||
StatusCode int
|
StatusCode int
|
||||||
IsBusinessLimited bool
|
IsBusinessLimited bool
|
||||||
|
IsCountTokens bool // 是否为 count_tokens 请求
|
||||||
|
|
||||||
ErrorMessage string
|
ErrorMessage string
|
||||||
ErrorBody string
|
ErrorBody string
|
||||||
@@ -88,7 +98,6 @@ type OpsInsertErrorLogInput struct {
|
|||||||
// It is set by OpsService.RecordError before persisting.
|
// It is set by OpsService.RecordError before persisting.
|
||||||
UpstreamErrorsJSON *string
|
UpstreamErrorsJSON *string
|
||||||
|
|
||||||
DurationMs *int
|
|
||||||
TimeToFirstTokenMs *int64
|
TimeToFirstTokenMs *int64
|
||||||
|
|
||||||
RequestBodyJSON *string // sanitized json string (not raw bytes)
|
RequestBodyJSON *string // sanitized json string (not raw bytes)
|
||||||
@@ -121,7 +130,15 @@ type OpsUpdateRetryAttemptInput struct {
|
|||||||
FinishedAt time.Time
|
FinishedAt time.Time
|
||||||
DurationMs int64
|
DurationMs int64
|
||||||
|
|
||||||
// Optional correlation
|
// Persisted execution results (best-effort)
|
||||||
|
Success *bool
|
||||||
|
HTTPStatusCode *int
|
||||||
|
UpstreamRequestID *string
|
||||||
|
UsedAccountID *int64
|
||||||
|
ResponsePreview *string
|
||||||
|
ResponseTruncated *bool
|
||||||
|
|
||||||
|
// Optional correlation (legacy fields kept)
|
||||||
ResultRequestID *string
|
ResultRequestID *string
|
||||||
ResultErrorID *int64
|
ResultErrorID *int64
|
||||||
|
|
||||||
|
|||||||
36
backend/internal/service/ops_realtime_traffic.go
Normal file
36
backend/internal/service/ops_realtime_traffic.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetRealtimeTrafficSummary returns QPS/TPS current/peak/avg for the provided window.
|
||||||
|
// This is used by the Ops dashboard "Realtime Traffic" card and is intentionally lightweight.
|
||||||
|
func (s *OpsService) GetRealtimeTrafficSummary(ctx context.Context, filter *OpsDashboardFilter) (*OpsRealtimeTrafficSummary, error) {
|
||||||
|
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.opsRepo == nil {
|
||||||
|
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||||
|
}
|
||||||
|
if filter == nil {
|
||||||
|
return nil, infraerrors.BadRequest("OPS_FILTER_REQUIRED", "filter is required")
|
||||||
|
}
|
||||||
|
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||||
|
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_REQUIRED", "start_time/end_time are required")
|
||||||
|
}
|
||||||
|
if filter.StartTime.After(filter.EndTime) {
|
||||||
|
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
|
||||||
|
}
|
||||||
|
if filter.EndTime.Sub(filter.StartTime) > time.Hour {
|
||||||
|
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_TOO_LARGE", "invalid time range: max window is 1 hour")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Realtime traffic summary always uses raw logs (minute granularity peaks).
|
||||||
|
filter.QueryMode = OpsQueryModeRaw
|
||||||
|
|
||||||
|
return s.opsRepo.GetRealtimeTrafficSummary(ctx, filter)
|
||||||
|
}
|
||||||
19
backend/internal/service/ops_realtime_traffic_models.go
Normal file
19
backend/internal/service/ops_realtime_traffic_models.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// OpsRealtimeTrafficSummary is a lightweight summary used by the Ops dashboard "Realtime Traffic" card.
|
||||||
|
// It reports QPS/TPS current/peak/avg for the requested time window.
|
||||||
|
type OpsRealtimeTrafficSummary struct {
|
||||||
|
// Window is a normalized label (e.g. "1min", "5min", "30min", "1h").
|
||||||
|
Window string `json:"window"`
|
||||||
|
|
||||||
|
StartTime time.Time `json:"start_time"`
|
||||||
|
EndTime time.Time `json:"end_time"`
|
||||||
|
|
||||||
|
Platform string `json:"platform"`
|
||||||
|
GroupID *int64 `json:"group_id"`
|
||||||
|
|
||||||
|
QPS OpsRateSummary `json:"qps"`
|
||||||
|
TPS OpsRateSummary `json:"tps"`
|
||||||
|
}
|
||||||
@@ -108,6 +108,10 @@ func (w *limitedResponseWriter) truncated() bool {
|
|||||||
return w.totalWritten > int64(w.limit)
|
return w.totalWritten > int64(w.limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
OpsRetryModeUpstreamEvent = "upstream_event"
|
||||||
|
)
|
||||||
|
|
||||||
func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, errorID int64, mode string, pinnedAccountID *int64) (*OpsRetryResult, error) {
|
func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, errorID int64, mode string, pinnedAccountID *int64) (*OpsRetryResult, error) {
|
||||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -123,6 +127,81 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
|
|||||||
return nil, infraerrors.BadRequest("OPS_RETRY_INVALID_MODE", "mode must be client or upstream")
|
return nil, infraerrors.BadRequest("OPS_RETRY_INVALID_MODE", "mode must be client or upstream")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
errorLog, err := s.GetErrorLogByID(ctx, errorID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if errorLog == nil {
|
||||||
|
return nil, infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(errorLog.RequestBody) == "" {
|
||||||
|
return nil, infraerrors.BadRequest("OPS_RETRY_NO_REQUEST_BODY", "No request body found to retry")
|
||||||
|
}
|
||||||
|
|
||||||
|
var pinned *int64
|
||||||
|
if mode == OpsRetryModeUpstream {
|
||||||
|
if pinnedAccountID != nil && *pinnedAccountID > 0 {
|
||||||
|
pinned = pinnedAccountID
|
||||||
|
} else if errorLog.AccountID != nil && *errorLog.AccountID > 0 {
|
||||||
|
pinned = errorLog.AccountID
|
||||||
|
} else {
|
||||||
|
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "pinned_account_id is required for upstream retry")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.retryWithErrorLog(ctx, requestedByUserID, errorID, mode, mode, pinned, errorLog)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetryUpstreamEvent retries a specific upstream attempt captured inside ops_error_logs.upstream_errors.
|
||||||
|
// idx is 0-based. It always pins the original event account_id.
|
||||||
|
func (s *OpsService) RetryUpstreamEvent(ctx context.Context, requestedByUserID int64, errorID int64, idx int) (*OpsRetryResult, error) {
|
||||||
|
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.opsRepo == nil {
|
||||||
|
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||||
|
}
|
||||||
|
if idx < 0 {
|
||||||
|
return nil, infraerrors.BadRequest("OPS_RETRY_INVALID_UPSTREAM_IDX", "invalid upstream idx")
|
||||||
|
}
|
||||||
|
|
||||||
|
errorLog, err := s.GetErrorLogByID(ctx, errorID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if errorLog == nil {
|
||||||
|
return nil, infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
events, err := ParseOpsUpstreamErrors(errorLog.UpstreamErrors)
|
||||||
|
if err != nil {
|
||||||
|
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_EVENTS_INVALID", "invalid upstream_errors")
|
||||||
|
}
|
||||||
|
if idx >= len(events) {
|
||||||
|
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_IDX_OOB", "upstream idx out of range")
|
||||||
|
}
|
||||||
|
ev := events[idx]
|
||||||
|
if ev == nil {
|
||||||
|
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_EVENT_MISSING", "upstream event missing")
|
||||||
|
}
|
||||||
|
if ev.AccountID <= 0 {
|
||||||
|
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "account_id is required for upstream retry")
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamBody := strings.TrimSpace(ev.UpstreamRequestBody)
|
||||||
|
if upstreamBody == "" {
|
||||||
|
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_NO_REQUEST_BODY", "No upstream request body found to retry")
|
||||||
|
}
|
||||||
|
|
||||||
|
override := *errorLog
|
||||||
|
override.RequestBody = upstreamBody
|
||||||
|
pinned := ev.AccountID
|
||||||
|
|
||||||
|
// Persist as upstream_event, execute as upstream pinned retry.
|
||||||
|
return s.retryWithErrorLog(ctx, requestedByUserID, errorID, OpsRetryModeUpstreamEvent, OpsRetryModeUpstream, &pinned, &override)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsService) retryWithErrorLog(ctx context.Context, requestedByUserID int64, errorID int64, mode string, execMode string, pinnedAccountID *int64, errorLog *OpsErrorLogDetail) (*OpsRetryResult, error) {
|
||||||
latest, err := s.opsRepo.GetLatestRetryAttemptForError(ctx, errorID)
|
latest, err := s.opsRepo.GetLatestRetryAttemptForError(ctx, errorID)
|
||||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, infraerrors.InternalServer("OPS_RETRY_LOAD_LATEST_FAILED", "Failed to check retry status").WithCause(err)
|
return nil, infraerrors.InternalServer("OPS_RETRY_LOAD_LATEST_FAILED", "Failed to check retry status").WithCause(err)
|
||||||
@@ -144,22 +223,18 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
errorLog, err := s.GetErrorLogByID(ctx, errorID)
|
if errorLog == nil || strings.TrimSpace(errorLog.RequestBody) == "" {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(errorLog.RequestBody) == "" {
|
|
||||||
return nil, infraerrors.BadRequest("OPS_RETRY_NO_REQUEST_BODY", "No request body found to retry")
|
return nil, infraerrors.BadRequest("OPS_RETRY_NO_REQUEST_BODY", "No request body found to retry")
|
||||||
}
|
}
|
||||||
|
|
||||||
var pinned *int64
|
var pinned *int64
|
||||||
if mode == OpsRetryModeUpstream {
|
if execMode == OpsRetryModeUpstream {
|
||||||
if pinnedAccountID != nil && *pinnedAccountID > 0 {
|
if pinnedAccountID != nil && *pinnedAccountID > 0 {
|
||||||
pinned = pinnedAccountID
|
pinned = pinnedAccountID
|
||||||
} else if errorLog.AccountID != nil && *errorLog.AccountID > 0 {
|
} else if errorLog.AccountID != nil && *errorLog.AccountID > 0 {
|
||||||
pinned = errorLog.AccountID
|
pinned = errorLog.AccountID
|
||||||
} else {
|
} else {
|
||||||
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "pinned_account_id is required for upstream retry")
|
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "account_id is required for upstream retry")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -196,7 +271,7 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
|
|||||||
execCtx, cancel := context.WithTimeout(ctx, opsRetryTimeout)
|
execCtx, cancel := context.WithTimeout(ctx, opsRetryTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
execRes := s.executeRetry(execCtx, errorLog, mode, pinned)
|
execRes := s.executeRetry(execCtx, errorLog, execMode, pinned)
|
||||||
|
|
||||||
finishedAt := time.Now()
|
finishedAt := time.Now()
|
||||||
result.FinishedAt = finishedAt
|
result.FinishedAt = finishedAt
|
||||||
@@ -220,27 +295,40 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
|
|||||||
msg := result.ErrorMessage
|
msg := result.ErrorMessage
|
||||||
updateErrMsg = &msg
|
updateErrMsg = &msg
|
||||||
}
|
}
|
||||||
|
// Keep legacy result_request_id empty; use upstream_request_id instead.
|
||||||
var resultRequestID *string
|
var resultRequestID *string
|
||||||
if strings.TrimSpace(result.UpstreamRequestID) != "" {
|
|
||||||
v := result.UpstreamRequestID
|
|
||||||
resultRequestID = &v
|
|
||||||
}
|
|
||||||
|
|
||||||
finalStatus := result.Status
|
finalStatus := result.Status
|
||||||
if strings.TrimSpace(finalStatus) == "" {
|
if strings.TrimSpace(finalStatus) == "" {
|
||||||
finalStatus = opsRetryStatusFailed
|
finalStatus = opsRetryStatusFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
success := strings.EqualFold(finalStatus, opsRetryStatusSucceeded)
|
||||||
|
httpStatus := result.HTTPStatusCode
|
||||||
|
upstreamReqID := result.UpstreamRequestID
|
||||||
|
usedAccountID := result.UsedAccountID
|
||||||
|
preview := result.ResponsePreview
|
||||||
|
truncated := result.ResponseTruncated
|
||||||
|
|
||||||
if err := s.opsRepo.UpdateRetryAttempt(updateCtx, &OpsUpdateRetryAttemptInput{
|
if err := s.opsRepo.UpdateRetryAttempt(updateCtx, &OpsUpdateRetryAttemptInput{
|
||||||
ID: attemptID,
|
ID: attemptID,
|
||||||
Status: finalStatus,
|
Status: finalStatus,
|
||||||
FinishedAt: finishedAt,
|
FinishedAt: finishedAt,
|
||||||
DurationMs: result.DurationMs,
|
DurationMs: result.DurationMs,
|
||||||
ResultRequestID: resultRequestID,
|
Success: &success,
|
||||||
ErrorMessage: updateErrMsg,
|
HTTPStatusCode: &httpStatus,
|
||||||
|
UpstreamRequestID: &upstreamReqID,
|
||||||
|
UsedAccountID: usedAccountID,
|
||||||
|
ResponsePreview: &preview,
|
||||||
|
ResponseTruncated: &truncated,
|
||||||
|
ResultRequestID: resultRequestID,
|
||||||
|
ErrorMessage: updateErrMsg,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
// Best-effort: retry itself already executed; do not fail the API response.
|
|
||||||
log.Printf("[Ops] UpdateRetryAttempt failed: %v", err)
|
log.Printf("[Ops] UpdateRetryAttempt failed: %v", err)
|
||||||
|
} else if success {
|
||||||
|
if err := s.opsRepo.UpdateErrorResolution(updateCtx, errorID, true, &requestedByUserID, &attemptID, &finishedAt); err != nil {
|
||||||
|
log.Printf("[Ops] UpdateErrorResolution failed: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user