mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 15:32:13 +08:00
Compare commits
280 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
426d691c95 | ||
|
|
e9a4c8ab97 | ||
|
|
34cc02f8c7 | ||
|
|
624d9fddb7 | ||
|
|
47fbe43324 | ||
|
|
1245f07a2d | ||
|
|
839975b0cf | ||
|
|
8c1233393f | ||
|
|
9cdb0568cc | ||
|
|
74e05b83ea | ||
|
|
4ded9e7d49 | ||
|
|
716272a1e2 | ||
|
|
9cc8352593 | ||
|
|
43a1031e38 | ||
|
|
a5547b2f30 | ||
|
|
b0aa23540b | ||
|
|
ffaa6c4a17 | ||
|
|
fbf72f0ec4 | ||
|
|
909b8a8f9c | ||
|
|
4a0fe3b143 | ||
|
|
a1292fac81 | ||
|
|
7f98be4f91 | ||
|
|
fd73b8875d | ||
|
|
f9ab1daa3c | ||
|
|
d27b847442 | ||
|
|
dac6bc2228 | ||
|
|
4bd3dbf2ce | ||
|
|
226df1c23a | ||
|
|
2665230a09 | ||
|
|
4f0c2b794c | ||
|
|
e756064c19 | ||
|
|
17dfb0af01 | ||
|
|
ff74f517df | ||
|
|
477a9a180f | ||
|
|
da48df06d2 | ||
|
|
39fad63ccf | ||
|
|
5602d02b1b | ||
|
|
81989eed1c | ||
|
|
192efb84a0 | ||
|
|
8672347f93 | ||
|
|
5e5d4a513b | ||
|
|
88b6358472 | ||
|
|
dd8d5e2c42 | ||
|
|
d91e2328fb | ||
|
|
2a16735495 | ||
|
|
292f25f9ca | ||
|
|
c92e37775a | ||
|
|
f6ed3d1456 | ||
|
|
84686753e8 | ||
|
|
91f01309da | ||
|
|
57a1fc9d33 | ||
|
|
c95a864975 | ||
|
|
7a83db6180 | ||
|
|
a8513da7ff | ||
|
|
53534d3956 | ||
|
|
cc07a0e295 | ||
|
|
e7bc62500b | ||
|
|
c8fb9ef3a5 | ||
|
|
eb5e6214bc | ||
|
|
568d6ee10e | ||
|
|
6aef1af76e | ||
|
|
a54852e129 | ||
|
|
668118def1 | ||
|
|
73e6b160f8 | ||
|
|
6fec141de6 | ||
|
|
31cde6c555 | ||
|
|
b1a980f344 | ||
|
|
00d9fbd220 | ||
|
|
4f4c9679bf | ||
|
|
3dab71729d | ||
|
|
2f6f758670 | ||
|
|
090c8981dd | ||
|
|
fbb572948d | ||
|
|
a652b513d3 | ||
|
|
ccfeaeb22d | ||
|
|
4c12799a95 | ||
|
|
0f8d42c577 | ||
|
|
03c7578713 | ||
|
|
de6797c560 | ||
|
|
46ae08ecb7 | ||
|
|
2028cc29b7 | ||
|
|
f6360e0bf3 | ||
|
|
9abda1bc59 | ||
|
|
2a94cc76a6 | ||
|
|
150b315a7b | ||
|
|
a07174c191 | ||
|
|
fb839ae6ca | ||
|
|
bdc426a774 | ||
|
|
32fff3798c | ||
|
|
2b02c6635d | ||
|
|
771baa66ee | ||
|
|
a82029b0cf | ||
|
|
0c2a901af4 | ||
|
|
bd18f4b8ef | ||
|
|
bf7b79f2f0 | ||
|
|
45e8598d32 | ||
|
|
8391d480c9 | ||
|
|
d17f853a5f | ||
|
|
ef5a41057f | ||
|
|
c115c9e048 | ||
|
|
6941315432 | ||
|
|
8b071cc665 | ||
|
|
959f6c538a | ||
|
|
217b3b59c0 | ||
|
|
ec916a3197 | ||
|
|
22eb72e0f9 | ||
|
|
07ba64c666 | ||
|
|
f22bc59fe3 | ||
|
|
0ce8666cc0 | ||
|
|
5427a9e422 | ||
|
|
5e9f5efbe3 | ||
|
|
a7a0017aa8 | ||
|
|
9078b17a41 | ||
|
|
14a3694a9a | ||
|
|
b9b4db3df5 | ||
|
|
bc1d7edc58 | ||
|
|
5a6f60a954 | ||
|
|
a61cc2cb24 | ||
|
|
31933c8a60 | ||
|
|
78bccd032d | ||
|
|
ae21db77ec | ||
|
|
ac7503d95f | ||
|
|
69c4b17a9b | ||
|
|
a7165b0f73 | ||
|
|
cc0fca35ec | ||
|
|
dae0d5321f | ||
|
|
34415db7ed | ||
|
|
28e46e0e7c | ||
|
|
7379423325 | ||
|
|
74a3c74514 | ||
|
|
3d6d131889 | ||
|
|
b0569d873a | ||
|
|
d9433699db | ||
|
|
92234857f7 | ||
|
|
8efa361728 | ||
|
|
1be3eacad5 | ||
|
|
34d6b0a601 | ||
|
|
eb432a49ed | ||
|
|
04811c00cb | ||
|
|
06093d4f79 | ||
|
|
2055a60bcb | ||
|
|
cc892744bc | ||
|
|
577ee16108 | ||
|
|
392a8ac7ea | ||
|
|
226920064b | ||
|
|
19865b865f | ||
|
|
e3f812c2fe | ||
|
|
c9f79dee66 | ||
|
|
c659788022 | ||
|
|
aeb987ceb1 | ||
|
|
b478982484 | ||
|
|
fe71ee57b3 | ||
|
|
fba3d21a35 | ||
|
|
455576300c | ||
|
|
821968903c | ||
|
|
452fa53c0d | ||
|
|
95fe1e818f | ||
|
|
a61042bca0 | ||
|
|
b4abfae4de | ||
|
|
c02c8646a6 | ||
|
|
3ff2ca8d41 | ||
|
|
415840088e | ||
|
|
c4f6c89b65 | ||
|
|
539b41f421 | ||
|
|
b2ff326ced | ||
|
|
8b95d16220 | ||
|
|
a478822b8e | ||
|
|
23aa69f56f | ||
|
|
b36f3db9de | ||
|
|
e93f086485 | ||
|
|
930e9ee55c | ||
|
|
38961ba10e | ||
|
|
93b5b7474b | ||
|
|
f862ddc9ff | ||
|
|
b59032304c | ||
|
|
3ba4d535e3 | ||
|
|
5b37e9aea4 | ||
|
|
1820389a05 | ||
|
|
35e3a89385 | ||
|
|
5f890e85e7 | ||
|
|
10bc7f7042 | ||
|
|
a65fd9dee8 | ||
|
|
1bb4c76deb | ||
|
|
aab44f9fc8 | ||
|
|
0a848e7578 | ||
|
|
90bce60b85 | ||
|
|
c22d51ee41 | ||
|
|
a458e684bc | ||
|
|
87b4662993 | ||
|
|
3a100339b9 | ||
|
|
47eb3c8888 | ||
|
|
4672a6fac3 | ||
|
|
82743704e4 | ||
|
|
cc2d064ab4 | ||
|
|
27214f8657 | ||
|
|
28de614dfb | ||
|
|
850183c269 | ||
|
|
2a5ef6d3f5 | ||
|
|
1d231c6cc3 | ||
|
|
20c71acb3b | ||
|
|
52ad7c6e9c | ||
|
|
5aaaffe4d1 | ||
|
|
5354ba3662 | ||
|
|
2daf13c4c8 | ||
|
|
16a90f3d3a | ||
|
|
8a0ff15242 | ||
|
|
8c993dfd35 | ||
|
|
2a6fb1e456 | ||
|
|
9e6cd36af4 | ||
|
|
f25f992a30 | ||
|
|
841d7ef2f2 | ||
|
|
a7a49be850 | ||
|
|
d5eab7da3b | ||
|
|
9b10241561 | ||
|
|
76448ab555 | ||
|
|
9584af5cb4 | ||
|
|
6fabddcb0b | ||
|
|
5efeabb0c6 | ||
|
|
806f402bba | ||
|
|
5432087d96 | ||
|
|
02cb14c7b8 | ||
|
|
9bdb45be7c | ||
|
|
514c0562e0 | ||
|
|
371275ec34 | ||
|
|
ec24a3c361 | ||
|
|
d89e797bfc | ||
|
|
55e469c7fe | ||
|
|
fb99ceacc7 | ||
|
|
daf10907e4 | ||
|
|
b3b2868f55 | ||
|
|
25b00abca1 | ||
|
|
8d0767352b | ||
|
|
918a253851 | ||
|
|
63711067e6 | ||
|
|
7158b38897 | ||
|
|
7f317b9093 | ||
|
|
7c4309ea24 | ||
|
|
5013290486 | ||
|
|
8cf3e9a620 | ||
|
|
060699c3b8 | ||
|
|
2ca6c631ac | ||
|
|
967e25878f | ||
|
|
182683814b | ||
|
|
99cbfa1567 | ||
|
|
3f8c8d70ad | ||
|
|
9c567fad92 | ||
|
|
33f58d583d | ||
|
|
0abb3a6843 | ||
|
|
3663951d11 | ||
|
|
1e169685f4 | ||
|
|
f38a3e7585 | ||
|
|
b8da5d45ce | ||
|
|
659df6e220 | ||
|
|
d601768016 | ||
|
|
16ddc6a83b | ||
|
|
340dc9cadb | ||
|
|
55fced3942 | ||
|
|
7bbf49fd65 | ||
|
|
eea6c2d02c | ||
|
|
70eaa450db | ||
|
|
55796a118d | ||
|
|
9a22d1a690 | ||
|
|
c9d21d53e6 | ||
|
|
e1015c2759 | ||
|
|
d7fa47d732 | ||
|
|
3d6e01a58f | ||
|
|
f9713e8733 | ||
|
|
0e44829720 | ||
|
|
f0ece82111 | ||
|
|
9618cb5643 | ||
|
|
9c02ab789d | ||
|
|
11bfc807d7 | ||
|
|
c2a6ca8d3a | ||
|
|
7b1cf2c495 | ||
|
|
da1f3d61be | ||
|
|
dc3cd62125 | ||
|
|
bc404d4fc1 | ||
|
|
a4a0c0e2cc | ||
|
|
c7abfe67b5 | ||
|
|
4e3476a669 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -83,6 +83,8 @@ temp/
|
||||
*.log
|
||||
*.bak
|
||||
.cache/
|
||||
.dev/
|
||||
.serena/
|
||||
|
||||
# ===================
|
||||
# 构建产物
|
||||
@@ -127,3 +129,4 @@ deploy/docker-compose.override.yml
|
||||
.gocache/
|
||||
vite.config.js
|
||||
docs/*
|
||||
.serena/
|
||||
164
PR_DESCRIPTION.md
Normal file
164
PR_DESCRIPTION.md
Normal file
@@ -0,0 +1,164 @@
|
||||
## 概述
|
||||
|
||||
全面增强运维监控系统(Ops)的错误日志管理和告警静默功能,优化前端 UI 组件代码质量和用户体验。本次更新重构了核心服务层和数据访问层,提升系统可维护性和运维效率。
|
||||
|
||||
## 主要改动
|
||||
|
||||
### 1. 错误日志查询优化
|
||||
|
||||
**功能特性:**
|
||||
- 新增 GetErrorLogByID 接口,支持按 ID 精确查询错误详情
|
||||
- 优化错误日志过滤逻辑,支持多维度筛选(平台、阶段、来源、所有者等)
|
||||
- 改进查询参数处理,简化代码结构
|
||||
- 增强错误分类和标准化处理
|
||||
- 支持错误解决状态追踪(resolved 字段)
|
||||
|
||||
**技术实现:**
|
||||
- `ops_handler.go` - 新增单条错误日志查询接口
|
||||
- `ops_repo.go` - 优化数据查询和过滤条件构建
|
||||
- `ops_models.go` - 扩展错误日志数据模型
|
||||
- 前端 API 接口同步更新
|
||||
|
||||
### 2. 告警静默功能
|
||||
|
||||
**功能特性:**
|
||||
- 支持按规则、平台、分组、区域等维度静默告警
|
||||
- 可设置静默时长和原因说明
|
||||
- 静默记录可追溯,记录创建人和创建时间
|
||||
- 自动过期机制,避免永久静默
|
||||
|
||||
**技术实现:**
|
||||
- `037_ops_alert_silences.sql` - 新增告警静默表
|
||||
- `ops_alerts.go` - 告警静默逻辑实现
|
||||
- `ops_alerts_handler.go` - 告警静默 API 接口
|
||||
- `OpsAlertEventsCard.vue` - 前端告警静默操作界面
|
||||
|
||||
**数据库结构:**
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| rule_id | BIGINT | 告警规则 ID |
|
||||
| platform | VARCHAR(64) | 平台标识 |
|
||||
| group_id | BIGINT | 分组 ID(可选) |
|
||||
| region | VARCHAR(64) | 区域(可选) |
|
||||
| until | TIMESTAMPTZ | 静默截止时间 |
|
||||
| reason | TEXT | 静默原因 |
|
||||
| created_by | BIGINT | 创建人 ID |
|
||||
|
||||
### 3. 错误分类标准化
|
||||
|
||||
**功能特性:**
|
||||
- 统一错误阶段分类(request|auth|routing|upstream|network|internal)
|
||||
- 规范错误归属分类(client|provider|platform)
|
||||
- 标准化错误来源分类(client_request|upstream_http|gateway)
|
||||
- 自动迁移历史数据到新分类体系
|
||||
|
||||
**技术实现:**
|
||||
- `038_ops_errors_resolution_retry_results_and_standardize_classification.sql` - 分类标准化迁移
|
||||
- 自动映射历史遗留分类到新标准
|
||||
- 自动解决已恢复的上游错误(客户端状态码 < 400)
|
||||
|
||||
### 4. Gateway 服务集成
|
||||
|
||||
**功能特性:**
|
||||
- 完善各 Gateway 服务的 Ops 集成
|
||||
- 统一错误日志记录接口
|
||||
- 增强上游错误追踪能力
|
||||
|
||||
**涉及服务:**
|
||||
- `antigravity_gateway_service.go` - Antigravity 网关集成
|
||||
- `gateway_service.go` - 通用网关集成
|
||||
- `gemini_messages_compat_service.go` - Gemini 兼容层集成
|
||||
- `openai_gateway_service.go` - OpenAI 网关集成
|
||||
|
||||
### 5. 前端 UI 优化
|
||||
|
||||
**代码重构:**
|
||||
- 大幅简化错误详情模态框代码(从 828 行优化到 450 行)
|
||||
- 优化错误日志表格组件,提升可读性
|
||||
- 清理未使用的 i18n 翻译,减少冗余
|
||||
- 统一组件代码风格和格式
|
||||
- 优化骨架屏组件,更好匹配实际看板布局
|
||||
|
||||
**布局改进:**
|
||||
- 修复模态框内容溢出和滚动问题
|
||||
- 优化表格布局,使用 flex 布局确保正确显示
|
||||
- 改进看板头部布局和交互
|
||||
- 提升响应式体验
|
||||
- 骨架屏支持全屏模式适配
|
||||
|
||||
**交互优化:**
|
||||
- 优化告警事件卡片功能和展示
|
||||
- 改进错误详情展示逻辑
|
||||
- 增强请求详情模态框
|
||||
- 完善运行时设置卡片
|
||||
- 改进加载动画效果
|
||||
|
||||
### 6. 国际化完善
|
||||
|
||||
**文案补充:**
|
||||
- 补充错误日志相关的英文翻译
|
||||
- 添加告警静默功能的中英文文案
|
||||
- 完善提示文本和错误信息
|
||||
- 统一术语翻译标准
|
||||
|
||||
## 文件变更
|
||||
|
||||
**后端(26 个文件):**
|
||||
- `backend/internal/handler/admin/ops_alerts_handler.go` - 告警接口增强
|
||||
- `backend/internal/handler/admin/ops_handler.go` - 错误日志接口优化
|
||||
- `backend/internal/handler/ops_error_logger.go` - 错误记录器增强
|
||||
- `backend/internal/repository/ops_repo.go` - 数据访问层重构
|
||||
- `backend/internal/repository/ops_repo_alerts.go` - 告警数据访问增强
|
||||
- `backend/internal/service/ops_*.go` - 核心服务层重构(10 个文件)
|
||||
- `backend/internal/service/*_gateway_service.go` - Gateway 集成(4 个文件)
|
||||
- `backend/internal/server/routes/admin.go` - 路由配置更新
|
||||
- `backend/migrations/*.sql` - 数据库迁移(2 个文件)
|
||||
- 测试文件更新(5 个文件)
|
||||
|
||||
**前端(13 个文件):**
|
||||
- `frontend/src/views/admin/ops/OpsDashboard.vue` - 看板主页优化
|
||||
- `frontend/src/views/admin/ops/components/*.vue` - 组件重构(10 个文件)
|
||||
- `frontend/src/api/admin/ops.ts` - API 接口扩展
|
||||
- `frontend/src/i18n/locales/*.ts` - 国际化文本(2 个文件)
|
||||
|
||||
## 代码统计
|
||||
|
||||
- 44 个文件修改
|
||||
- 3733 行新增
|
||||
- 995 行删除
|
||||
- 净增加 2738 行
|
||||
|
||||
## 核心改进
|
||||
|
||||
**可维护性提升:**
|
||||
- 重构核心服务层,职责更清晰
|
||||
- 简化前端组件代码,降低复杂度
|
||||
- 统一代码风格和命名规范
|
||||
- 清理冗余代码和未使用的翻译
|
||||
- 标准化错误分类体系
|
||||
|
||||
**功能完善:**
|
||||
- 告警静默功能,减少告警噪音
|
||||
- 错误日志查询优化,提升运维效率
|
||||
- Gateway 服务集成完善,统一监控能力
|
||||
- 错误解决状态追踪,便于问题管理
|
||||
|
||||
**用户体验优化:**
|
||||
- 修复多个 UI 布局问题
|
||||
- 优化交互流程
|
||||
- 完善国际化支持
|
||||
- 提升响应式体验
|
||||
- 改进加载状态展示
|
||||
|
||||
## 测试验证
|
||||
|
||||
- ✅ 错误日志查询和过滤功能
|
||||
- ✅ 告警静默创建和自动过期
|
||||
- ✅ 错误分类标准化迁移
|
||||
- ✅ Gateway 服务错误日志记录
|
||||
- ✅ 前端组件布局和交互
|
||||
- ✅ 骨架屏全屏模式适配
|
||||
- ✅ 国际化文本完整性
|
||||
- ✅ API 接口功能正确性
|
||||
- ✅ 数据库迁移执行成功
|
||||
@@ -18,7 +18,7 @@ English | [中文](README_CN.md)
|
||||
|
||||
## Demo
|
||||
|
||||
Try Sub2API online: **https://v2.pincc.ai/**
|
||||
Try Sub2API online: **https://demo.sub2api.org/**
|
||||
|
||||
Demo credentials (shared demo environment; **not** created automatically for self-hosted installs):
|
||||
|
||||
|
||||
@@ -57,6 +57,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
||||
|
||||
---
|
||||
|
||||
## OpenAI Responses 兼容注意事项
|
||||
|
||||
- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。
|
||||
- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。
|
||||
|
||||
---
|
||||
|
||||
## 部署方式
|
||||
|
||||
### 方式一:脚本安装(推荐)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"errors"
|
||||
"flag"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -44,7 +45,25 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
// initLogger configures the default slog handler based on gin.Mode().
|
||||
// In non-release mode, Debug level logs are enabled.
|
||||
func initLogger() {
|
||||
var level slog.Level
|
||||
if gin.Mode() == gin.ReleaseMode {
|
||||
level = slog.LevelInfo
|
||||
} else {
|
||||
level = slog.LevelDebug
|
||||
}
|
||||
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
})
|
||||
slog.SetDefault(slog.New(handler))
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Initialize slog logger based on gin mode
|
||||
initLogger()
|
||||
|
||||
// Parse command line flags
|
||||
setupMode := flag.Bool("setup", false, "Run setup wizard in CLI mode")
|
||||
showVersion := flag.Bool("version", false, "Show version information")
|
||||
|
||||
@@ -70,6 +70,8 @@ func provideCleanup(
|
||||
schedulerSnapshot *service.SchedulerSnapshotService,
|
||||
tokenRefresh *service.TokenRefreshService,
|
||||
accountExpiry *service.AccountExpiryService,
|
||||
subscriptionExpiry *service.SubscriptionExpiryService,
|
||||
usageCleanup *service.UsageCleanupService,
|
||||
pricing *service.PricingService,
|
||||
emailQueue *service.EmailQueueService,
|
||||
billingCache *service.BillingCacheService,
|
||||
@@ -123,6 +125,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"UsageCleanupService", func() error {
|
||||
if usageCleanup != nil {
|
||||
usageCleanup.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"TokenRefreshService", func() error {
|
||||
tokenRefresh.Stop()
|
||||
return nil
|
||||
@@ -131,6 +139,10 @@ func provideCleanup(
|
||||
accountExpiry.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"SubscriptionExpiryService", func() error {
|
||||
subscriptionExpiry.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"PricingService", func() error {
|
||||
pricing.Stop()
|
||||
return nil
|
||||
|
||||
@@ -63,11 +63,16 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
|
||||
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
|
||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService)
|
||||
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
totpCache := repository.NewTotpCache(redisClient)
|
||||
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
|
||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, totpService)
|
||||
userHandler := handler.NewUserHandler(userService)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||
dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
|
||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
|
||||
@@ -76,15 +81,21 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||
dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
|
||||
dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig)
|
||||
timingWheelService := service.ProvideTimingWheelService()
|
||||
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
|
||||
dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig)
|
||||
timingWheelService, err := service.ProvideTimingWheelService()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
|
||||
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
|
||||
accountRepository := repository.NewAccountRepository(client, db)
|
||||
schedulerCache := repository.NewSchedulerCache(redisClient)
|
||||
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
|
||||
proxyRepository := repository.NewProxyRepository(client, db)
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, apiKeyAuthCacheInvalidator)
|
||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
|
||||
adminUserHandler := admin.NewUserHandler(adminService)
|
||||
groupHandler := admin.NewGroupHandler(adminService)
|
||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||
@@ -98,25 +109,25 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
|
||||
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
||||
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
||||
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService)
|
||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
||||
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
||||
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
||||
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
|
||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream)
|
||||
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
|
||||
usageCache := service.NewUsageCache()
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache)
|
||||
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
||||
identityCache := repository.NewIdentityCache(redisClient)
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache)
|
||||
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||
schedulerCache := repository.NewSchedulerCache(redisClient)
|
||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
|
||||
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
||||
@@ -125,17 +136,20 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
||||
promoHandler := admin.NewPromoHandler(promoService)
|
||||
opsRepository := repository.NewOpsRepository(db)
|
||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
|
||||
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
billingService := service.NewBillingService(configConfig, pricingService)
|
||||
identityCache := repository.NewIdentityCache(redisClient)
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
|
||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache)
|
||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
|
||||
@@ -146,7 +160,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo)
|
||||
systemHandler := handler.ProvideSystemHandler(updateService)
|
||||
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
|
||||
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService)
|
||||
usageCleanupRepository := repository.NewUsageCleanupRepository(client, db)
|
||||
usageCleanupService := service.ProvideUsageCleanupService(usageCleanupRepository, timingWheelService, dashboardAggregationService, configConfig)
|
||||
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService, usageCleanupService)
|
||||
userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client)
|
||||
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
|
||||
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
|
||||
@@ -155,7 +171,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig)
|
||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig)
|
||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
|
||||
totpHandler := handler.NewTotpHandler(totpService)
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler)
|
||||
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
|
||||
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
|
||||
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
|
||||
@@ -166,9 +183,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
|
||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -201,6 +219,8 @@ func provideCleanup(
|
||||
schedulerSnapshot *service.SchedulerSnapshotService,
|
||||
tokenRefresh *service.TokenRefreshService,
|
||||
accountExpiry *service.AccountExpiryService,
|
||||
subscriptionExpiry *service.SubscriptionExpiryService,
|
||||
usageCleanup *service.UsageCleanupService,
|
||||
pricing *service.PricingService,
|
||||
emailQueue *service.EmailQueueService,
|
||||
billingCache *service.BillingCacheService,
|
||||
@@ -253,6 +273,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"UsageCleanupService", func() error {
|
||||
if usageCleanup != nil {
|
||||
usageCleanup.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"TokenRefreshService", func() error {
|
||||
tokenRefresh.Stop()
|
||||
return nil
|
||||
@@ -261,6 +287,10 @@ func provideCleanup(
|
||||
accountExpiry.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"SubscriptionExpiryService", func() error {
|
||||
subscriptionExpiry.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"PricingService", func() error {
|
||||
pricing.Stop()
|
||||
return nil
|
||||
|
||||
@@ -43,6 +43,8 @@ type Account struct {
|
||||
Concurrency int `json:"concurrency,omitempty"`
|
||||
// Priority holds the value of the "priority" field.
|
||||
Priority int `json:"priority,omitempty"`
|
||||
// RateMultiplier holds the value of the "rate_multiplier" field.
|
||||
RateMultiplier float64 `json:"rate_multiplier,omitempty"`
|
||||
// Status holds the value of the "status" field.
|
||||
Status string `json:"status,omitempty"`
|
||||
// ErrorMessage holds the value of the "error_message" field.
|
||||
@@ -135,6 +137,8 @@ func (*Account) scanValues(columns []string) ([]any, error) {
|
||||
values[i] = new([]byte)
|
||||
case account.FieldAutoPauseOnExpired, account.FieldSchedulable:
|
||||
values[i] = new(sql.NullBool)
|
||||
case account.FieldRateMultiplier:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldSessionWindowStatus:
|
||||
@@ -241,6 +245,12 @@ func (_m *Account) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.Priority = int(value.Int64)
|
||||
}
|
||||
case account.FieldRateMultiplier:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field rate_multiplier", values[i])
|
||||
} else if value.Valid {
|
||||
_m.RateMultiplier = value.Float64
|
||||
}
|
||||
case account.FieldStatus:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field status", values[i])
|
||||
@@ -420,6 +430,9 @@ func (_m *Account) String() string {
|
||||
builder.WriteString("priority=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Priority))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("rate_multiplier=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.RateMultiplier))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("status=")
|
||||
builder.WriteString(_m.Status)
|
||||
builder.WriteString(", ")
|
||||
|
||||
@@ -39,6 +39,8 @@ const (
|
||||
FieldConcurrency = "concurrency"
|
||||
// FieldPriority holds the string denoting the priority field in the database.
|
||||
FieldPriority = "priority"
|
||||
// FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
|
||||
FieldRateMultiplier = "rate_multiplier"
|
||||
// FieldStatus holds the string denoting the status field in the database.
|
||||
FieldStatus = "status"
|
||||
// FieldErrorMessage holds the string denoting the error_message field in the database.
|
||||
@@ -116,6 +118,7 @@ var Columns = []string{
|
||||
FieldProxyID,
|
||||
FieldConcurrency,
|
||||
FieldPriority,
|
||||
FieldRateMultiplier,
|
||||
FieldStatus,
|
||||
FieldErrorMessage,
|
||||
FieldLastUsedAt,
|
||||
@@ -174,6 +177,8 @@ var (
|
||||
DefaultConcurrency int
|
||||
// DefaultPriority holds the default value on creation for the "priority" field.
|
||||
DefaultPriority int
|
||||
// DefaultRateMultiplier holds the default value on creation for the "rate_multiplier" field.
|
||||
DefaultRateMultiplier float64
|
||||
// DefaultStatus holds the default value on creation for the "status" field.
|
||||
DefaultStatus string
|
||||
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||
@@ -244,6 +249,11 @@ func ByPriority(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldPriority, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByRateMultiplier orders the results by the rate_multiplier field.
|
||||
func ByRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldRateMultiplier, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByStatus orders the results by the status field.
|
||||
func ByStatus(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldStatus, opts...).ToFunc()
|
||||
|
||||
@@ -105,6 +105,11 @@ func Priority(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldPriority, v))
|
||||
}
|
||||
|
||||
// RateMultiplier applies equality check predicate on the "rate_multiplier" field. It's identical to RateMultiplierEQ.
|
||||
func RateMultiplier(v float64) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldRateMultiplier, v))
|
||||
}
|
||||
|
||||
// Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
|
||||
func Status(v string) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldStatus, v))
|
||||
@@ -675,6 +680,46 @@ func PriorityLTE(v int) predicate.Account {
|
||||
return predicate.Account(sql.FieldLTE(FieldPriority, v))
|
||||
}
|
||||
|
||||
// RateMultiplierEQ applies the EQ predicate on the "rate_multiplier" field.
|
||||
func RateMultiplierEQ(v float64) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldRateMultiplier, v))
|
||||
}
|
||||
|
||||
// RateMultiplierNEQ applies the NEQ predicate on the "rate_multiplier" field.
|
||||
func RateMultiplierNEQ(v float64) predicate.Account {
|
||||
return predicate.Account(sql.FieldNEQ(FieldRateMultiplier, v))
|
||||
}
|
||||
|
||||
// RateMultiplierIn applies the In predicate on the "rate_multiplier" field.
|
||||
func RateMultiplierIn(vs ...float64) predicate.Account {
|
||||
return predicate.Account(sql.FieldIn(FieldRateMultiplier, vs...))
|
||||
}
|
||||
|
||||
// RateMultiplierNotIn applies the NotIn predicate on the "rate_multiplier" field.
|
||||
func RateMultiplierNotIn(vs ...float64) predicate.Account {
|
||||
return predicate.Account(sql.FieldNotIn(FieldRateMultiplier, vs...))
|
||||
}
|
||||
|
||||
// RateMultiplierGT applies the GT predicate on the "rate_multiplier" field.
|
||||
func RateMultiplierGT(v float64) predicate.Account {
|
||||
return predicate.Account(sql.FieldGT(FieldRateMultiplier, v))
|
||||
}
|
||||
|
||||
// RateMultiplierGTE applies the GTE predicate on the "rate_multiplier" field.
|
||||
func RateMultiplierGTE(v float64) predicate.Account {
|
||||
return predicate.Account(sql.FieldGTE(FieldRateMultiplier, v))
|
||||
}
|
||||
|
||||
// RateMultiplierLT applies the LT predicate on the "rate_multiplier" field.
|
||||
func RateMultiplierLT(v float64) predicate.Account {
|
||||
return predicate.Account(sql.FieldLT(FieldRateMultiplier, v))
|
||||
}
|
||||
|
||||
// RateMultiplierLTE applies the LTE predicate on the "rate_multiplier" field.
|
||||
func RateMultiplierLTE(v float64) predicate.Account {
|
||||
return predicate.Account(sql.FieldLTE(FieldRateMultiplier, v))
|
||||
}
|
||||
|
||||
// StatusEQ applies the EQ predicate on the "status" field.
|
||||
func StatusEQ(v string) predicate.Account {
|
||||
return predicate.Account(sql.FieldEQ(FieldStatus, v))
|
||||
|
||||
@@ -153,6 +153,20 @@ func (_c *AccountCreate) SetNillablePriority(v *int) *AccountCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetRateMultiplier sets the "rate_multiplier" field.
|
||||
func (_c *AccountCreate) SetRateMultiplier(v float64) *AccountCreate {
|
||||
_c.mutation.SetRateMultiplier(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil.
|
||||
func (_c *AccountCreate) SetNillableRateMultiplier(v *float64) *AccountCreate {
|
||||
if v != nil {
|
||||
_c.SetRateMultiplier(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetStatus sets the "status" field.
|
||||
func (_c *AccountCreate) SetStatus(v string) *AccountCreate {
|
||||
_c.mutation.SetStatus(v)
|
||||
@@ -429,6 +443,10 @@ func (_c *AccountCreate) defaults() error {
|
||||
v := account.DefaultPriority
|
||||
_c.mutation.SetPriority(v)
|
||||
}
|
||||
if _, ok := _c.mutation.RateMultiplier(); !ok {
|
||||
v := account.DefaultRateMultiplier
|
||||
_c.mutation.SetRateMultiplier(v)
|
||||
}
|
||||
if _, ok := _c.mutation.Status(); !ok {
|
||||
v := account.DefaultStatus
|
||||
_c.mutation.SetStatus(v)
|
||||
@@ -488,6 +506,9 @@ func (_c *AccountCreate) check() error {
|
||||
if _, ok := _c.mutation.Priority(); !ok {
|
||||
return &ValidationError{Name: "priority", err: errors.New(`ent: missing required field "Account.priority"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.RateMultiplier(); !ok {
|
||||
return &ValidationError{Name: "rate_multiplier", err: errors.New(`ent: missing required field "Account.rate_multiplier"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.Status(); !ok {
|
||||
return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "Account.status"`)}
|
||||
}
|
||||
@@ -578,6 +599,10 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(account.FieldPriority, field.TypeInt, value)
|
||||
_node.Priority = value
|
||||
}
|
||||
if value, ok := _c.mutation.RateMultiplier(); ok {
|
||||
_spec.SetField(account.FieldRateMultiplier, field.TypeFloat64, value)
|
||||
_node.RateMultiplier = value
|
||||
}
|
||||
if value, ok := _c.mutation.Status(); ok {
|
||||
_spec.SetField(account.FieldStatus, field.TypeString, value)
|
||||
_node.Status = value
|
||||
@@ -893,6 +918,24 @@ func (u *AccountUpsert) AddPriority(v int) *AccountUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetRateMultiplier sets the "rate_multiplier" field.
|
||||
func (u *AccountUpsert) SetRateMultiplier(v float64) *AccountUpsert {
|
||||
u.Set(account.FieldRateMultiplier, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create.
|
||||
func (u *AccountUpsert) UpdateRateMultiplier() *AccountUpsert {
|
||||
u.SetExcluded(account.FieldRateMultiplier)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddRateMultiplier adds v to the "rate_multiplier" field.
|
||||
func (u *AccountUpsert) AddRateMultiplier(v float64) *AccountUpsert {
|
||||
u.Add(account.FieldRateMultiplier, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetStatus sets the "status" field.
|
||||
func (u *AccountUpsert) SetStatus(v string) *AccountUpsert {
|
||||
u.Set(account.FieldStatus, v)
|
||||
@@ -1325,6 +1368,27 @@ func (u *AccountUpsertOne) UpdatePriority() *AccountUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetRateMultiplier sets the "rate_multiplier" field.
|
||||
func (u *AccountUpsertOne) SetRateMultiplier(v float64) *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.SetRateMultiplier(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddRateMultiplier adds v to the "rate_multiplier" field.
|
||||
func (u *AccountUpsertOne) AddRateMultiplier(v float64) *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.AddRateMultiplier(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create.
|
||||
func (u *AccountUpsertOne) UpdateRateMultiplier() *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.UpdateRateMultiplier()
|
||||
})
|
||||
}
|
||||
|
||||
// SetStatus sets the "status" field.
|
||||
func (u *AccountUpsertOne) SetStatus(v string) *AccountUpsertOne {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
@@ -1956,6 +2020,27 @@ func (u *AccountUpsertBulk) UpdatePriority() *AccountUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetRateMultiplier sets the "rate_multiplier" field.
|
||||
func (u *AccountUpsertBulk) SetRateMultiplier(v float64) *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.SetRateMultiplier(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddRateMultiplier adds v to the "rate_multiplier" field.
|
||||
func (u *AccountUpsertBulk) AddRateMultiplier(v float64) *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.AddRateMultiplier(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create.
|
||||
func (u *AccountUpsertBulk) UpdateRateMultiplier() *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
s.UpdateRateMultiplier()
|
||||
})
|
||||
}
|
||||
|
||||
// SetStatus sets the "status" field.
|
||||
func (u *AccountUpsertBulk) SetStatus(v string) *AccountUpsertBulk {
|
||||
return u.Update(func(s *AccountUpsert) {
|
||||
|
||||
@@ -193,6 +193,27 @@ func (_u *AccountUpdate) AddPriority(v int) *AccountUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetRateMultiplier sets the "rate_multiplier" field.
|
||||
func (_u *AccountUpdate) SetRateMultiplier(v float64) *AccountUpdate {
|
||||
_u.mutation.ResetRateMultiplier()
|
||||
_u.mutation.SetRateMultiplier(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil.
|
||||
func (_u *AccountUpdate) SetNillableRateMultiplier(v *float64) *AccountUpdate {
|
||||
if v != nil {
|
||||
_u.SetRateMultiplier(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddRateMultiplier adds value to the "rate_multiplier" field.
|
||||
func (_u *AccountUpdate) AddRateMultiplier(v float64) *AccountUpdate {
|
||||
_u.mutation.AddRateMultiplier(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetStatus sets the "status" field.
|
||||
func (_u *AccountUpdate) SetStatus(v string) *AccountUpdate {
|
||||
_u.mutation.SetStatus(v)
|
||||
@@ -629,6 +650,12 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.AddedPriority(); ok {
|
||||
_spec.AddField(account.FieldPriority, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.RateMultiplier(); ok {
|
||||
_spec.SetField(account.FieldRateMultiplier, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
|
||||
_spec.AddField(account.FieldRateMultiplier, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Status(); ok {
|
||||
_spec.SetField(account.FieldStatus, field.TypeString, value)
|
||||
}
|
||||
@@ -1005,6 +1032,27 @@ func (_u *AccountUpdateOne) AddPriority(v int) *AccountUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetRateMultiplier sets the "rate_multiplier" field.
|
||||
func (_u *AccountUpdateOne) SetRateMultiplier(v float64) *AccountUpdateOne {
|
||||
_u.mutation.ResetRateMultiplier()
|
||||
_u.mutation.SetRateMultiplier(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil.
|
||||
func (_u *AccountUpdateOne) SetNillableRateMultiplier(v *float64) *AccountUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetRateMultiplier(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddRateMultiplier adds value to the "rate_multiplier" field.
|
||||
func (_u *AccountUpdateOne) AddRateMultiplier(v float64) *AccountUpdateOne {
|
||||
_u.mutation.AddRateMultiplier(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetStatus sets the "status" field.
|
||||
func (_u *AccountUpdateOne) SetStatus(v string) *AccountUpdateOne {
|
||||
_u.mutation.SetStatus(v)
|
||||
@@ -1471,6 +1519,12 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er
|
||||
if value, ok := _u.mutation.AddedPriority(); ok {
|
||||
_spec.AddField(account.FieldPriority, field.TypeInt, value)
|
||||
}
|
||||
if value, ok := _u.mutation.RateMultiplier(); ok {
|
||||
_spec.SetField(account.FieldRateMultiplier, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
|
||||
_spec.AddField(account.FieldRateMultiplier, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Status(); ok {
|
||||
_spec.SetField(account.FieldStatus, field.TypeString, value)
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/proxy"
|
||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/setting"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||
@@ -57,6 +58,8 @@ type Client struct {
|
||||
RedeemCode *RedeemCodeClient
|
||||
// Setting is the client for interacting with the Setting builders.
|
||||
Setting *SettingClient
|
||||
// UsageCleanupTask is the client for interacting with the UsageCleanupTask builders.
|
||||
UsageCleanupTask *UsageCleanupTaskClient
|
||||
// UsageLog is the client for interacting with the UsageLog builders.
|
||||
UsageLog *UsageLogClient
|
||||
// User is the client for interacting with the User builders.
|
||||
@@ -89,6 +92,7 @@ func (c *Client) init() {
|
||||
c.Proxy = NewProxyClient(c.config)
|
||||
c.RedeemCode = NewRedeemCodeClient(c.config)
|
||||
c.Setting = NewSettingClient(c.config)
|
||||
c.UsageCleanupTask = NewUsageCleanupTaskClient(c.config)
|
||||
c.UsageLog = NewUsageLogClient(c.config)
|
||||
c.User = NewUserClient(c.config)
|
||||
c.UserAllowedGroup = NewUserAllowedGroupClient(c.config)
|
||||
@@ -196,6 +200,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
|
||||
Proxy: NewProxyClient(cfg),
|
||||
RedeemCode: NewRedeemCodeClient(cfg),
|
||||
Setting: NewSettingClient(cfg),
|
||||
UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
|
||||
UsageLog: NewUsageLogClient(cfg),
|
||||
User: NewUserClient(cfg),
|
||||
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
|
||||
@@ -230,6 +235,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
|
||||
Proxy: NewProxyClient(cfg),
|
||||
RedeemCode: NewRedeemCodeClient(cfg),
|
||||
Setting: NewSettingClient(cfg),
|
||||
UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
|
||||
UsageLog: NewUsageLogClient(cfg),
|
||||
User: NewUserClient(cfg),
|
||||
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
|
||||
@@ -266,8 +272,9 @@ func (c *Client) Close() error {
|
||||
func (c *Client) Use(hooks ...Hook) {
|
||||
for _, n := range []interface{ Use(...Hook) }{
|
||||
c.APIKey, c.Account, c.AccountGroup, c.Group, c.PromoCode, c.PromoCodeUsage,
|
||||
c.Proxy, c.RedeemCode, c.Setting, c.UsageLog, c.User, c.UserAllowedGroup,
|
||||
c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription,
|
||||
c.Proxy, c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User,
|
||||
c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
|
||||
c.UserSubscription,
|
||||
} {
|
||||
n.Use(hooks...)
|
||||
}
|
||||
@@ -278,8 +285,9 @@ func (c *Client) Use(hooks ...Hook) {
|
||||
func (c *Client) Intercept(interceptors ...Interceptor) {
|
||||
for _, n := range []interface{ Intercept(...Interceptor) }{
|
||||
c.APIKey, c.Account, c.AccountGroup, c.Group, c.PromoCode, c.PromoCodeUsage,
|
||||
c.Proxy, c.RedeemCode, c.Setting, c.UsageLog, c.User, c.UserAllowedGroup,
|
||||
c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription,
|
||||
c.Proxy, c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User,
|
||||
c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
|
||||
c.UserSubscription,
|
||||
} {
|
||||
n.Intercept(interceptors...)
|
||||
}
|
||||
@@ -306,6 +314,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
|
||||
return c.RedeemCode.mutate(ctx, m)
|
||||
case *SettingMutation:
|
||||
return c.Setting.mutate(ctx, m)
|
||||
case *UsageCleanupTaskMutation:
|
||||
return c.UsageCleanupTask.mutate(ctx, m)
|
||||
case *UsageLogMutation:
|
||||
return c.UsageLog.mutate(ctx, m)
|
||||
case *UserMutation:
|
||||
@@ -1847,6 +1857,139 @@ func (c *SettingClient) mutate(ctx context.Context, m *SettingMutation) (Value,
|
||||
}
|
||||
}
|
||||
|
||||
// UsageCleanupTaskClient is a client for the UsageCleanupTask schema.
|
||||
type UsageCleanupTaskClient struct {
|
||||
config
|
||||
}
|
||||
|
||||
// NewUsageCleanupTaskClient returns a client for the UsageCleanupTask from the given config.
|
||||
func NewUsageCleanupTaskClient(c config) *UsageCleanupTaskClient {
|
||||
return &UsageCleanupTaskClient{config: c}
|
||||
}
|
||||
|
||||
// Use adds a list of mutation hooks to the hooks stack.
|
||||
// A call to `Use(f, g, h)` equals to `usagecleanuptask.Hooks(f(g(h())))`.
|
||||
func (c *UsageCleanupTaskClient) Use(hooks ...Hook) {
|
||||
c.hooks.UsageCleanupTask = append(c.hooks.UsageCleanupTask, hooks...)
|
||||
}
|
||||
|
||||
// Intercept adds a list of query interceptors to the interceptors stack.
|
||||
// A call to `Intercept(f, g, h)` equals to `usagecleanuptask.Intercept(f(g(h())))`.
|
||||
func (c *UsageCleanupTaskClient) Intercept(interceptors ...Interceptor) {
|
||||
c.inters.UsageCleanupTask = append(c.inters.UsageCleanupTask, interceptors...)
|
||||
}
|
||||
|
||||
// Create returns a builder for creating a UsageCleanupTask entity.
|
||||
func (c *UsageCleanupTaskClient) Create() *UsageCleanupTaskCreate {
|
||||
mutation := newUsageCleanupTaskMutation(c.config, OpCreate)
|
||||
return &UsageCleanupTaskCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// CreateBulk returns a builder for creating a bulk of UsageCleanupTask entities.
|
||||
func (c *UsageCleanupTaskClient) CreateBulk(builders ...*UsageCleanupTaskCreate) *UsageCleanupTaskCreateBulk {
|
||||
return &UsageCleanupTaskCreateBulk{config: c.config, builders: builders}
|
||||
}
|
||||
|
||||
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
|
||||
// a builder and applies setFunc on it.
|
||||
func (c *UsageCleanupTaskClient) MapCreateBulk(slice any, setFunc func(*UsageCleanupTaskCreate, int)) *UsageCleanupTaskCreateBulk {
|
||||
rv := reflect.ValueOf(slice)
|
||||
if rv.Kind() != reflect.Slice {
|
||||
return &UsageCleanupTaskCreateBulk{err: fmt.Errorf("calling to UsageCleanupTaskClient.MapCreateBulk with wrong type %T, need slice", slice)}
|
||||
}
|
||||
builders := make([]*UsageCleanupTaskCreate, rv.Len())
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
builders[i] = c.Create()
|
||||
setFunc(builders[i], i)
|
||||
}
|
||||
return &UsageCleanupTaskCreateBulk{config: c.config, builders: builders}
|
||||
}
|
||||
|
||||
// Update returns an update builder for UsageCleanupTask.
|
||||
func (c *UsageCleanupTaskClient) Update() *UsageCleanupTaskUpdate {
|
||||
mutation := newUsageCleanupTaskMutation(c.config, OpUpdate)
|
||||
return &UsageCleanupTaskUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// UpdateOne returns an update builder for the given entity.
|
||||
func (c *UsageCleanupTaskClient) UpdateOne(_m *UsageCleanupTask) *UsageCleanupTaskUpdateOne {
|
||||
mutation := newUsageCleanupTaskMutation(c.config, OpUpdateOne, withUsageCleanupTask(_m))
|
||||
return &UsageCleanupTaskUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// UpdateOneID returns an update builder for the given id.
|
||||
func (c *UsageCleanupTaskClient) UpdateOneID(id int64) *UsageCleanupTaskUpdateOne {
|
||||
mutation := newUsageCleanupTaskMutation(c.config, OpUpdateOne, withUsageCleanupTaskID(id))
|
||||
return &UsageCleanupTaskUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// Delete returns a delete builder for UsageCleanupTask.
|
||||
func (c *UsageCleanupTaskClient) Delete() *UsageCleanupTaskDelete {
|
||||
mutation := newUsageCleanupTaskMutation(c.config, OpDelete)
|
||||
return &UsageCleanupTaskDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// DeleteOne returns a builder for deleting the given entity.
|
||||
func (c *UsageCleanupTaskClient) DeleteOne(_m *UsageCleanupTask) *UsageCleanupTaskDeleteOne {
|
||||
return c.DeleteOneID(_m.ID)
|
||||
}
|
||||
|
||||
// DeleteOneID returns a builder for deleting the given entity by its id.
|
||||
func (c *UsageCleanupTaskClient) DeleteOneID(id int64) *UsageCleanupTaskDeleteOne {
|
||||
builder := c.Delete().Where(usagecleanuptask.ID(id))
|
||||
builder.mutation.id = &id
|
||||
builder.mutation.op = OpDeleteOne
|
||||
return &UsageCleanupTaskDeleteOne{builder}
|
||||
}
|
||||
|
||||
// Query returns a query builder for UsageCleanupTask.
|
||||
func (c *UsageCleanupTaskClient) Query() *UsageCleanupTaskQuery {
|
||||
return &UsageCleanupTaskQuery{
|
||||
config: c.config,
|
||||
ctx: &QueryContext{Type: TypeUsageCleanupTask},
|
||||
inters: c.Interceptors(),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns a UsageCleanupTask entity by its id.
|
||||
func (c *UsageCleanupTaskClient) Get(ctx context.Context, id int64) (*UsageCleanupTask, error) {
|
||||
return c.Query().Where(usagecleanuptask.ID(id)).Only(ctx)
|
||||
}
|
||||
|
||||
// GetX is like Get, but panics if an error occurs.
|
||||
func (c *UsageCleanupTaskClient) GetX(ctx context.Context, id int64) *UsageCleanupTask {
|
||||
obj, err := c.Get(ctx, id)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return obj
|
||||
}
|
||||
|
||||
// Hooks returns the client hooks.
|
||||
func (c *UsageCleanupTaskClient) Hooks() []Hook {
|
||||
return c.hooks.UsageCleanupTask
|
||||
}
|
||||
|
||||
// Interceptors returns the client interceptors.
|
||||
func (c *UsageCleanupTaskClient) Interceptors() []Interceptor {
|
||||
return c.inters.UsageCleanupTask
|
||||
}
|
||||
|
||||
func (c *UsageCleanupTaskClient) mutate(ctx context.Context, m *UsageCleanupTaskMutation) (Value, error) {
|
||||
switch m.Op() {
|
||||
case OpCreate:
|
||||
return (&UsageCleanupTaskCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||
case OpUpdate:
|
||||
return (&UsageCleanupTaskUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||
case OpUpdateOne:
|
||||
return (&UsageCleanupTaskUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||
case OpDelete, OpDeleteOne:
|
||||
return (&UsageCleanupTaskDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
|
||||
default:
|
||||
return nil, fmt.Errorf("ent: unknown UsageCleanupTask mutation op: %q", m.Op())
|
||||
}
|
||||
}
|
||||
|
||||
// UsageLogClient is a client for the UsageLog schema.
|
||||
type UsageLogClient struct {
|
||||
config
|
||||
@@ -2974,13 +3117,13 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
|
||||
type (
|
||||
hooks struct {
|
||||
APIKey, Account, AccountGroup, Group, PromoCode, PromoCodeUsage, Proxy,
|
||||
RedeemCode, Setting, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
|
||||
UserAttributeValue, UserSubscription []ent.Hook
|
||||
RedeemCode, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
|
||||
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook
|
||||
}
|
||||
inters struct {
|
||||
APIKey, Account, AccountGroup, Group, PromoCode, PromoCodeUsage, Proxy,
|
||||
RedeemCode, Setting, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
|
||||
UserAttributeValue, UserSubscription []ent.Interceptor
|
||||
RedeemCode, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
|
||||
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/proxy"
|
||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/setting"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||
@@ -96,6 +97,7 @@ func checkColumn(t, c string) error {
|
||||
proxy.Table: proxy.ValidColumn,
|
||||
redeemcode.Table: redeemcode.ValidColumn,
|
||||
setting.Table: setting.ValidColumn,
|
||||
usagecleanuptask.Table: usagecleanuptask.ValidColumn,
|
||||
usagelog.Table: usagelog.ValidColumn,
|
||||
user.Table: user.ValidColumn,
|
||||
userallowedgroup.Table: userallowedgroup.ValidColumn,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package ent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -55,6 +56,10 @@ type Group struct {
|
||||
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
|
||||
// 非 Claude Code 请求降级使用的分组 ID
|
||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||
// 模型路由配置:模型模式 -> 优先账号ID列表
|
||||
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
||||
// 是否启用模型路由配置
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
// The values are being populated by the GroupQuery when eager-loading is set.
|
||||
Edges GroupEdges `json:"edges"`
|
||||
@@ -161,7 +166,9 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
||||
values := make([]any, len(columns))
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly:
|
||||
case group.FieldModelRouting:
|
||||
values[i] = new([]byte)
|
||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled:
|
||||
values[i] = new(sql.NullBool)
|
||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
@@ -315,6 +322,20 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
||||
_m.FallbackGroupID = new(int64)
|
||||
*_m.FallbackGroupID = value.Int64
|
||||
}
|
||||
case group.FieldModelRouting:
|
||||
if value, ok := values[i].(*[]byte); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field model_routing", values[i])
|
||||
} else if value != nil && len(*value) > 0 {
|
||||
if err := json.Unmarshal(*value, &_m.ModelRouting); err != nil {
|
||||
return fmt.Errorf("unmarshal field model_routing: %w", err)
|
||||
}
|
||||
}
|
||||
case group.FieldModelRoutingEnabled:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field model_routing_enabled", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ModelRoutingEnabled = value.Bool
|
||||
}
|
||||
default:
|
||||
_m.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
@@ -465,6 +486,12 @@ func (_m *Group) String() string {
|
||||
builder.WriteString("fallback_group_id=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("model_routing=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.ModelRouting))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("model_routing_enabled=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.ModelRoutingEnabled))
|
||||
builder.WriteByte(')')
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
@@ -53,6 +53,10 @@ const (
|
||||
FieldClaudeCodeOnly = "claude_code_only"
|
||||
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
|
||||
FieldFallbackGroupID = "fallback_group_id"
|
||||
// FieldModelRouting holds the string denoting the model_routing field in the database.
|
||||
FieldModelRouting = "model_routing"
|
||||
// FieldModelRoutingEnabled holds the string denoting the model_routing_enabled field in the database.
|
||||
FieldModelRoutingEnabled = "model_routing_enabled"
|
||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||
EdgeAPIKeys = "api_keys"
|
||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||
@@ -147,6 +151,8 @@ var Columns = []string{
|
||||
FieldImagePrice4k,
|
||||
FieldClaudeCodeOnly,
|
||||
FieldFallbackGroupID,
|
||||
FieldModelRouting,
|
||||
FieldModelRoutingEnabled,
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -204,6 +210,8 @@ var (
|
||||
DefaultDefaultValidityDays int
|
||||
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
|
||||
DefaultClaudeCodeOnly bool
|
||||
// DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field.
|
||||
DefaultModelRoutingEnabled bool
|
||||
)
|
||||
|
||||
// OrderOption defines the ordering options for the Group queries.
|
||||
@@ -309,6 +317,11 @@ func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByModelRoutingEnabled orders the results by the model_routing_enabled field.
|
||||
func ByModelRoutingEnabled(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldModelRoutingEnabled, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAPIKeysCount orders the results by api_keys count.
|
||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
|
||||
@@ -150,6 +150,11 @@ func FallbackGroupID(v int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
|
||||
}
|
||||
|
||||
// ModelRoutingEnabled applies equality check predicate on the "model_routing_enabled" field. It's identical to ModelRoutingEnabledEQ.
|
||||
func ModelRoutingEnabled(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v))
|
||||
}
|
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
||||
@@ -1065,6 +1070,26 @@ func FallbackGroupIDNotNil() predicate.Group {
|
||||
return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID))
|
||||
}
|
||||
|
||||
// ModelRoutingIsNil applies the IsNil predicate on the "model_routing" field.
|
||||
func ModelRoutingIsNil() predicate.Group {
|
||||
return predicate.Group(sql.FieldIsNull(FieldModelRouting))
|
||||
}
|
||||
|
||||
// ModelRoutingNotNil applies the NotNil predicate on the "model_routing" field.
|
||||
func ModelRoutingNotNil() predicate.Group {
|
||||
return predicate.Group(sql.FieldNotNull(FieldModelRouting))
|
||||
}
|
||||
|
||||
// ModelRoutingEnabledEQ applies the EQ predicate on the "model_routing_enabled" field.
|
||||
func ModelRoutingEnabledEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v))
|
||||
}
|
||||
|
||||
// ModelRoutingEnabledNEQ applies the NEQ predicate on the "model_routing_enabled" field.
|
||||
func ModelRoutingEnabledNEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldModelRoutingEnabled, v))
|
||||
}
|
||||
|
||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||
func HasAPIKeys() predicate.Group {
|
||||
return predicate.Group(func(s *sql.Selector) {
|
||||
|
||||
@@ -286,6 +286,26 @@ func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetModelRouting sets the "model_routing" field.
|
||||
func (_c *GroupCreate) SetModelRouting(v map[string][]int64) *GroupCreate {
|
||||
_c.mutation.SetModelRouting(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
|
||||
func (_c *GroupCreate) SetModelRoutingEnabled(v bool) *GroupCreate {
|
||||
_c.mutation.SetModelRoutingEnabled(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableModelRoutingEnabled sets the "model_routing_enabled" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableModelRoutingEnabled(v *bool) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetModelRoutingEnabled(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
||||
_c.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -455,6 +475,10 @@ func (_c *GroupCreate) defaults() error {
|
||||
v := group.DefaultClaudeCodeOnly
|
||||
_c.mutation.SetClaudeCodeOnly(v)
|
||||
}
|
||||
if _, ok := _c.mutation.ModelRoutingEnabled(); !ok {
|
||||
v := group.DefaultModelRoutingEnabled
|
||||
_c.mutation.SetModelRoutingEnabled(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -510,6 +534,9 @@ func (_c *GroupCreate) check() error {
|
||||
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
|
||||
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.ModelRoutingEnabled(); !ok {
|
||||
return &ValidationError{Name: "model_routing_enabled", err: errors.New(`ent: missing required field "Group.model_routing_enabled"`)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -613,6 +640,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
|
||||
_node.FallbackGroupID = &value
|
||||
}
|
||||
if value, ok := _c.mutation.ModelRouting(); ok {
|
||||
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
|
||||
_node.ModelRouting = value
|
||||
}
|
||||
if value, ok := _c.mutation.ModelRoutingEnabled(); ok {
|
||||
_spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
|
||||
_node.ModelRoutingEnabled = value
|
||||
}
|
||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -1093,6 +1128,36 @@ func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetModelRouting sets the "model_routing" field.
|
||||
func (u *GroupUpsert) SetModelRouting(v map[string][]int64) *GroupUpsert {
|
||||
u.Set(group.FieldModelRouting, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateModelRouting sets the "model_routing" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateModelRouting() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldModelRouting)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearModelRouting clears the value of the "model_routing" field.
|
||||
func (u *GroupUpsert) ClearModelRouting() *GroupUpsert {
|
||||
u.SetNull(group.FieldModelRouting)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
|
||||
func (u *GroupUpsert) SetModelRoutingEnabled(v bool) *GroupUpsert {
|
||||
u.Set(group.FieldModelRoutingEnabled, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateModelRoutingEnabled sets the "model_routing_enabled" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateModelRoutingEnabled() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldModelRoutingEnabled)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||
// Using this option is equivalent to using:
|
||||
//
|
||||
@@ -1516,6 +1581,41 @@ func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetModelRouting sets the "model_routing" field.
|
||||
func (u *GroupUpsertOne) SetModelRouting(v map[string][]int64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetModelRouting(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateModelRouting sets the "model_routing" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateModelRouting() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateModelRouting()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearModelRouting clears the value of the "model_routing" field.
|
||||
func (u *GroupUpsertOne) ClearModelRouting() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.ClearModelRouting()
|
||||
})
|
||||
}
|
||||
|
||||
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
|
||||
func (u *GroupUpsertOne) SetModelRoutingEnabled(v bool) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetModelRoutingEnabled(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateModelRoutingEnabled sets the "model_routing_enabled" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateModelRoutingEnabled() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateModelRoutingEnabled()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
||||
if len(u.create.conflict) == 0 {
|
||||
@@ -2105,6 +2205,41 @@ func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetModelRouting sets the "model_routing" field.
|
||||
func (u *GroupUpsertBulk) SetModelRouting(v map[string][]int64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetModelRouting(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateModelRouting sets the "model_routing" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateModelRouting() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateModelRouting()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearModelRouting clears the value of the "model_routing" field.
|
||||
func (u *GroupUpsertBulk) ClearModelRouting() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.ClearModelRouting()
|
||||
})
|
||||
}
|
||||
|
||||
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
|
||||
func (u *GroupUpsertBulk) SetModelRoutingEnabled(v bool) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetModelRoutingEnabled(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateModelRoutingEnabled sets the "model_routing_enabled" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateModelRoutingEnabled() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateModelRoutingEnabled()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
||||
if u.create.err != nil {
|
||||
|
||||
@@ -395,6 +395,32 @@ func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetModelRouting sets the "model_routing" field.
|
||||
func (_u *GroupUpdate) SetModelRouting(v map[string][]int64) *GroupUpdate {
|
||||
_u.mutation.SetModelRouting(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearModelRouting clears the value of the "model_routing" field.
|
||||
func (_u *GroupUpdate) ClearModelRouting() *GroupUpdate {
|
||||
_u.mutation.ClearModelRouting()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
|
||||
func (_u *GroupUpdate) SetModelRoutingEnabled(v bool) *GroupUpdate {
|
||||
_u.mutation.SetModelRoutingEnabled(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableModelRoutingEnabled sets the "model_routing_enabled" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableModelRoutingEnabled(v *bool) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetModelRoutingEnabled(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -803,6 +829,15 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if _u.mutation.FallbackGroupIDCleared() {
|
||||
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
|
||||
}
|
||||
if value, ok := _u.mutation.ModelRouting(); ok {
|
||||
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
|
||||
}
|
||||
if _u.mutation.ModelRoutingCleared() {
|
||||
_spec.ClearField(group.FieldModelRouting, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.ModelRoutingEnabled(); ok {
|
||||
_spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -1478,6 +1513,32 @@ func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetModelRouting sets the "model_routing" field.
|
||||
func (_u *GroupUpdateOne) SetModelRouting(v map[string][]int64) *GroupUpdateOne {
|
||||
_u.mutation.SetModelRouting(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearModelRouting clears the value of the "model_routing" field.
|
||||
func (_u *GroupUpdateOne) ClearModelRouting() *GroupUpdateOne {
|
||||
_u.mutation.ClearModelRouting()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
|
||||
func (_u *GroupUpdateOne) SetModelRoutingEnabled(v bool) *GroupUpdateOne {
|
||||
_u.mutation.SetModelRoutingEnabled(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableModelRoutingEnabled sets the "model_routing_enabled" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableModelRoutingEnabled(v *bool) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetModelRoutingEnabled(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -1916,6 +1977,15 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
||||
if _u.mutation.FallbackGroupIDCleared() {
|
||||
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
|
||||
}
|
||||
if value, ok := _u.mutation.ModelRouting(); ok {
|
||||
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
|
||||
}
|
||||
if _u.mutation.ModelRoutingCleared() {
|
||||
_spec.ClearField(group.FieldModelRouting, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.ModelRoutingEnabled(); ok {
|
||||
_spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
|
||||
@@ -117,6 +117,18 @@ func (f SettingFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, err
|
||||
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SettingMutation", m)
|
||||
}
|
||||
|
||||
// The UsageCleanupTaskFunc type is an adapter to allow the use of ordinary
|
||||
// function as UsageCleanupTask mutator.
|
||||
type UsageCleanupTaskFunc func(context.Context, *ent.UsageCleanupTaskMutation) (ent.Value, error)
|
||||
|
||||
// Mutate calls f(ctx, m).
|
||||
func (f UsageCleanupTaskFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
|
||||
if mv, ok := m.(*ent.UsageCleanupTaskMutation); ok {
|
||||
return f(ctx, mv)
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UsageCleanupTaskMutation", m)
|
||||
}
|
||||
|
||||
// The UsageLogFunc type is an adapter to allow the use of ordinary
|
||||
// function as UsageLog mutator.
|
||||
type UsageLogFunc func(context.Context, *ent.UsageLogMutation) (ent.Value, error)
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/proxy"
|
||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/setting"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||
@@ -325,6 +326,33 @@ func (f TraverseSetting) Traverse(ctx context.Context, q ent.Query) error {
|
||||
return fmt.Errorf("unexpected query type %T. expect *ent.SettingQuery", q)
|
||||
}
|
||||
|
||||
// The UsageCleanupTaskFunc type is an adapter to allow the use of ordinary function as a Querier.
|
||||
type UsageCleanupTaskFunc func(context.Context, *ent.UsageCleanupTaskQuery) (ent.Value, error)
|
||||
|
||||
// Query calls f(ctx, q).
|
||||
func (f UsageCleanupTaskFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
|
||||
if q, ok := q.(*ent.UsageCleanupTaskQuery); ok {
|
||||
return f(ctx, q)
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected query type %T. expect *ent.UsageCleanupTaskQuery", q)
|
||||
}
|
||||
|
||||
// The TraverseUsageCleanupTask type is an adapter to allow the use of ordinary function as Traverser.
|
||||
type TraverseUsageCleanupTask func(context.Context, *ent.UsageCleanupTaskQuery) error
|
||||
|
||||
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
|
||||
func (f TraverseUsageCleanupTask) Intercept(next ent.Querier) ent.Querier {
|
||||
return next
|
||||
}
|
||||
|
||||
// Traverse calls f(ctx, q).
|
||||
func (f TraverseUsageCleanupTask) Traverse(ctx context.Context, q ent.Query) error {
|
||||
if q, ok := q.(*ent.UsageCleanupTaskQuery); ok {
|
||||
return f(ctx, q)
|
||||
}
|
||||
return fmt.Errorf("unexpected query type %T. expect *ent.UsageCleanupTaskQuery", q)
|
||||
}
|
||||
|
||||
// The UsageLogFunc type is an adapter to allow the use of ordinary function as a Querier.
|
||||
type UsageLogFunc func(context.Context, *ent.UsageLogQuery) (ent.Value, error)
|
||||
|
||||
@@ -508,6 +536,8 @@ func NewQuery(q ent.Query) (Query, error) {
|
||||
return &query[*ent.RedeemCodeQuery, predicate.RedeemCode, redeemcode.OrderOption]{typ: ent.TypeRedeemCode, tq: q}, nil
|
||||
case *ent.SettingQuery:
|
||||
return &query[*ent.SettingQuery, predicate.Setting, setting.OrderOption]{typ: ent.TypeSetting, tq: q}, nil
|
||||
case *ent.UsageCleanupTaskQuery:
|
||||
return &query[*ent.UsageCleanupTaskQuery, predicate.UsageCleanupTask, usagecleanuptask.OrderOption]{typ: ent.TypeUsageCleanupTask, tq: q}, nil
|
||||
case *ent.UsageLogQuery:
|
||||
return &query[*ent.UsageLogQuery, predicate.UsageLog, usagelog.OrderOption]{typ: ent.TypeUsageLog, tq: q}, nil
|
||||
case *ent.UserQuery:
|
||||
|
||||
@@ -79,6 +79,7 @@ var (
|
||||
{Name: "extra", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "concurrency", Type: field.TypeInt, Default: 3},
|
||||
{Name: "priority", Type: field.TypeInt, Default: 50},
|
||||
{Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
|
||||
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
|
||||
{Name: "error_message", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
|
||||
{Name: "last_used_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
@@ -101,7 +102,7 @@ var (
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "accounts_proxies_proxy",
|
||||
Columns: []*schema.Column{AccountsColumns[24]},
|
||||
Columns: []*schema.Column{AccountsColumns[25]},
|
||||
RefColumns: []*schema.Column{ProxiesColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
@@ -120,12 +121,12 @@ var (
|
||||
{
|
||||
Name: "account_status",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[12]},
|
||||
Columns: []*schema.Column{AccountsColumns[13]},
|
||||
},
|
||||
{
|
||||
Name: "account_proxy_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[24]},
|
||||
Columns: []*schema.Column{AccountsColumns[25]},
|
||||
},
|
||||
{
|
||||
Name: "account_priority",
|
||||
@@ -135,27 +136,27 @@ var (
|
||||
{
|
||||
Name: "account_last_used_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[14]},
|
||||
Columns: []*schema.Column{AccountsColumns[15]},
|
||||
},
|
||||
{
|
||||
Name: "account_schedulable",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[17]},
|
||||
Columns: []*schema.Column{AccountsColumns[18]},
|
||||
},
|
||||
{
|
||||
Name: "account_rate_limited_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[18]},
|
||||
Columns: []*schema.Column{AccountsColumns[19]},
|
||||
},
|
||||
{
|
||||
Name: "account_rate_limit_reset_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[19]},
|
||||
Columns: []*schema.Column{AccountsColumns[20]},
|
||||
},
|
||||
{
|
||||
Name: "account_overload_until",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{AccountsColumns[20]},
|
||||
Columns: []*schema.Column{AccountsColumns[21]},
|
||||
},
|
||||
{
|
||||
Name: "account_deleted_at",
|
||||
@@ -225,6 +226,8 @@ var (
|
||||
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
|
||||
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
|
||||
{Name: "model_routing", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "model_routing_enabled", Type: field.TypeBool, Default: false},
|
||||
}
|
||||
// GroupsTable holds the schema information for the "groups" table.
|
||||
GroupsTable = &schema.Table{
|
||||
@@ -431,6 +434,44 @@ var (
|
||||
Columns: SettingsColumns,
|
||||
PrimaryKey: []*schema.Column{SettingsColumns[0]},
|
||||
}
|
||||
// UsageCleanupTasksColumns holds the columns for the "usage_cleanup_tasks" table.
|
||||
UsageCleanupTasksColumns = []*schema.Column{
|
||||
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "status", Type: field.TypeString, Size: 20},
|
||||
{Name: "filters", Type: field.TypeJSON},
|
||||
{Name: "created_by", Type: field.TypeInt64},
|
||||
{Name: "deleted_rows", Type: field.TypeInt64, Default: 0},
|
||||
{Name: "error_message", Type: field.TypeString, Nullable: true},
|
||||
{Name: "canceled_by", Type: field.TypeInt64, Nullable: true},
|
||||
{Name: "canceled_at", Type: field.TypeTime, Nullable: true},
|
||||
{Name: "started_at", Type: field.TypeTime, Nullable: true},
|
||||
{Name: "finished_at", Type: field.TypeTime, Nullable: true},
|
||||
}
|
||||
// UsageCleanupTasksTable holds the schema information for the "usage_cleanup_tasks" table.
|
||||
UsageCleanupTasksTable = &schema.Table{
|
||||
Name: "usage_cleanup_tasks",
|
||||
Columns: UsageCleanupTasksColumns,
|
||||
PrimaryKey: []*schema.Column{UsageCleanupTasksColumns[0]},
|
||||
Indexes: []*schema.Index{
|
||||
{
|
||||
Name: "usagecleanuptask_status_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageCleanupTasksColumns[3], UsageCleanupTasksColumns[1]},
|
||||
},
|
||||
{
|
||||
Name: "usagecleanuptask_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageCleanupTasksColumns[1]},
|
||||
},
|
||||
{
|
||||
Name: "usagecleanuptask_canceled_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageCleanupTasksColumns[9]},
|
||||
},
|
||||
},
|
||||
}
|
||||
// UsageLogsColumns holds the columns for the "usage_logs" table.
|
||||
UsageLogsColumns = []*schema.Column{
|
||||
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||
@@ -449,6 +490,7 @@ var (
|
||||
{Name: "total_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
|
||||
{Name: "actual_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
|
||||
{Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
|
||||
{Name: "account_rate_multiplier", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
|
||||
{Name: "billing_type", Type: field.TypeInt8, Default: 0},
|
||||
{Name: "stream", Type: field.TypeBool, Default: false},
|
||||
{Name: "duration_ms", Type: field.TypeInt, Nullable: true},
|
||||
@@ -472,31 +514,31 @@ var (
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "usage_logs_api_keys_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[25]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[26]},
|
||||
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_accounts_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[26]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
||||
RefColumns: []*schema.Column{AccountsColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_groups_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||
RefColumns: []*schema.Column{GroupsColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_users_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
@@ -505,32 +547,32 @@ var (
|
||||
{
|
||||
Name: "usagelog_user_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_api_key_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[25]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[26]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_account_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[26]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_group_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[27]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_subscription_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[24]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[25]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_model",
|
||||
@@ -545,12 +587,12 @@ var (
|
||||
{
|
||||
Name: "usagelog_user_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[24]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[25]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_api_key_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[25], UsageLogsColumns[24]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[26], UsageLogsColumns[25]},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -568,6 +610,9 @@ var (
|
||||
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
|
||||
{Name: "username", Type: field.TypeString, Size: 100, Default: ""},
|
||||
{Name: "notes", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
|
||||
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
|
||||
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
|
||||
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
|
||||
}
|
||||
// UsersTable holds the schema information for the "users" table.
|
||||
UsersTable = &schema.Table{
|
||||
@@ -801,6 +846,7 @@ var (
|
||||
ProxiesTable,
|
||||
RedeemCodesTable,
|
||||
SettingsTable,
|
||||
UsageCleanupTasksTable,
|
||||
UsageLogsTable,
|
||||
UsersTable,
|
||||
UserAllowedGroupsTable,
|
||||
@@ -847,6 +893,9 @@ func init() {
|
||||
SettingsTable.Annotation = &entsql.Annotation{
|
||||
Table: "settings",
|
||||
}
|
||||
UsageCleanupTasksTable.Annotation = &entsql.Annotation{
|
||||
Table: "usage_cleanup_tasks",
|
||||
}
|
||||
UsageLogsTable.ForeignKeys[0].RefTable = APIKeysTable
|
||||
UsageLogsTable.ForeignKeys[1].RefTable = AccountsTable
|
||||
UsageLogsTable.ForeignKeys[2].RefTable = GroupsTable
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -33,6 +33,9 @@ type RedeemCode func(*sql.Selector)
|
||||
// Setting is the predicate function for setting builders.
|
||||
type Setting func(*sql.Selector)
|
||||
|
||||
// UsageCleanupTask is the predicate function for usagecleanuptask builders.
|
||||
type UsageCleanupTask func(*sql.Selector)
|
||||
|
||||
// UsageLog is the predicate function for usagelog builders.
|
||||
type UsageLog func(*sql.Selector)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/schema"
|
||||
"github.com/Wei-Shaw/sub2api/ent/setting"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||
@@ -177,22 +178,26 @@ func init() {
|
||||
accountDescPriority := accountFields[8].Descriptor()
|
||||
// account.DefaultPriority holds the default value on creation for the priority field.
|
||||
account.DefaultPriority = accountDescPriority.Default.(int)
|
||||
// accountDescRateMultiplier is the schema descriptor for rate_multiplier field.
|
||||
accountDescRateMultiplier := accountFields[9].Descriptor()
|
||||
// account.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
|
||||
account.DefaultRateMultiplier = accountDescRateMultiplier.Default.(float64)
|
||||
// accountDescStatus is the schema descriptor for status field.
|
||||
accountDescStatus := accountFields[9].Descriptor()
|
||||
accountDescStatus := accountFields[10].Descriptor()
|
||||
// account.DefaultStatus holds the default value on creation for the status field.
|
||||
account.DefaultStatus = accountDescStatus.Default.(string)
|
||||
// account.StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||
account.StatusValidator = accountDescStatus.Validators[0].(func(string) error)
|
||||
// accountDescAutoPauseOnExpired is the schema descriptor for auto_pause_on_expired field.
|
||||
accountDescAutoPauseOnExpired := accountFields[13].Descriptor()
|
||||
accountDescAutoPauseOnExpired := accountFields[14].Descriptor()
|
||||
// account.DefaultAutoPauseOnExpired holds the default value on creation for the auto_pause_on_expired field.
|
||||
account.DefaultAutoPauseOnExpired = accountDescAutoPauseOnExpired.Default.(bool)
|
||||
// accountDescSchedulable is the schema descriptor for schedulable field.
|
||||
accountDescSchedulable := accountFields[14].Descriptor()
|
||||
accountDescSchedulable := accountFields[15].Descriptor()
|
||||
// account.DefaultSchedulable holds the default value on creation for the schedulable field.
|
||||
account.DefaultSchedulable = accountDescSchedulable.Default.(bool)
|
||||
// accountDescSessionWindowStatus is the schema descriptor for session_window_status field.
|
||||
accountDescSessionWindowStatus := accountFields[20].Descriptor()
|
||||
accountDescSessionWindowStatus := accountFields[21].Descriptor()
|
||||
// account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save.
|
||||
account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error)
|
||||
accountgroupFields := schema.AccountGroup{}.Fields()
|
||||
@@ -276,6 +281,10 @@ func init() {
|
||||
groupDescClaudeCodeOnly := groupFields[14].Descriptor()
|
||||
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
|
||||
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
|
||||
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
|
||||
groupDescModelRoutingEnabled := groupFields[17].Descriptor()
|
||||
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
|
||||
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
|
||||
promocodeFields := schema.PromoCode{}.Fields()
|
||||
_ = promocodeFields
|
||||
// promocodeDescCode is the schema descriptor for code field.
|
||||
@@ -487,6 +496,43 @@ func init() {
|
||||
setting.DefaultUpdatedAt = settingDescUpdatedAt.Default.(func() time.Time)
|
||||
// setting.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
|
||||
setting.UpdateDefaultUpdatedAt = settingDescUpdatedAt.UpdateDefault.(func() time.Time)
|
||||
usagecleanuptaskMixin := schema.UsageCleanupTask{}.Mixin()
|
||||
usagecleanuptaskMixinFields0 := usagecleanuptaskMixin[0].Fields()
|
||||
_ = usagecleanuptaskMixinFields0
|
||||
usagecleanuptaskFields := schema.UsageCleanupTask{}.Fields()
|
||||
_ = usagecleanuptaskFields
|
||||
// usagecleanuptaskDescCreatedAt is the schema descriptor for created_at field.
|
||||
usagecleanuptaskDescCreatedAt := usagecleanuptaskMixinFields0[0].Descriptor()
|
||||
// usagecleanuptask.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||
usagecleanuptask.DefaultCreatedAt = usagecleanuptaskDescCreatedAt.Default.(func() time.Time)
|
||||
// usagecleanuptaskDescUpdatedAt is the schema descriptor for updated_at field.
|
||||
usagecleanuptaskDescUpdatedAt := usagecleanuptaskMixinFields0[1].Descriptor()
|
||||
// usagecleanuptask.DefaultUpdatedAt holds the default value on creation for the updated_at field.
|
||||
usagecleanuptask.DefaultUpdatedAt = usagecleanuptaskDescUpdatedAt.Default.(func() time.Time)
|
||||
// usagecleanuptask.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
|
||||
usagecleanuptask.UpdateDefaultUpdatedAt = usagecleanuptaskDescUpdatedAt.UpdateDefault.(func() time.Time)
|
||||
// usagecleanuptaskDescStatus is the schema descriptor for status field.
|
||||
usagecleanuptaskDescStatus := usagecleanuptaskFields[0].Descriptor()
|
||||
// usagecleanuptask.StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||
usagecleanuptask.StatusValidator = func() func(string) error {
|
||||
validators := usagecleanuptaskDescStatus.Validators
|
||||
fns := [...]func(string) error{
|
||||
validators[0].(func(string) error),
|
||||
validators[1].(func(string) error),
|
||||
}
|
||||
return func(status string) error {
|
||||
for _, fn := range fns {
|
||||
if err := fn(status); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}()
|
||||
// usagecleanuptaskDescDeletedRows is the schema descriptor for deleted_rows field.
|
||||
usagecleanuptaskDescDeletedRows := usagecleanuptaskFields[3].Descriptor()
|
||||
// usagecleanuptask.DefaultDeletedRows holds the default value on creation for the deleted_rows field.
|
||||
usagecleanuptask.DefaultDeletedRows = usagecleanuptaskDescDeletedRows.Default.(int64)
|
||||
usagelogFields := schema.UsageLog{}.Fields()
|
||||
_ = usagelogFields
|
||||
// usagelogDescRequestID is the schema descriptor for request_id field.
|
||||
@@ -578,31 +624,31 @@ func init() {
|
||||
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
|
||||
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
|
||||
// usagelogDescBillingType is the schema descriptor for billing_type field.
|
||||
usagelogDescBillingType := usagelogFields[20].Descriptor()
|
||||
usagelogDescBillingType := usagelogFields[21].Descriptor()
|
||||
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
|
||||
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
|
||||
// usagelogDescStream is the schema descriptor for stream field.
|
||||
usagelogDescStream := usagelogFields[21].Descriptor()
|
||||
usagelogDescStream := usagelogFields[22].Descriptor()
|
||||
// usagelog.DefaultStream holds the default value on creation for the stream field.
|
||||
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
|
||||
// usagelogDescUserAgent is the schema descriptor for user_agent field.
|
||||
usagelogDescUserAgent := usagelogFields[24].Descriptor()
|
||||
usagelogDescUserAgent := usagelogFields[25].Descriptor()
|
||||
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
|
||||
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
|
||||
// usagelogDescIPAddress is the schema descriptor for ip_address field.
|
||||
usagelogDescIPAddress := usagelogFields[25].Descriptor()
|
||||
usagelogDescIPAddress := usagelogFields[26].Descriptor()
|
||||
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
|
||||
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
|
||||
// usagelogDescImageCount is the schema descriptor for image_count field.
|
||||
usagelogDescImageCount := usagelogFields[26].Descriptor()
|
||||
usagelogDescImageCount := usagelogFields[27].Descriptor()
|
||||
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
|
||||
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
|
||||
// usagelogDescImageSize is the schema descriptor for image_size field.
|
||||
usagelogDescImageSize := usagelogFields[27].Descriptor()
|
||||
usagelogDescImageSize := usagelogFields[28].Descriptor()
|
||||
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
||||
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
|
||||
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
||||
usagelogDescCreatedAt := usagelogFields[28].Descriptor()
|
||||
usagelogDescCreatedAt := usagelogFields[29].Descriptor()
|
||||
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
||||
userMixin := schema.User{}.Mixin()
|
||||
@@ -690,6 +736,10 @@ func init() {
|
||||
userDescNotes := userFields[7].Descriptor()
|
||||
// user.DefaultNotes holds the default value on creation for the notes field.
|
||||
user.DefaultNotes = userDescNotes.Default.(string)
|
||||
// userDescTotpEnabled is the schema descriptor for totp_enabled field.
|
||||
userDescTotpEnabled := userFields[9].Descriptor()
|
||||
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
|
||||
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
|
||||
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
|
||||
_ = userallowedgroupFields
|
||||
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.
|
||||
|
||||
@@ -102,6 +102,12 @@ func (Account) Fields() []ent.Field {
|
||||
field.Int("priority").
|
||||
Default(50),
|
||||
|
||||
// rate_multiplier: 账号计费倍率(>=0,允许 0 表示该账号计费为 0)
|
||||
// 仅影响账号维度计费口径,不影响用户/API Key 扣费(分组倍率)
|
||||
field.Float("rate_multiplier").
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}).
|
||||
Default(1.0),
|
||||
|
||||
// status: 账户状态,如 "active", "error", "disabled"
|
||||
field.String("status").
|
||||
MaxLen(20).
|
||||
|
||||
@@ -95,6 +95,17 @@ func (Group) Fields() []ent.Field {
|
||||
Optional().
|
||||
Nillable().
|
||||
Comment("非 Claude Code 请求降级使用的分组 ID"),
|
||||
|
||||
// 模型路由配置 (added by migration 040)
|
||||
field.JSON("model_routing", map[string][]int64{}).
|
||||
Optional().
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
||||
Comment("模型路由配置:模型模式 -> 优先账号ID列表"),
|
||||
|
||||
// 模型路由开关 (added by migration 041)
|
||||
field.Bool("model_routing_enabled").
|
||||
Default(false).
|
||||
Comment("是否启用模型路由配置"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ package mixins
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"entgo.io/ent"
|
||||
@@ -12,7 +13,6 @@ import (
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/schema/field"
|
||||
"entgo.io/ent/schema/mixin"
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/intercept"
|
||||
)
|
||||
|
||||
@@ -113,7 +113,6 @@ func (d SoftDeleteMixin) Hooks() []ent.Hook {
|
||||
SetOp(ent.Op)
|
||||
SetDeletedAt(time.Time)
|
||||
WhereP(...func(*sql.Selector))
|
||||
Client() *dbent.Client
|
||||
})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected mutation type %T", m)
|
||||
@@ -124,7 +123,7 @@ func (d SoftDeleteMixin) Hooks() []ent.Hook {
|
||||
mx.SetOp(ent.OpUpdate)
|
||||
// 设置删除时间为当前时间
|
||||
mx.SetDeletedAt(time.Now())
|
||||
return mx.Client().Mutate(ctx, m)
|
||||
return mutateWithClient(ctx, m, next)
|
||||
})
|
||||
},
|
||||
}
|
||||
@@ -137,3 +136,41 @@ func (d SoftDeleteMixin) applyPredicate(w interface{ WhereP(...func(*sql.Selecto
|
||||
sql.FieldIsNull(d.Fields()[0].Descriptor().Name),
|
||||
)
|
||||
}
|
||||
|
||||
func mutateWithClient(ctx context.Context, m ent.Mutation, fallback ent.Mutator) (ent.Value, error) {
|
||||
clientMethod := reflect.ValueOf(m).MethodByName("Client")
|
||||
if !clientMethod.IsValid() || clientMethod.Type().NumIn() != 0 || clientMethod.Type().NumOut() != 1 {
|
||||
return nil, fmt.Errorf("soft delete: mutation client method not found for %T", m)
|
||||
}
|
||||
client := clientMethod.Call(nil)[0]
|
||||
mutateMethod := client.MethodByName("Mutate")
|
||||
if !mutateMethod.IsValid() {
|
||||
return nil, fmt.Errorf("soft delete: mutation client missing Mutate for %T", m)
|
||||
}
|
||||
if mutateMethod.Type().NumIn() != 2 || mutateMethod.Type().NumOut() != 2 {
|
||||
return nil, fmt.Errorf("soft delete: mutation client signature mismatch for %T", m)
|
||||
}
|
||||
|
||||
results := mutateMethod.Call([]reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(m)})
|
||||
value := results[0].Interface()
|
||||
var err error
|
||||
if !results[1].IsNil() {
|
||||
errValue := results[1].Interface()
|
||||
typedErr, ok := errValue.(error)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("soft delete: unexpected error type %T for %T", errValue, m)
|
||||
}
|
||||
err = typedErr
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if value == nil {
|
||||
return nil, fmt.Errorf("soft delete: mutation client returned nil for %T", m)
|
||||
}
|
||||
v, ok := value.(ent.Value)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("soft delete: unexpected value type %T for %T", value, m)
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
75
backend/ent/schema/usage_cleanup_task.go
Normal file
75
backend/ent/schema/usage_cleanup_task.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect/entsql"
|
||||
"entgo.io/ent/schema"
|
||||
"entgo.io/ent/schema/field"
|
||||
"entgo.io/ent/schema/index"
|
||||
)
|
||||
|
||||
// UsageCleanupTask 定义使用记录清理任务的 schema。
|
||||
type UsageCleanupTask struct {
|
||||
ent.Schema
|
||||
}
|
||||
|
||||
func (UsageCleanupTask) Annotations() []schema.Annotation {
|
||||
return []schema.Annotation{
|
||||
entsql.Annotation{Table: "usage_cleanup_tasks"},
|
||||
}
|
||||
}
|
||||
|
||||
func (UsageCleanupTask) Mixin() []ent.Mixin {
|
||||
return []ent.Mixin{
|
||||
mixins.TimeMixin{},
|
||||
}
|
||||
}
|
||||
|
||||
func (UsageCleanupTask) Fields() []ent.Field {
|
||||
return []ent.Field{
|
||||
field.String("status").
|
||||
MaxLen(20).
|
||||
Validate(validateUsageCleanupStatus),
|
||||
field.JSON("filters", json.RawMessage{}),
|
||||
field.Int64("created_by"),
|
||||
field.Int64("deleted_rows").
|
||||
Default(0),
|
||||
field.String("error_message").
|
||||
Optional().
|
||||
Nillable(),
|
||||
field.Int64("canceled_by").
|
||||
Optional().
|
||||
Nillable(),
|
||||
field.Time("canceled_at").
|
||||
Optional().
|
||||
Nillable(),
|
||||
field.Time("started_at").
|
||||
Optional().
|
||||
Nillable(),
|
||||
field.Time("finished_at").
|
||||
Optional().
|
||||
Nillable(),
|
||||
}
|
||||
}
|
||||
|
||||
func (UsageCleanupTask) Indexes() []ent.Index {
|
||||
return []ent.Index{
|
||||
index.Fields("status", "created_at"),
|
||||
index.Fields("created_at"),
|
||||
index.Fields("canceled_at"),
|
||||
}
|
||||
}
|
||||
|
||||
func validateUsageCleanupStatus(status string) error {
|
||||
switch status {
|
||||
case "pending", "running", "succeeded", "failed", "canceled":
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("invalid usage cleanup status: %s", status)
|
||||
}
|
||||
}
|
||||
@@ -85,6 +85,12 @@ func (UsageLog) Fields() []ent.Field {
|
||||
Default(1).
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}),
|
||||
|
||||
// account_rate_multiplier: 账号计费倍率快照(NULL 表示按 1.0 处理)
|
||||
field.Float("account_rate_multiplier").
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}),
|
||||
|
||||
// 其他字段
|
||||
field.Int8("billing_type").
|
||||
Default(0),
|
||||
|
||||
@@ -61,6 +61,17 @@ func (User) Fields() []ent.Field {
|
||||
field.String("notes").
|
||||
SchemaType(map[string]string{dialect.Postgres: "text"}).
|
||||
Default(""),
|
||||
|
||||
// TOTP 双因素认证字段
|
||||
field.String("totp_secret_encrypted").
|
||||
SchemaType(map[string]string{dialect.Postgres: "text"}).
|
||||
Optional().
|
||||
Nillable(),
|
||||
field.Bool("totp_enabled").
|
||||
Default(false),
|
||||
field.Time("totp_enabled_at").
|
||||
Optional().
|
||||
Nillable(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -32,6 +32,8 @@ type Tx struct {
|
||||
RedeemCode *RedeemCodeClient
|
||||
// Setting is the client for interacting with the Setting builders.
|
||||
Setting *SettingClient
|
||||
// UsageCleanupTask is the client for interacting with the UsageCleanupTask builders.
|
||||
UsageCleanupTask *UsageCleanupTaskClient
|
||||
// UsageLog is the client for interacting with the UsageLog builders.
|
||||
UsageLog *UsageLogClient
|
||||
// User is the client for interacting with the User builders.
|
||||
@@ -184,6 +186,7 @@ func (tx *Tx) init() {
|
||||
tx.Proxy = NewProxyClient(tx.config)
|
||||
tx.RedeemCode = NewRedeemCodeClient(tx.config)
|
||||
tx.Setting = NewSettingClient(tx.config)
|
||||
tx.UsageCleanupTask = NewUsageCleanupTaskClient(tx.config)
|
||||
tx.UsageLog = NewUsageLogClient(tx.config)
|
||||
tx.User = NewUserClient(tx.config)
|
||||
tx.UserAllowedGroup = NewUserAllowedGroupClient(tx.config)
|
||||
|
||||
236
backend/ent/usagecleanuptask.go
Normal file
236
backend/ent/usagecleanuptask.go
Normal file
@@ -0,0 +1,236 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
|
||||
)
|
||||
|
||||
// UsageCleanupTask is the model entity for the UsageCleanupTask schema.
|
||||
type UsageCleanupTask struct {
|
||||
config `json:"-"`
|
||||
// ID of the ent.
|
||||
ID int64 `json:"id,omitempty"`
|
||||
// CreatedAt holds the value of the "created_at" field.
|
||||
CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
// UpdatedAt holds the value of the "updated_at" field.
|
||||
UpdatedAt time.Time `json:"updated_at,omitempty"`
|
||||
// Status holds the value of the "status" field.
|
||||
Status string `json:"status,omitempty"`
|
||||
// Filters holds the value of the "filters" field.
|
||||
Filters json.RawMessage `json:"filters,omitempty"`
|
||||
// CreatedBy holds the value of the "created_by" field.
|
||||
CreatedBy int64 `json:"created_by,omitempty"`
|
||||
// DeletedRows holds the value of the "deleted_rows" field.
|
||||
DeletedRows int64 `json:"deleted_rows,omitempty"`
|
||||
// ErrorMessage holds the value of the "error_message" field.
|
||||
ErrorMessage *string `json:"error_message,omitempty"`
|
||||
// CanceledBy holds the value of the "canceled_by" field.
|
||||
CanceledBy *int64 `json:"canceled_by,omitempty"`
|
||||
// CanceledAt holds the value of the "canceled_at" field.
|
||||
CanceledAt *time.Time `json:"canceled_at,omitempty"`
|
||||
// StartedAt holds the value of the "started_at" field.
|
||||
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||
// FinishedAt holds the value of the "finished_at" field.
|
||||
FinishedAt *time.Time `json:"finished_at,omitempty"`
|
||||
selectValues sql.SelectValues
|
||||
}
|
||||
|
||||
// scanValues returns the types for scanning values from sql.Rows.
|
||||
func (*UsageCleanupTask) scanValues(columns []string) ([]any, error) {
|
||||
values := make([]any, len(columns))
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
case usagecleanuptask.FieldFilters:
|
||||
values[i] = new([]byte)
|
||||
case usagecleanuptask.FieldID, usagecleanuptask.FieldCreatedBy, usagecleanuptask.FieldDeletedRows, usagecleanuptask.FieldCanceledBy:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case usagecleanuptask.FieldStatus, usagecleanuptask.FieldErrorMessage:
|
||||
values[i] = new(sql.NullString)
|
||||
case usagecleanuptask.FieldCreatedAt, usagecleanuptask.FieldUpdatedAt, usagecleanuptask.FieldCanceledAt, usagecleanuptask.FieldStartedAt, usagecleanuptask.FieldFinishedAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
default:
|
||||
values[i] = new(sql.UnknownType)
|
||||
}
|
||||
}
|
||||
return values, nil
|
||||
}
|
||||
|
||||
// assignValues assigns the values that were returned from sql.Rows (after scanning)
|
||||
// to the UsageCleanupTask fields.
|
||||
func (_m *UsageCleanupTask) assignValues(columns []string, values []any) error {
|
||||
if m, n := len(values), len(columns); m < n {
|
||||
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
|
||||
}
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
case usagecleanuptask.FieldID:
|
||||
value, ok := values[i].(*sql.NullInt64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field id", value)
|
||||
}
|
||||
_m.ID = int64(value.Int64)
|
||||
case usagecleanuptask.FieldCreatedAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field created_at", values[i])
|
||||
} else if value.Valid {
|
||||
_m.CreatedAt = value.Time
|
||||
}
|
||||
case usagecleanuptask.FieldUpdatedAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
|
||||
} else if value.Valid {
|
||||
_m.UpdatedAt = value.Time
|
||||
}
|
||||
case usagecleanuptask.FieldStatus:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field status", values[i])
|
||||
} else if value.Valid {
|
||||
_m.Status = value.String
|
||||
}
|
||||
case usagecleanuptask.FieldFilters:
|
||||
if value, ok := values[i].(*[]byte); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field filters", values[i])
|
||||
} else if value != nil && len(*value) > 0 {
|
||||
if err := json.Unmarshal(*value, &_m.Filters); err != nil {
|
||||
return fmt.Errorf("unmarshal field filters: %w", err)
|
||||
}
|
||||
}
|
||||
case usagecleanuptask.FieldCreatedBy:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field created_by", values[i])
|
||||
} else if value.Valid {
|
||||
_m.CreatedBy = value.Int64
|
||||
}
|
||||
case usagecleanuptask.FieldDeletedRows:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field deleted_rows", values[i])
|
||||
} else if value.Valid {
|
||||
_m.DeletedRows = value.Int64
|
||||
}
|
||||
case usagecleanuptask.FieldErrorMessage:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field error_message", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ErrorMessage = new(string)
|
||||
*_m.ErrorMessage = value.String
|
||||
}
|
||||
case usagecleanuptask.FieldCanceledBy:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field canceled_by", values[i])
|
||||
} else if value.Valid {
|
||||
_m.CanceledBy = new(int64)
|
||||
*_m.CanceledBy = value.Int64
|
||||
}
|
||||
case usagecleanuptask.FieldCanceledAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field canceled_at", values[i])
|
||||
} else if value.Valid {
|
||||
_m.CanceledAt = new(time.Time)
|
||||
*_m.CanceledAt = value.Time
|
||||
}
|
||||
case usagecleanuptask.FieldStartedAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field started_at", values[i])
|
||||
} else if value.Valid {
|
||||
_m.StartedAt = new(time.Time)
|
||||
*_m.StartedAt = value.Time
|
||||
}
|
||||
case usagecleanuptask.FieldFinishedAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field finished_at", values[i])
|
||||
} else if value.Valid {
|
||||
_m.FinishedAt = new(time.Time)
|
||||
*_m.FinishedAt = value.Time
|
||||
}
|
||||
default:
|
||||
_m.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value returns the ent.Value that was dynamically selected and assigned to the UsageCleanupTask.
|
||||
// This includes values selected through modifiers, order, etc.
|
||||
func (_m *UsageCleanupTask) Value(name string) (ent.Value, error) {
|
||||
return _m.selectValues.Get(name)
|
||||
}
|
||||
|
||||
// Update returns a builder for updating this UsageCleanupTask.
|
||||
// Note that you need to call UsageCleanupTask.Unwrap() before calling this method if this UsageCleanupTask
|
||||
// was returned from a transaction, and the transaction was committed or rolled back.
|
||||
func (_m *UsageCleanupTask) Update() *UsageCleanupTaskUpdateOne {
|
||||
return NewUsageCleanupTaskClient(_m.config).UpdateOne(_m)
|
||||
}
|
||||
|
||||
// Unwrap unwraps the UsageCleanupTask entity that was returned from a transaction after it was closed,
|
||||
// so that all future queries will be executed through the driver which created the transaction.
|
||||
func (_m *UsageCleanupTask) Unwrap() *UsageCleanupTask {
|
||||
_tx, ok := _m.config.driver.(*txDriver)
|
||||
if !ok {
|
||||
panic("ent: UsageCleanupTask is not a transactional entity")
|
||||
}
|
||||
_m.config.driver = _tx.drv
|
||||
return _m
|
||||
}
|
||||
|
||||
// String implements the fmt.Stringer.
|
||||
func (_m *UsageCleanupTask) String() string {
|
||||
var builder strings.Builder
|
||||
builder.WriteString("UsageCleanupTask(")
|
||||
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
|
||||
builder.WriteString("created_at=")
|
||||
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("updated_at=")
|
||||
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("status=")
|
||||
builder.WriteString(_m.Status)
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("filters=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Filters))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("created_by=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.CreatedBy))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("deleted_rows=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.DeletedRows))
|
||||
builder.WriteString(", ")
|
||||
if v := _m.ErrorMessage; v != nil {
|
||||
builder.WriteString("error_message=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.CanceledBy; v != nil {
|
||||
builder.WriteString("canceled_by=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.CanceledAt; v != nil {
|
||||
builder.WriteString("canceled_at=")
|
||||
builder.WriteString(v.Format(time.ANSIC))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.StartedAt; v != nil {
|
||||
builder.WriteString("started_at=")
|
||||
builder.WriteString(v.Format(time.ANSIC))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.FinishedAt; v != nil {
|
||||
builder.WriteString("finished_at=")
|
||||
builder.WriteString(v.Format(time.ANSIC))
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// UsageCleanupTasks is a parsable slice of UsageCleanupTask.
|
||||
type UsageCleanupTasks []*UsageCleanupTask
|
||||
137
backend/ent/usagecleanuptask/usagecleanuptask.go
Normal file
137
backend/ent/usagecleanuptask/usagecleanuptask.go
Normal file
@@ -0,0 +1,137 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package usagecleanuptask
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
)
|
||||
|
||||
const (
|
||||
// Label holds the string label denoting the usagecleanuptask type in the database.
|
||||
Label = "usage_cleanup_task"
|
||||
// FieldID holds the string denoting the id field in the database.
|
||||
FieldID = "id"
|
||||
// FieldCreatedAt holds the string denoting the created_at field in the database.
|
||||
FieldCreatedAt = "created_at"
|
||||
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
|
||||
FieldUpdatedAt = "updated_at"
|
||||
// FieldStatus holds the string denoting the status field in the database.
|
||||
FieldStatus = "status"
|
||||
// FieldFilters holds the string denoting the filters field in the database.
|
||||
FieldFilters = "filters"
|
||||
// FieldCreatedBy holds the string denoting the created_by field in the database.
|
||||
FieldCreatedBy = "created_by"
|
||||
// FieldDeletedRows holds the string denoting the deleted_rows field in the database.
|
||||
FieldDeletedRows = "deleted_rows"
|
||||
// FieldErrorMessage holds the string denoting the error_message field in the database.
|
||||
FieldErrorMessage = "error_message"
|
||||
// FieldCanceledBy holds the string denoting the canceled_by field in the database.
|
||||
FieldCanceledBy = "canceled_by"
|
||||
// FieldCanceledAt holds the string denoting the canceled_at field in the database.
|
||||
FieldCanceledAt = "canceled_at"
|
||||
// FieldStartedAt holds the string denoting the started_at field in the database.
|
||||
FieldStartedAt = "started_at"
|
||||
// FieldFinishedAt holds the string denoting the finished_at field in the database.
|
||||
FieldFinishedAt = "finished_at"
|
||||
// Table holds the table name of the usagecleanuptask in the database.
|
||||
Table = "usage_cleanup_tasks"
|
||||
)
|
||||
|
||||
// Columns holds all SQL columns for usagecleanuptask fields.
|
||||
var Columns = []string{
|
||||
FieldID,
|
||||
FieldCreatedAt,
|
||||
FieldUpdatedAt,
|
||||
FieldStatus,
|
||||
FieldFilters,
|
||||
FieldCreatedBy,
|
||||
FieldDeletedRows,
|
||||
FieldErrorMessage,
|
||||
FieldCanceledBy,
|
||||
FieldCanceledAt,
|
||||
FieldStartedAt,
|
||||
FieldFinishedAt,
|
||||
}
|
||||
|
||||
// ValidColumn reports if the column name is valid (part of the table columns).
|
||||
func ValidColumn(column string) bool {
|
||||
for i := range Columns {
|
||||
if column == Columns[i] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var (
|
||||
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
||||
DefaultCreatedAt func() time.Time
|
||||
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
|
||||
DefaultUpdatedAt func() time.Time
|
||||
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
|
||||
UpdateDefaultUpdatedAt func() time.Time
|
||||
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||
StatusValidator func(string) error
|
||||
// DefaultDeletedRows holds the default value on creation for the "deleted_rows" field.
|
||||
DefaultDeletedRows int64
|
||||
)
|
||||
|
||||
// OrderOption defines the ordering options for the UsageCleanupTask queries.
|
||||
type OrderOption func(*sql.Selector)
|
||||
|
||||
// ByID orders the results by the id field.
|
||||
func ByID(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldID, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByCreatedAt orders the results by the created_at field.
|
||||
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByUpdatedAt orders the results by the updated_at field.
|
||||
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByStatus orders the results by the status field.
|
||||
func ByStatus(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldStatus, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByCreatedBy orders the results by the created_by field.
|
||||
func ByCreatedBy(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldCreatedBy, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByDeletedRows orders the results by the deleted_rows field.
|
||||
func ByDeletedRows(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldDeletedRows, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByErrorMessage orders the results by the error_message field.
|
||||
func ByErrorMessage(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldErrorMessage, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByCanceledBy orders the results by the canceled_by field.
|
||||
func ByCanceledBy(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldCanceledBy, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByCanceledAt orders the results by the canceled_at field.
|
||||
func ByCanceledAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldCanceledAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByStartedAt orders the results by the started_at field.
|
||||
func ByStartedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldStartedAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByFinishedAt orders the results by the finished_at field.
|
||||
func ByFinishedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldFinishedAt, opts...).ToFunc()
|
||||
}
|
||||
620
backend/ent/usagecleanuptask/where.go
Normal file
620
backend/ent/usagecleanuptask/where.go
Normal file
@@ -0,0 +1,620 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package usagecleanuptask
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
)
|
||||
|
||||
// ID filters vertices based on their ID field.
|
||||
func ID(id int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldEQ(FieldID, id))
|
||||
}
|
||||
|
||||
// IDEQ applies the EQ predicate on the ID field.
|
||||
func IDEQ(id int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldEQ(FieldID, id))
|
||||
}
|
||||
|
||||
// IDNEQ applies the NEQ predicate on the ID field.
|
||||
func IDNEQ(id int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldID, id))
|
||||
}
|
||||
|
||||
// IDIn applies the In predicate on the ID field.
|
||||
func IDIn(ids ...int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldIn(FieldID, ids...))
|
||||
}
|
||||
|
||||
// IDNotIn applies the NotIn predicate on the ID field.
|
||||
func IDNotIn(ids ...int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldID, ids...))
|
||||
}
|
||||
|
||||
// IDGT applies the GT predicate on the ID field.
|
||||
func IDGT(id int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldGT(FieldID, id))
|
||||
}
|
||||
|
||||
// IDGTE applies the GTE predicate on the ID field.
|
||||
func IDGTE(id int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldGTE(FieldID, id))
|
||||
}
|
||||
|
||||
// IDLT applies the LT predicate on the ID field.
|
||||
func IDLT(id int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldLT(FieldID, id))
|
||||
}
|
||||
|
||||
// IDLTE applies the LTE predicate on the ID field.
|
||||
func IDLTE(id int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldLTE(FieldID, id))
|
||||
}
|
||||
|
||||
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
|
||||
func CreatedAt(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldEQ(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
|
||||
func UpdatedAt(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldEQ(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
|
||||
func Status(v string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldEQ(FieldStatus, v))
|
||||
}
|
||||
|
||||
// CreatedBy applies equality check predicate on the "created_by" field. It's identical to CreatedByEQ.
|
||||
func CreatedBy(v int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldEQ(FieldCreatedBy, v))
|
||||
}
|
||||
|
||||
// DeletedRows applies equality check predicate on the "deleted_rows" field. It's identical to DeletedRowsEQ.
|
||||
func DeletedRows(v int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldEQ(FieldDeletedRows, v))
|
||||
}
|
||||
|
||||
// ErrorMessage applies equality check predicate on the "error_message" field. It's identical to ErrorMessageEQ.
|
||||
func ErrorMessage(v string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldEQ(FieldErrorMessage, v))
|
||||
}
|
||||
|
||||
// CanceledBy applies equality check predicate on the "canceled_by" field. It's identical to CanceledByEQ.
|
||||
func CanceledBy(v int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldEQ(FieldCanceledBy, v))
|
||||
}
|
||||
|
||||
// CanceledAt applies equality check predicate on the "canceled_at" field. It's identical to CanceledAtEQ.
|
||||
func CanceledAt(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldEQ(FieldCanceledAt, v))
|
||||
}
|
||||
|
||||
// StartedAt applies equality check predicate on the "started_at" field. It's identical to StartedAtEQ.
|
||||
func StartedAt(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldEQ(FieldStartedAt, v))
|
||||
}
|
||||
|
||||
// FinishedAt applies equality check predicate on the "finished_at" field. It's identical to FinishedAtEQ.
|
||||
func FinishedAt(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldEQ(FieldFinishedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldEQ(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
|
||||
func CreatedAtNEQ(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtIn applies the In predicate on the "created_at" field.
|
||||
func CreatedAtIn(vs ...time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldIn(FieldCreatedAt, vs...))
|
||||
}
|
||||
|
||||
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
|
||||
func CreatedAtNotIn(vs ...time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldCreatedAt, vs...))
|
||||
}
|
||||
|
||||
// CreatedAtGT applies the GT predicate on the "created_at" field.
|
||||
func CreatedAtGT(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldGT(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
|
||||
func CreatedAtGTE(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldGTE(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtLT applies the LT predicate on the "created_at" field.
|
||||
func CreatedAtLT(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldLT(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
|
||||
func CreatedAtLTE(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldLTE(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
|
||||
func UpdatedAtEQ(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldEQ(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
|
||||
func UpdatedAtNEQ(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtIn applies the In predicate on the "updated_at" field.
|
||||
func UpdatedAtIn(vs ...time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldIn(FieldUpdatedAt, vs...))
|
||||
}
|
||||
|
||||
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
|
||||
func UpdatedAtNotIn(vs ...time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldUpdatedAt, vs...))
|
||||
}
|
||||
|
||||
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
|
||||
func UpdatedAtGT(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldGT(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
|
||||
func UpdatedAtGTE(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldGTE(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
|
||||
func UpdatedAtLT(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldLT(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
|
||||
func UpdatedAtLTE(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldLTE(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// StatusEQ applies the EQ predicate on the "status" field.
|
||||
func StatusEQ(v string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldEQ(FieldStatus, v))
|
||||
}
|
||||
|
||||
// StatusNEQ applies the NEQ predicate on the "status" field.
|
||||
func StatusNEQ(v string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldStatus, v))
|
||||
}
|
||||
|
||||
// StatusIn applies the In predicate on the "status" field.
|
||||
func StatusIn(vs ...string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldIn(FieldStatus, vs...))
|
||||
}
|
||||
|
||||
// StatusNotIn applies the NotIn predicate on the "status" field.
|
||||
func StatusNotIn(vs ...string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldStatus, vs...))
|
||||
}
|
||||
|
||||
// StatusGT applies the GT predicate on the "status" field.
|
||||
func StatusGT(v string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldGT(FieldStatus, v))
|
||||
}
|
||||
|
||||
// StatusGTE applies the GTE predicate on the "status" field.
|
||||
func StatusGTE(v string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldGTE(FieldStatus, v))
|
||||
}
|
||||
|
||||
// StatusLT applies the LT predicate on the "status" field.
|
||||
func StatusLT(v string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldLT(FieldStatus, v))
|
||||
}
|
||||
|
||||
// StatusLTE applies the LTE predicate on the "status" field.
|
||||
func StatusLTE(v string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldLTE(FieldStatus, v))
|
||||
}
|
||||
|
||||
// StatusContains applies the Contains predicate on the "status" field.
|
||||
func StatusContains(v string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldContains(FieldStatus, v))
|
||||
}
|
||||
|
||||
// StatusHasPrefix applies the HasPrefix predicate on the "status" field.
|
||||
func StatusHasPrefix(v string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldHasPrefix(FieldStatus, v))
|
||||
}
|
||||
|
||||
// StatusHasSuffix applies the HasSuffix predicate on the "status" field.
|
||||
func StatusHasSuffix(v string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldHasSuffix(FieldStatus, v))
|
||||
}
|
||||
|
||||
// StatusEqualFold applies the EqualFold predicate on the "status" field.
|
||||
func StatusEqualFold(v string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldEqualFold(FieldStatus, v))
|
||||
}
|
||||
|
||||
// StatusContainsFold applies the ContainsFold predicate on the "status" field.
|
||||
func StatusContainsFold(v string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldContainsFold(FieldStatus, v))
|
||||
}
|
||||
|
||||
// CreatedByEQ applies the EQ predicate on the "created_by" field.
|
||||
func CreatedByEQ(v int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldEQ(FieldCreatedBy, v))
|
||||
}
|
||||
|
||||
// CreatedByNEQ applies the NEQ predicate on the "created_by" field.
|
||||
func CreatedByNEQ(v int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldCreatedBy, v))
|
||||
}
|
||||
|
||||
// CreatedByIn applies the In predicate on the "created_by" field.
|
||||
func CreatedByIn(vs ...int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldIn(FieldCreatedBy, vs...))
|
||||
}
|
||||
|
||||
// CreatedByNotIn applies the NotIn predicate on the "created_by" field.
|
||||
func CreatedByNotIn(vs ...int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldCreatedBy, vs...))
|
||||
}
|
||||
|
||||
// CreatedByGT applies the GT predicate on the "created_by" field.
|
||||
func CreatedByGT(v int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldGT(FieldCreatedBy, v))
|
||||
}
|
||||
|
||||
// CreatedByGTE applies the GTE predicate on the "created_by" field.
|
||||
func CreatedByGTE(v int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldGTE(FieldCreatedBy, v))
|
||||
}
|
||||
|
||||
// CreatedByLT applies the LT predicate on the "created_by" field.
|
||||
func CreatedByLT(v int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldLT(FieldCreatedBy, v))
|
||||
}
|
||||
|
||||
// CreatedByLTE applies the LTE predicate on the "created_by" field.
|
||||
func CreatedByLTE(v int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldLTE(FieldCreatedBy, v))
|
||||
}
|
||||
|
||||
// DeletedRowsEQ applies the EQ predicate on the "deleted_rows" field.
|
||||
func DeletedRowsEQ(v int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldEQ(FieldDeletedRows, v))
|
||||
}
|
||||
|
||||
// DeletedRowsNEQ applies the NEQ predicate on the "deleted_rows" field.
|
||||
func DeletedRowsNEQ(v int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldDeletedRows, v))
|
||||
}
|
||||
|
||||
// DeletedRowsIn applies the In predicate on the "deleted_rows" field.
|
||||
func DeletedRowsIn(vs ...int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldIn(FieldDeletedRows, vs...))
|
||||
}
|
||||
|
||||
// DeletedRowsNotIn applies the NotIn predicate on the "deleted_rows" field.
|
||||
func DeletedRowsNotIn(vs ...int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldDeletedRows, vs...))
|
||||
}
|
||||
|
||||
// DeletedRowsGT applies the GT predicate on the "deleted_rows" field.
|
||||
func DeletedRowsGT(v int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldGT(FieldDeletedRows, v))
|
||||
}
|
||||
|
||||
// DeletedRowsGTE applies the GTE predicate on the "deleted_rows" field.
|
||||
func DeletedRowsGTE(v int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldGTE(FieldDeletedRows, v))
|
||||
}
|
||||
|
||||
// DeletedRowsLT applies the LT predicate on the "deleted_rows" field.
|
||||
func DeletedRowsLT(v int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldLT(FieldDeletedRows, v))
|
||||
}
|
||||
|
||||
// DeletedRowsLTE applies the LTE predicate on the "deleted_rows" field.
|
||||
func DeletedRowsLTE(v int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldLTE(FieldDeletedRows, v))
|
||||
}
|
||||
|
||||
// ErrorMessageEQ applies the EQ predicate on the "error_message" field.
|
||||
func ErrorMessageEQ(v string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldEQ(FieldErrorMessage, v))
|
||||
}
|
||||
|
||||
// ErrorMessageNEQ applies the NEQ predicate on the "error_message" field.
|
||||
func ErrorMessageNEQ(v string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldErrorMessage, v))
|
||||
}
|
||||
|
||||
// ErrorMessageIn applies the In predicate on the "error_message" field.
|
||||
func ErrorMessageIn(vs ...string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldIn(FieldErrorMessage, vs...))
|
||||
}
|
||||
|
||||
// ErrorMessageNotIn applies the NotIn predicate on the "error_message" field.
|
||||
func ErrorMessageNotIn(vs ...string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldErrorMessage, vs...))
|
||||
}
|
||||
|
||||
// ErrorMessageGT applies the GT predicate on the "error_message" field.
|
||||
func ErrorMessageGT(v string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldGT(FieldErrorMessage, v))
|
||||
}
|
||||
|
||||
// ErrorMessageGTE applies the GTE predicate on the "error_message" field.
|
||||
func ErrorMessageGTE(v string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldGTE(FieldErrorMessage, v))
|
||||
}
|
||||
|
||||
// ErrorMessageLT applies the LT predicate on the "error_message" field.
|
||||
func ErrorMessageLT(v string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldLT(FieldErrorMessage, v))
|
||||
}
|
||||
|
||||
// ErrorMessageLTE applies the LTE predicate on the "error_message" field.
|
||||
func ErrorMessageLTE(v string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldLTE(FieldErrorMessage, v))
|
||||
}
|
||||
|
||||
// ErrorMessageContains applies the Contains predicate on the "error_message" field.
|
||||
func ErrorMessageContains(v string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldContains(FieldErrorMessage, v))
|
||||
}
|
||||
|
||||
// ErrorMessageHasPrefix applies the HasPrefix predicate on the "error_message" field.
|
||||
func ErrorMessageHasPrefix(v string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldHasPrefix(FieldErrorMessage, v))
|
||||
}
|
||||
|
||||
// ErrorMessageHasSuffix applies the HasSuffix predicate on the "error_message" field.
|
||||
func ErrorMessageHasSuffix(v string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldHasSuffix(FieldErrorMessage, v))
|
||||
}
|
||||
|
||||
// ErrorMessageIsNil applies the IsNil predicate on the "error_message" field.
|
||||
func ErrorMessageIsNil() predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldIsNull(FieldErrorMessage))
|
||||
}
|
||||
|
||||
// ErrorMessageNotNil applies the NotNil predicate on the "error_message" field.
|
||||
func ErrorMessageNotNil() predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNotNull(FieldErrorMessage))
|
||||
}
|
||||
|
||||
// ErrorMessageEqualFold applies the EqualFold predicate on the "error_message" field.
|
||||
func ErrorMessageEqualFold(v string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldEqualFold(FieldErrorMessage, v))
|
||||
}
|
||||
|
||||
// ErrorMessageContainsFold applies the ContainsFold predicate on the "error_message" field.
|
||||
func ErrorMessageContainsFold(v string) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldContainsFold(FieldErrorMessage, v))
|
||||
}
|
||||
|
||||
// CanceledByEQ applies the EQ predicate on the "canceled_by" field.
|
||||
func CanceledByEQ(v int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldEQ(FieldCanceledBy, v))
|
||||
}
|
||||
|
||||
// CanceledByNEQ applies the NEQ predicate on the "canceled_by" field.
|
||||
func CanceledByNEQ(v int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldCanceledBy, v))
|
||||
}
|
||||
|
||||
// CanceledByIn applies the In predicate on the "canceled_by" field.
|
||||
func CanceledByIn(vs ...int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldIn(FieldCanceledBy, vs...))
|
||||
}
|
||||
|
||||
// CanceledByNotIn applies the NotIn predicate on the "canceled_by" field.
|
||||
func CanceledByNotIn(vs ...int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldCanceledBy, vs...))
|
||||
}
|
||||
|
||||
// CanceledByGT applies the GT predicate on the "canceled_by" field.
|
||||
func CanceledByGT(v int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldGT(FieldCanceledBy, v))
|
||||
}
|
||||
|
||||
// CanceledByGTE applies the GTE predicate on the "canceled_by" field.
|
||||
func CanceledByGTE(v int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldGTE(FieldCanceledBy, v))
|
||||
}
|
||||
|
||||
// CanceledByLT applies the LT predicate on the "canceled_by" field.
|
||||
func CanceledByLT(v int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldLT(FieldCanceledBy, v))
|
||||
}
|
||||
|
||||
// CanceledByLTE applies the LTE predicate on the "canceled_by" field.
|
||||
func CanceledByLTE(v int64) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldLTE(FieldCanceledBy, v))
|
||||
}
|
||||
|
||||
// CanceledByIsNil applies the IsNil predicate on the "canceled_by" field.
|
||||
func CanceledByIsNil() predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldIsNull(FieldCanceledBy))
|
||||
}
|
||||
|
||||
// CanceledByNotNil applies the NotNil predicate on the "canceled_by" field.
|
||||
func CanceledByNotNil() predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNotNull(FieldCanceledBy))
|
||||
}
|
||||
|
||||
// CanceledAtEQ applies the EQ predicate on the "canceled_at" field.
|
||||
func CanceledAtEQ(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldEQ(FieldCanceledAt, v))
|
||||
}
|
||||
|
||||
// CanceledAtNEQ applies the NEQ predicate on the "canceled_at" field.
|
||||
func CanceledAtNEQ(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldCanceledAt, v))
|
||||
}
|
||||
|
||||
// CanceledAtIn applies the In predicate on the "canceled_at" field.
|
||||
func CanceledAtIn(vs ...time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldIn(FieldCanceledAt, vs...))
|
||||
}
|
||||
|
||||
// CanceledAtNotIn applies the NotIn predicate on the "canceled_at" field.
|
||||
func CanceledAtNotIn(vs ...time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldCanceledAt, vs...))
|
||||
}
|
||||
|
||||
// CanceledAtGT applies the GT predicate on the "canceled_at" field.
|
||||
func CanceledAtGT(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldGT(FieldCanceledAt, v))
|
||||
}
|
||||
|
||||
// CanceledAtGTE applies the GTE predicate on the "canceled_at" field.
|
||||
func CanceledAtGTE(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldGTE(FieldCanceledAt, v))
|
||||
}
|
||||
|
||||
// CanceledAtLT applies the LT predicate on the "canceled_at" field.
|
||||
func CanceledAtLT(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldLT(FieldCanceledAt, v))
|
||||
}
|
||||
|
||||
// CanceledAtLTE applies the LTE predicate on the "canceled_at" field.
|
||||
func CanceledAtLTE(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldLTE(FieldCanceledAt, v))
|
||||
}
|
||||
|
||||
// CanceledAtIsNil applies the IsNil predicate on the "canceled_at" field.
|
||||
func CanceledAtIsNil() predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldIsNull(FieldCanceledAt))
|
||||
}
|
||||
|
||||
// CanceledAtNotNil applies the NotNil predicate on the "canceled_at" field.
|
||||
func CanceledAtNotNil() predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNotNull(FieldCanceledAt))
|
||||
}
|
||||
|
||||
// StartedAtEQ applies the EQ predicate on the "started_at" field.
|
||||
func StartedAtEQ(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldEQ(FieldStartedAt, v))
|
||||
}
|
||||
|
||||
// StartedAtNEQ applies the NEQ predicate on the "started_at" field.
|
||||
func StartedAtNEQ(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldStartedAt, v))
|
||||
}
|
||||
|
||||
// StartedAtIn applies the In predicate on the "started_at" field.
|
||||
func StartedAtIn(vs ...time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldIn(FieldStartedAt, vs...))
|
||||
}
|
||||
|
||||
// StartedAtNotIn applies the NotIn predicate on the "started_at" field.
|
||||
func StartedAtNotIn(vs ...time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldStartedAt, vs...))
|
||||
}
|
||||
|
||||
// StartedAtGT applies the GT predicate on the "started_at" field.
|
||||
func StartedAtGT(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldGT(FieldStartedAt, v))
|
||||
}
|
||||
|
||||
// StartedAtGTE applies the GTE predicate on the "started_at" field.
|
||||
func StartedAtGTE(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldGTE(FieldStartedAt, v))
|
||||
}
|
||||
|
||||
// StartedAtLT applies the LT predicate on the "started_at" field.
|
||||
func StartedAtLT(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldLT(FieldStartedAt, v))
|
||||
}
|
||||
|
||||
// StartedAtLTE applies the LTE predicate on the "started_at" field.
|
||||
func StartedAtLTE(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldLTE(FieldStartedAt, v))
|
||||
}
|
||||
|
||||
// StartedAtIsNil applies the IsNil predicate on the "started_at" field.
|
||||
func StartedAtIsNil() predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldIsNull(FieldStartedAt))
|
||||
}
|
||||
|
||||
// StartedAtNotNil applies the NotNil predicate on the "started_at" field.
|
||||
func StartedAtNotNil() predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNotNull(FieldStartedAt))
|
||||
}
|
||||
|
||||
// FinishedAtEQ applies the EQ predicate on the "finished_at" field.
|
||||
func FinishedAtEQ(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldEQ(FieldFinishedAt, v))
|
||||
}
|
||||
|
||||
// FinishedAtNEQ applies the NEQ predicate on the "finished_at" field.
|
||||
func FinishedAtNEQ(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldFinishedAt, v))
|
||||
}
|
||||
|
||||
// FinishedAtIn applies the In predicate on the "finished_at" field.
|
||||
func FinishedAtIn(vs ...time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldIn(FieldFinishedAt, vs...))
|
||||
}
|
||||
|
||||
// FinishedAtNotIn applies the NotIn predicate on the "finished_at" field.
|
||||
func FinishedAtNotIn(vs ...time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldFinishedAt, vs...))
|
||||
}
|
||||
|
||||
// FinishedAtGT applies the GT predicate on the "finished_at" field.
|
||||
func FinishedAtGT(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldGT(FieldFinishedAt, v))
|
||||
}
|
||||
|
||||
// FinishedAtGTE applies the GTE predicate on the "finished_at" field.
|
||||
func FinishedAtGTE(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldGTE(FieldFinishedAt, v))
|
||||
}
|
||||
|
||||
// FinishedAtLT applies the LT predicate on the "finished_at" field.
|
||||
func FinishedAtLT(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldLT(FieldFinishedAt, v))
|
||||
}
|
||||
|
||||
// FinishedAtLTE applies the LTE predicate on the "finished_at" field.
|
||||
func FinishedAtLTE(v time.Time) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldLTE(FieldFinishedAt, v))
|
||||
}
|
||||
|
||||
// FinishedAtIsNil applies the IsNil predicate on the "finished_at" field.
|
||||
func FinishedAtIsNil() predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldIsNull(FieldFinishedAt))
|
||||
}
|
||||
|
||||
// FinishedAtNotNil applies the NotNil predicate on the "finished_at" field.
|
||||
func FinishedAtNotNil() predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.FieldNotNull(FieldFinishedAt))
|
||||
}
|
||||
|
||||
// And groups predicates with the AND operator between them.
|
||||
func And(predicates ...predicate.UsageCleanupTask) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.AndPredicates(predicates...))
|
||||
}
|
||||
|
||||
// Or groups predicates with the OR operator between them.
|
||||
func Or(predicates ...predicate.UsageCleanupTask) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.OrPredicates(predicates...))
|
||||
}
|
||||
|
||||
// Not applies the not operator on the given predicate.
|
||||
func Not(p predicate.UsageCleanupTask) predicate.UsageCleanupTask {
|
||||
return predicate.UsageCleanupTask(sql.NotPredicates(p))
|
||||
}
|
||||
1190
backend/ent/usagecleanuptask_create.go
Normal file
1190
backend/ent/usagecleanuptask_create.go
Normal file
File diff suppressed because it is too large
Load Diff
88
backend/ent/usagecleanuptask_delete.go
Normal file
88
backend/ent/usagecleanuptask_delete.go
Normal file
@@ -0,0 +1,88 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
|
||||
)
|
||||
|
||||
// UsageCleanupTaskDelete is the builder for deleting a UsageCleanupTask entity.
|
||||
type UsageCleanupTaskDelete struct {
|
||||
config
|
||||
hooks []Hook
|
||||
mutation *UsageCleanupTaskMutation
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the UsageCleanupTaskDelete builder.
|
||||
func (_d *UsageCleanupTaskDelete) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskDelete {
|
||||
_d.mutation.Where(ps...)
|
||||
return _d
|
||||
}
|
||||
|
||||
// Exec executes the deletion query and returns how many vertices were deleted.
|
||||
func (_d *UsageCleanupTaskDelete) Exec(ctx context.Context) (int, error) {
|
||||
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_d *UsageCleanupTaskDelete) ExecX(ctx context.Context) int {
|
||||
n, err := _d.Exec(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (_d *UsageCleanupTaskDelete) sqlExec(ctx context.Context) (int, error) {
|
||||
_spec := sqlgraph.NewDeleteSpec(usagecleanuptask.Table, sqlgraph.NewFieldSpec(usagecleanuptask.FieldID, field.TypeInt64))
|
||||
if ps := _d.mutation.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
|
||||
if err != nil && sqlgraph.IsConstraintError(err) {
|
||||
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||
}
|
||||
_d.mutation.done = true
|
||||
return affected, err
|
||||
}
|
||||
|
||||
// UsageCleanupTaskDeleteOne is the builder for deleting a single UsageCleanupTask entity.
|
||||
type UsageCleanupTaskDeleteOne struct {
|
||||
_d *UsageCleanupTaskDelete
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the UsageCleanupTaskDelete builder.
|
||||
func (_d *UsageCleanupTaskDeleteOne) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskDeleteOne {
|
||||
_d._d.mutation.Where(ps...)
|
||||
return _d
|
||||
}
|
||||
|
||||
// Exec executes the deletion query.
|
||||
func (_d *UsageCleanupTaskDeleteOne) Exec(ctx context.Context) error {
|
||||
n, err := _d._d.Exec(ctx)
|
||||
switch {
|
||||
case err != nil:
|
||||
return err
|
||||
case n == 0:
|
||||
return &NotFoundError{usagecleanuptask.Label}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_d *UsageCleanupTaskDeleteOne) ExecX(ctx context.Context) {
|
||||
if err := _d.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
564
backend/ent/usagecleanuptask_query.go
Normal file
564
backend/ent/usagecleanuptask_query.go
Normal file
@@ -0,0 +1,564 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
|
||||
)
|
||||
|
||||
// UsageCleanupTaskQuery is the builder for querying UsageCleanupTask entities.
|
||||
type UsageCleanupTaskQuery struct {
|
||||
config
|
||||
ctx *QueryContext
|
||||
order []usagecleanuptask.OrderOption
|
||||
inters []Interceptor
|
||||
predicates []predicate.UsageCleanupTask
|
||||
modifiers []func(*sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
}
|
||||
|
||||
// Where adds a new predicate for the UsageCleanupTaskQuery builder.
|
||||
func (_q *UsageCleanupTaskQuery) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskQuery {
|
||||
_q.predicates = append(_q.predicates, ps...)
|
||||
return _q
|
||||
}
|
||||
|
||||
// Limit the number of records to be returned by this query.
|
||||
func (_q *UsageCleanupTaskQuery) Limit(limit int) *UsageCleanupTaskQuery {
|
||||
_q.ctx.Limit = &limit
|
||||
return _q
|
||||
}
|
||||
|
||||
// Offset to start from.
|
||||
func (_q *UsageCleanupTaskQuery) Offset(offset int) *UsageCleanupTaskQuery {
|
||||
_q.ctx.Offset = &offset
|
||||
return _q
|
||||
}
|
||||
|
||||
// Unique configures the query builder to filter duplicate records on query.
|
||||
// By default, unique is set to true, and can be disabled using this method.
|
||||
func (_q *UsageCleanupTaskQuery) Unique(unique bool) *UsageCleanupTaskQuery {
|
||||
_q.ctx.Unique = &unique
|
||||
return _q
|
||||
}
|
||||
|
||||
// Order specifies how the records should be ordered.
|
||||
func (_q *UsageCleanupTaskQuery) Order(o ...usagecleanuptask.OrderOption) *UsageCleanupTaskQuery {
|
||||
_q.order = append(_q.order, o...)
|
||||
return _q
|
||||
}
|
||||
|
||||
// First returns the first UsageCleanupTask entity from the query.
|
||||
// Returns a *NotFoundError when no UsageCleanupTask was found.
|
||||
func (_q *UsageCleanupTaskQuery) First(ctx context.Context) (*UsageCleanupTask, error) {
|
||||
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nodes) == 0 {
|
||||
return nil, &NotFoundError{usagecleanuptask.Label}
|
||||
}
|
||||
return nodes[0], nil
|
||||
}
|
||||
|
||||
// FirstX is like First, but panics if an error occurs.
|
||||
func (_q *UsageCleanupTaskQuery) FirstX(ctx context.Context) *UsageCleanupTask {
|
||||
node, err := _q.First(ctx)
|
||||
if err != nil && !IsNotFound(err) {
|
||||
panic(err)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// FirstID returns the first UsageCleanupTask ID from the query.
|
||||
// Returns a *NotFoundError when no UsageCleanupTask ID was found.
|
||||
func (_q *UsageCleanupTaskQuery) FirstID(ctx context.Context) (id int64, err error) {
|
||||
var ids []int64
|
||||
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
|
||||
return
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
err = &NotFoundError{usagecleanuptask.Label}
|
||||
return
|
||||
}
|
||||
return ids[0], nil
|
||||
}
|
||||
|
||||
// FirstIDX is like FirstID, but panics if an error occurs.
|
||||
func (_q *UsageCleanupTaskQuery) FirstIDX(ctx context.Context) int64 {
|
||||
id, err := _q.FirstID(ctx)
|
||||
if err != nil && !IsNotFound(err) {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// Only returns a single UsageCleanupTask entity found by the query, ensuring it only returns one.
|
||||
// Returns a *NotSingularError when more than one UsageCleanupTask entity is found.
|
||||
// Returns a *NotFoundError when no UsageCleanupTask entities are found.
|
||||
func (_q *UsageCleanupTaskQuery) Only(ctx context.Context) (*UsageCleanupTask, error) {
|
||||
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch len(nodes) {
|
||||
case 1:
|
||||
return nodes[0], nil
|
||||
case 0:
|
||||
return nil, &NotFoundError{usagecleanuptask.Label}
|
||||
default:
|
||||
return nil, &NotSingularError{usagecleanuptask.Label}
|
||||
}
|
||||
}
|
||||
|
||||
// OnlyX is like Only, but panics if an error occurs.
|
||||
func (_q *UsageCleanupTaskQuery) OnlyX(ctx context.Context) *UsageCleanupTask {
|
||||
node, err := _q.Only(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// OnlyID is like Only, but returns the only UsageCleanupTask ID in the query.
|
||||
// Returns a *NotSingularError when more than one UsageCleanupTask ID is found.
|
||||
// Returns a *NotFoundError when no entities are found.
|
||||
func (_q *UsageCleanupTaskQuery) OnlyID(ctx context.Context) (id int64, err error) {
|
||||
var ids []int64
|
||||
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
|
||||
return
|
||||
}
|
||||
switch len(ids) {
|
||||
case 1:
|
||||
id = ids[0]
|
||||
case 0:
|
||||
err = &NotFoundError{usagecleanuptask.Label}
|
||||
default:
|
||||
err = &NotSingularError{usagecleanuptask.Label}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// OnlyIDX is like OnlyID, but panics if an error occurs.
|
||||
func (_q *UsageCleanupTaskQuery) OnlyIDX(ctx context.Context) int64 {
|
||||
id, err := _q.OnlyID(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// All executes the query and returns a list of UsageCleanupTasks.
|
||||
func (_q *UsageCleanupTaskQuery) All(ctx context.Context) ([]*UsageCleanupTask, error) {
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
|
||||
if err := _q.prepareQuery(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qr := querierAll[[]*UsageCleanupTask, *UsageCleanupTaskQuery]()
|
||||
return withInterceptors[[]*UsageCleanupTask](ctx, _q, qr, _q.inters)
|
||||
}
|
||||
|
||||
// AllX is like All, but panics if an error occurs.
|
||||
func (_q *UsageCleanupTaskQuery) AllX(ctx context.Context) []*UsageCleanupTask {
|
||||
nodes, err := _q.All(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return nodes
|
||||
}
|
||||
|
||||
// IDs executes the query and returns a list of UsageCleanupTask IDs.
|
||||
func (_q *UsageCleanupTaskQuery) IDs(ctx context.Context) (ids []int64, err error) {
|
||||
if _q.ctx.Unique == nil && _q.path != nil {
|
||||
_q.Unique(true)
|
||||
}
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
|
||||
if err = _q.Select(usagecleanuptask.FieldID).Scan(ctx, &ids); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// IDsX is like IDs, but panics if an error occurs.
|
||||
func (_q *UsageCleanupTaskQuery) IDsX(ctx context.Context) []int64 {
|
||||
ids, err := _q.IDs(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// Count returns the count of the given query.
|
||||
func (_q *UsageCleanupTaskQuery) Count(ctx context.Context) (int, error) {
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
|
||||
if err := _q.prepareQuery(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return withInterceptors[int](ctx, _q, querierCount[*UsageCleanupTaskQuery](), _q.inters)
|
||||
}
|
||||
|
||||
// CountX is like Count, but panics if an error occurs.
|
||||
func (_q *UsageCleanupTaskQuery) CountX(ctx context.Context) int {
|
||||
count, err := _q.Count(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// Exist returns true if the query has elements in the graph.
|
||||
func (_q *UsageCleanupTaskQuery) Exist(ctx context.Context) (bool, error) {
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
|
||||
switch _, err := _q.FirstID(ctx); {
|
||||
case IsNotFound(err):
|
||||
return false, nil
|
||||
case err != nil:
|
||||
return false, fmt.Errorf("ent: check existence: %w", err)
|
||||
default:
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ExistX is like Exist, but panics if an error occurs.
|
||||
func (_q *UsageCleanupTaskQuery) ExistX(ctx context.Context) bool {
|
||||
exist, err := _q.Exist(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return exist
|
||||
}
|
||||
|
||||
// Clone returns a duplicate of the UsageCleanupTaskQuery builder, including all associated steps. It can be
|
||||
// used to prepare common query builders and use them differently after the clone is made.
|
||||
func (_q *UsageCleanupTaskQuery) Clone() *UsageCleanupTaskQuery {
|
||||
if _q == nil {
|
||||
return nil
|
||||
}
|
||||
return &UsageCleanupTaskQuery{
|
||||
config: _q.config,
|
||||
ctx: _q.ctx.Clone(),
|
||||
order: append([]usagecleanuptask.OrderOption{}, _q.order...),
|
||||
inters: append([]Interceptor{}, _q.inters...),
|
||||
predicates: append([]predicate.UsageCleanupTask{}, _q.predicates...),
|
||||
// clone intermediate query.
|
||||
sql: _q.sql.Clone(),
|
||||
path: _q.path,
|
||||
}
|
||||
}
|
||||
|
||||
// GroupBy is used to group vertices by one or more fields/columns.
|
||||
// It is often used with aggregate functions, like: count, max, mean, min, sum.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// var v []struct {
|
||||
// CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
// Count int `json:"count,omitempty"`
|
||||
// }
|
||||
//
|
||||
// client.UsageCleanupTask.Query().
|
||||
// GroupBy(usagecleanuptask.FieldCreatedAt).
|
||||
// Aggregate(ent.Count()).
|
||||
// Scan(ctx, &v)
|
||||
func (_q *UsageCleanupTaskQuery) GroupBy(field string, fields ...string) *UsageCleanupTaskGroupBy {
|
||||
_q.ctx.Fields = append([]string{field}, fields...)
|
||||
grbuild := &UsageCleanupTaskGroupBy{build: _q}
|
||||
grbuild.flds = &_q.ctx.Fields
|
||||
grbuild.label = usagecleanuptask.Label
|
||||
grbuild.scan = grbuild.Scan
|
||||
return grbuild
|
||||
}
|
||||
|
||||
// Select allows the selection one or more fields/columns for the given query,
|
||||
// instead of selecting all fields in the entity.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// var v []struct {
|
||||
// CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
// }
|
||||
//
|
||||
// client.UsageCleanupTask.Query().
|
||||
// Select(usagecleanuptask.FieldCreatedAt).
|
||||
// Scan(ctx, &v)
|
||||
func (_q *UsageCleanupTaskQuery) Select(fields ...string) *UsageCleanupTaskSelect {
|
||||
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
|
||||
sbuild := &UsageCleanupTaskSelect{UsageCleanupTaskQuery: _q}
|
||||
sbuild.label = usagecleanuptask.Label
|
||||
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
|
||||
return sbuild
|
||||
}
|
||||
|
||||
// Aggregate returns a UsageCleanupTaskSelect configured with the given aggregations.
|
||||
func (_q *UsageCleanupTaskQuery) Aggregate(fns ...AggregateFunc) *UsageCleanupTaskSelect {
|
||||
return _q.Select().Aggregate(fns...)
|
||||
}
|
||||
|
||||
func (_q *UsageCleanupTaskQuery) prepareQuery(ctx context.Context) error {
|
||||
for _, inter := range _q.inters {
|
||||
if inter == nil {
|
||||
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
|
||||
}
|
||||
if trv, ok := inter.(Traverser); ok {
|
||||
if err := trv.Traverse(ctx, _q); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, f := range _q.ctx.Fields {
|
||||
if !usagecleanuptask.ValidColumn(f) {
|
||||
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
|
||||
}
|
||||
}
|
||||
if _q.path != nil {
|
||||
prev, err := _q.path(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_q.sql = prev
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_q *UsageCleanupTaskQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UsageCleanupTask, error) {
|
||||
var (
|
||||
nodes = []*UsageCleanupTask{}
|
||||
_spec = _q.querySpec()
|
||||
)
|
||||
_spec.ScanValues = func(columns []string) ([]any, error) {
|
||||
return (*UsageCleanupTask).scanValues(nil, columns)
|
||||
}
|
||||
_spec.Assign = func(columns []string, values []any) error {
|
||||
node := &UsageCleanupTask{config: _q.config}
|
||||
nodes = append(nodes, node)
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nodes) == 0 {
|
||||
return nodes, nil
|
||||
}
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (_q *UsageCleanupTaskQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := _q.querySpec()
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
_spec.Node.Columns = _q.ctx.Fields
|
||||
if len(_q.ctx.Fields) > 0 {
|
||||
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
|
||||
}
|
||||
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
|
||||
}
|
||||
|
||||
func (_q *UsageCleanupTaskQuery) querySpec() *sqlgraph.QuerySpec {
|
||||
_spec := sqlgraph.NewQuerySpec(usagecleanuptask.Table, usagecleanuptask.Columns, sqlgraph.NewFieldSpec(usagecleanuptask.FieldID, field.TypeInt64))
|
||||
_spec.From = _q.sql
|
||||
if unique := _q.ctx.Unique; unique != nil {
|
||||
_spec.Unique = *unique
|
||||
} else if _q.path != nil {
|
||||
_spec.Unique = true
|
||||
}
|
||||
if fields := _q.ctx.Fields; len(fields) > 0 {
|
||||
_spec.Node.Columns = make([]string, 0, len(fields))
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, usagecleanuptask.FieldID)
|
||||
for i := range fields {
|
||||
if fields[i] != usagecleanuptask.FieldID {
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
if ps := _q.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
if limit := _q.ctx.Limit; limit != nil {
|
||||
_spec.Limit = *limit
|
||||
}
|
||||
if offset := _q.ctx.Offset; offset != nil {
|
||||
_spec.Offset = *offset
|
||||
}
|
||||
if ps := _q.order; len(ps) > 0 {
|
||||
_spec.Order = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
return _spec
|
||||
}
|
||||
|
||||
func (_q *UsageCleanupTaskQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
builder := sql.Dialect(_q.driver.Dialect())
|
||||
t1 := builder.Table(usagecleanuptask.Table)
|
||||
columns := _q.ctx.Fields
|
||||
if len(columns) == 0 {
|
||||
columns = usagecleanuptask.Columns
|
||||
}
|
||||
selector := builder.Select(t1.Columns(columns...)...).From(t1)
|
||||
if _q.sql != nil {
|
||||
selector = _q.sql
|
||||
selector.Select(selector.Columns(columns...)...)
|
||||
}
|
||||
if _q.ctx.Unique != nil && *_q.ctx.Unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range _q.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range _q.predicates {
|
||||
p(selector)
|
||||
}
|
||||
for _, p := range _q.order {
|
||||
p(selector)
|
||||
}
|
||||
if offset := _q.ctx.Offset; offset != nil {
|
||||
// limit is mandatory for offset clause. We start
|
||||
// with default value, and override it below if needed.
|
||||
selector.Offset(*offset).Limit(math.MaxInt32)
|
||||
}
|
||||
if limit := _q.ctx.Limit; limit != nil {
|
||||
selector.Limit(*limit)
|
||||
}
|
||||
return selector
|
||||
}
|
||||
|
||||
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
|
||||
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
|
||||
// either committed or rolled-back.
|
||||
func (_q *UsageCleanupTaskQuery) ForUpdate(opts ...sql.LockOption) *UsageCleanupTaskQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForUpdate(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
|
||||
// on any rows that are read. Other sessions can read the rows, but cannot modify them
|
||||
// until your transaction commits.
|
||||
func (_q *UsageCleanupTaskQuery) ForShare(opts ...sql.LockOption) *UsageCleanupTaskQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForShare(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// UsageCleanupTaskGroupBy is the group-by builder for UsageCleanupTask entities.
|
||||
type UsageCleanupTaskGroupBy struct {
|
||||
selector
|
||||
build *UsageCleanupTaskQuery
|
||||
}
|
||||
|
||||
// Aggregate adds the given aggregation functions to the group-by query.
|
||||
func (_g *UsageCleanupTaskGroupBy) Aggregate(fns ...AggregateFunc) *UsageCleanupTaskGroupBy {
|
||||
_g.fns = append(_g.fns, fns...)
|
||||
return _g
|
||||
}
|
||||
|
||||
// Scan applies the selector query and scans the result into the given value.
|
||||
func (_g *UsageCleanupTaskGroupBy) Scan(ctx context.Context, v any) error {
|
||||
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
|
||||
if err := _g.build.prepareQuery(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return scanWithInterceptors[*UsageCleanupTaskQuery, *UsageCleanupTaskGroupBy](ctx, _g.build, _g, _g.build.inters, v)
|
||||
}
|
||||
|
||||
func (_g *UsageCleanupTaskGroupBy) sqlScan(ctx context.Context, root *UsageCleanupTaskQuery, v any) error {
|
||||
selector := root.sqlQuery(ctx).Select()
|
||||
aggregation := make([]string, 0, len(_g.fns))
|
||||
for _, fn := range _g.fns {
|
||||
aggregation = append(aggregation, fn(selector))
|
||||
}
|
||||
if len(selector.SelectedColumns()) == 0 {
|
||||
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
|
||||
for _, f := range *_g.flds {
|
||||
columns = append(columns, selector.C(f))
|
||||
}
|
||||
columns = append(columns, aggregation...)
|
||||
selector.Select(columns...)
|
||||
}
|
||||
selector.GroupBy(selector.Columns(*_g.flds...)...)
|
||||
if err := selector.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := selector.Query()
|
||||
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
return sql.ScanSlice(rows, v)
|
||||
}
|
||||
|
||||
// UsageCleanupTaskSelect is the builder for selecting fields of UsageCleanupTask entities.
|
||||
type UsageCleanupTaskSelect struct {
|
||||
*UsageCleanupTaskQuery
|
||||
selector
|
||||
}
|
||||
|
||||
// Aggregate adds the given aggregation functions to the selector query.
|
||||
func (_s *UsageCleanupTaskSelect) Aggregate(fns ...AggregateFunc) *UsageCleanupTaskSelect {
|
||||
_s.fns = append(_s.fns, fns...)
|
||||
return _s
|
||||
}
|
||||
|
||||
// Scan applies the selector query and scans the result into the given value.
|
||||
func (_s *UsageCleanupTaskSelect) Scan(ctx context.Context, v any) error {
|
||||
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
|
||||
if err := _s.prepareQuery(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return scanWithInterceptors[*UsageCleanupTaskQuery, *UsageCleanupTaskSelect](ctx, _s.UsageCleanupTaskQuery, _s, _s.inters, v)
|
||||
}
|
||||
|
||||
func (_s *UsageCleanupTaskSelect) sqlScan(ctx context.Context, root *UsageCleanupTaskQuery, v any) error {
|
||||
selector := root.sqlQuery(ctx)
|
||||
aggregation := make([]string, 0, len(_s.fns))
|
||||
for _, fn := range _s.fns {
|
||||
aggregation = append(aggregation, fn(selector))
|
||||
}
|
||||
switch n := len(*_s.selector.flds); {
|
||||
case n == 0 && len(aggregation) > 0:
|
||||
selector.Select(aggregation...)
|
||||
case n != 0 && len(aggregation) > 0:
|
||||
selector.AppendSelect(aggregation...)
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := selector.Query()
|
||||
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
return sql.ScanSlice(rows, v)
|
||||
}
|
||||
702
backend/ent/usagecleanuptask_update.go
Normal file
702
backend/ent/usagecleanuptask_update.go
Normal file
@@ -0,0 +1,702 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/dialect/sql/sqljson"
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
|
||||
)
|
||||
|
||||
// UsageCleanupTaskUpdate is the builder for updating UsageCleanupTask entities.
|
||||
type UsageCleanupTaskUpdate struct {
|
||||
config
|
||||
hooks []Hook
|
||||
mutation *UsageCleanupTaskMutation
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the UsageCleanupTaskUpdate builder.
|
||||
func (_u *UsageCleanupTaskUpdate) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskUpdate {
|
||||
_u.mutation.Where(ps...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUpdatedAt sets the "updated_at" field.
|
||||
func (_u *UsageCleanupTaskUpdate) SetUpdatedAt(v time.Time) *UsageCleanupTaskUpdate {
|
||||
_u.mutation.SetUpdatedAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetStatus sets the "status" field.
|
||||
func (_u *UsageCleanupTaskUpdate) SetStatus(v string) *UsageCleanupTaskUpdate {
|
||||
_u.mutation.SetStatus(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableStatus sets the "status" field if the given value is not nil.
|
||||
func (_u *UsageCleanupTaskUpdate) SetNillableStatus(v *string) *UsageCleanupTaskUpdate {
|
||||
if v != nil {
|
||||
_u.SetStatus(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetFilters sets the "filters" field.
|
||||
func (_u *UsageCleanupTaskUpdate) SetFilters(v json.RawMessage) *UsageCleanupTaskUpdate {
|
||||
_u.mutation.SetFilters(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendFilters appends value to the "filters" field.
|
||||
func (_u *UsageCleanupTaskUpdate) AppendFilters(v json.RawMessage) *UsageCleanupTaskUpdate {
|
||||
_u.mutation.AppendFilters(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetCreatedBy sets the "created_by" field.
|
||||
func (_u *UsageCleanupTaskUpdate) SetCreatedBy(v int64) *UsageCleanupTaskUpdate {
|
||||
_u.mutation.ResetCreatedBy()
|
||||
_u.mutation.SetCreatedBy(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableCreatedBy sets the "created_by" field if the given value is not nil.
|
||||
func (_u *UsageCleanupTaskUpdate) SetNillableCreatedBy(v *int64) *UsageCleanupTaskUpdate {
|
||||
if v != nil {
|
||||
_u.SetCreatedBy(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddCreatedBy adds value to the "created_by" field.
|
||||
func (_u *UsageCleanupTaskUpdate) AddCreatedBy(v int64) *UsageCleanupTaskUpdate {
|
||||
_u.mutation.AddCreatedBy(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDeletedRows sets the "deleted_rows" field.
|
||||
func (_u *UsageCleanupTaskUpdate) SetDeletedRows(v int64) *UsageCleanupTaskUpdate {
|
||||
_u.mutation.ResetDeletedRows()
|
||||
_u.mutation.SetDeletedRows(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDeletedRows sets the "deleted_rows" field if the given value is not nil.
|
||||
func (_u *UsageCleanupTaskUpdate) SetNillableDeletedRows(v *int64) *UsageCleanupTaskUpdate {
|
||||
if v != nil {
|
||||
_u.SetDeletedRows(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddDeletedRows adds value to the "deleted_rows" field.
|
||||
func (_u *UsageCleanupTaskUpdate) AddDeletedRows(v int64) *UsageCleanupTaskUpdate {
|
||||
_u.mutation.AddDeletedRows(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetErrorMessage sets the "error_message" field.
|
||||
func (_u *UsageCleanupTaskUpdate) SetErrorMessage(v string) *UsageCleanupTaskUpdate {
|
||||
_u.mutation.SetErrorMessage(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableErrorMessage sets the "error_message" field if the given value is not nil.
|
||||
func (_u *UsageCleanupTaskUpdate) SetNillableErrorMessage(v *string) *UsageCleanupTaskUpdate {
|
||||
if v != nil {
|
||||
_u.SetErrorMessage(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearErrorMessage clears the value of the "error_message" field.
|
||||
func (_u *UsageCleanupTaskUpdate) ClearErrorMessage() *UsageCleanupTaskUpdate {
|
||||
_u.mutation.ClearErrorMessage()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetCanceledBy sets the "canceled_by" field.
|
||||
func (_u *UsageCleanupTaskUpdate) SetCanceledBy(v int64) *UsageCleanupTaskUpdate {
|
||||
_u.mutation.ResetCanceledBy()
|
||||
_u.mutation.SetCanceledBy(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableCanceledBy sets the "canceled_by" field if the given value is not nil.
|
||||
func (_u *UsageCleanupTaskUpdate) SetNillableCanceledBy(v *int64) *UsageCleanupTaskUpdate {
|
||||
if v != nil {
|
||||
_u.SetCanceledBy(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddCanceledBy adds value to the "canceled_by" field.
|
||||
func (_u *UsageCleanupTaskUpdate) AddCanceledBy(v int64) *UsageCleanupTaskUpdate {
|
||||
_u.mutation.AddCanceledBy(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearCanceledBy clears the value of the "canceled_by" field.
|
||||
func (_u *UsageCleanupTaskUpdate) ClearCanceledBy() *UsageCleanupTaskUpdate {
|
||||
_u.mutation.ClearCanceledBy()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetCanceledAt sets the "canceled_at" field.
|
||||
func (_u *UsageCleanupTaskUpdate) SetCanceledAt(v time.Time) *UsageCleanupTaskUpdate {
|
||||
_u.mutation.SetCanceledAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableCanceledAt sets the "canceled_at" field if the given value is not nil.
|
||||
func (_u *UsageCleanupTaskUpdate) SetNillableCanceledAt(v *time.Time) *UsageCleanupTaskUpdate {
|
||||
if v != nil {
|
||||
_u.SetCanceledAt(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearCanceledAt clears the value of the "canceled_at" field.
|
||||
func (_u *UsageCleanupTaskUpdate) ClearCanceledAt() *UsageCleanupTaskUpdate {
|
||||
_u.mutation.ClearCanceledAt()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetStartedAt sets the "started_at" field.
|
||||
func (_u *UsageCleanupTaskUpdate) SetStartedAt(v time.Time) *UsageCleanupTaskUpdate {
|
||||
_u.mutation.SetStartedAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableStartedAt sets the "started_at" field if the given value is not nil.
|
||||
func (_u *UsageCleanupTaskUpdate) SetNillableStartedAt(v *time.Time) *UsageCleanupTaskUpdate {
|
||||
if v != nil {
|
||||
_u.SetStartedAt(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearStartedAt clears the value of the "started_at" field.
|
||||
func (_u *UsageCleanupTaskUpdate) ClearStartedAt() *UsageCleanupTaskUpdate {
|
||||
_u.mutation.ClearStartedAt()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetFinishedAt sets the "finished_at" field.
|
||||
func (_u *UsageCleanupTaskUpdate) SetFinishedAt(v time.Time) *UsageCleanupTaskUpdate {
|
||||
_u.mutation.SetFinishedAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableFinishedAt sets the "finished_at" field if the given value is not nil.
|
||||
func (_u *UsageCleanupTaskUpdate) SetNillableFinishedAt(v *time.Time) *UsageCleanupTaskUpdate {
|
||||
if v != nil {
|
||||
_u.SetFinishedAt(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearFinishedAt clears the value of the "finished_at" field.
|
||||
func (_u *UsageCleanupTaskUpdate) ClearFinishedAt() *UsageCleanupTaskUpdate {
|
||||
_u.mutation.ClearFinishedAt()
|
||||
return _u
|
||||
}
|
||||
|
||||
// Mutation returns the UsageCleanupTaskMutation object of the builder.
|
||||
func (_u *UsageCleanupTaskUpdate) Mutation() *UsageCleanupTaskMutation {
|
||||
return _u.mutation
|
||||
}
|
||||
|
||||
// Save executes the query and returns the number of nodes affected by the update operation.
|
||||
func (_u *UsageCleanupTaskUpdate) Save(ctx context.Context) (int, error) {
|
||||
_u.defaults()
|
||||
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
|
||||
}
|
||||
|
||||
// SaveX is like Save, but panics if an error occurs.
|
||||
func (_u *UsageCleanupTaskUpdate) SaveX(ctx context.Context) int {
|
||||
affected, err := _u.Save(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return affected
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (_u *UsageCleanupTaskUpdate) Exec(ctx context.Context) error {
|
||||
_, err := _u.Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_u *UsageCleanupTaskUpdate) ExecX(ctx context.Context) {
|
||||
if err := _u.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// defaults sets the default values of the builder before save.
|
||||
func (_u *UsageCleanupTaskUpdate) defaults() {
|
||||
if _, ok := _u.mutation.UpdatedAt(); !ok {
|
||||
v := usagecleanuptask.UpdateDefaultUpdatedAt()
|
||||
_u.mutation.SetUpdatedAt(v)
|
||||
}
|
||||
}
|
||||
|
||||
// check runs all checks and user-defined validators on the builder.
|
||||
func (_u *UsageCleanupTaskUpdate) check() error {
|
||||
if v, ok := _u.mutation.Status(); ok {
|
||||
if err := usagecleanuptask.StatusValidator(v); err != nil {
|
||||
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "UsageCleanupTask.status": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_u *UsageCleanupTaskUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if err := _u.check(); err != nil {
|
||||
return _node, err
|
||||
}
|
||||
_spec := sqlgraph.NewUpdateSpec(usagecleanuptask.Table, usagecleanuptask.Columns, sqlgraph.NewFieldSpec(usagecleanuptask.FieldID, field.TypeInt64))
|
||||
if ps := _u.mutation.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
if value, ok := _u.mutation.UpdatedAt(); ok {
|
||||
_spec.SetField(usagecleanuptask.FieldUpdatedAt, field.TypeTime, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Status(); ok {
|
||||
_spec.SetField(usagecleanuptask.FieldStatus, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Filters(); ok {
|
||||
_spec.SetField(usagecleanuptask.FieldFilters, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedFilters(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, usagecleanuptask.FieldFilters, value)
|
||||
})
|
||||
}
|
||||
if value, ok := _u.mutation.CreatedBy(); ok {
|
||||
_spec.SetField(usagecleanuptask.FieldCreatedBy, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedCreatedBy(); ok {
|
||||
_spec.AddField(usagecleanuptask.FieldCreatedBy, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.DeletedRows(); ok {
|
||||
_spec.SetField(usagecleanuptask.FieldDeletedRows, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedDeletedRows(); ok {
|
||||
_spec.AddField(usagecleanuptask.FieldDeletedRows, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ErrorMessage(); ok {
|
||||
_spec.SetField(usagecleanuptask.FieldErrorMessage, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.ErrorMessageCleared() {
|
||||
_spec.ClearField(usagecleanuptask.FieldErrorMessage, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.CanceledBy(); ok {
|
||||
_spec.SetField(usagecleanuptask.FieldCanceledBy, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedCanceledBy(); ok {
|
||||
_spec.AddField(usagecleanuptask.FieldCanceledBy, field.TypeInt64, value)
|
||||
}
|
||||
if _u.mutation.CanceledByCleared() {
|
||||
_spec.ClearField(usagecleanuptask.FieldCanceledBy, field.TypeInt64)
|
||||
}
|
||||
if value, ok := _u.mutation.CanceledAt(); ok {
|
||||
_spec.SetField(usagecleanuptask.FieldCanceledAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.CanceledAtCleared() {
|
||||
_spec.ClearField(usagecleanuptask.FieldCanceledAt, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.StartedAt(); ok {
|
||||
_spec.SetField(usagecleanuptask.FieldStartedAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.StartedAtCleared() {
|
||||
_spec.ClearField(usagecleanuptask.FieldStartedAt, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.FinishedAt(); ok {
|
||||
_spec.SetField(usagecleanuptask.FieldFinishedAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.FinishedAtCleared() {
|
||||
_spec.ClearField(usagecleanuptask.FieldFinishedAt, field.TypeTime)
|
||||
}
|
||||
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
|
||||
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||
err = &NotFoundError{usagecleanuptask.Label}
|
||||
} else if sqlgraph.IsConstraintError(err) {
|
||||
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
_u.mutation.done = true
|
||||
return _node, nil
|
||||
}
|
||||
|
||||
// UsageCleanupTaskUpdateOne is the builder for updating a single UsageCleanupTask entity.
|
||||
type UsageCleanupTaskUpdateOne struct {
|
||||
config
|
||||
fields []string
|
||||
hooks []Hook
|
||||
mutation *UsageCleanupTaskMutation
|
||||
}
|
||||
|
||||
// SetUpdatedAt sets the "updated_at" field.
|
||||
func (_u *UsageCleanupTaskUpdateOne) SetUpdatedAt(v time.Time) *UsageCleanupTaskUpdateOne {
|
||||
_u.mutation.SetUpdatedAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetStatus sets the "status" field.
|
||||
func (_u *UsageCleanupTaskUpdateOne) SetStatus(v string) *UsageCleanupTaskUpdateOne {
|
||||
_u.mutation.SetStatus(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableStatus sets the "status" field if the given value is not nil.
|
||||
func (_u *UsageCleanupTaskUpdateOne) SetNillableStatus(v *string) *UsageCleanupTaskUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetStatus(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetFilters sets the "filters" field.
|
||||
func (_u *UsageCleanupTaskUpdateOne) SetFilters(v json.RawMessage) *UsageCleanupTaskUpdateOne {
|
||||
_u.mutation.SetFilters(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AppendFilters appends value to the "filters" field.
|
||||
func (_u *UsageCleanupTaskUpdateOne) AppendFilters(v json.RawMessage) *UsageCleanupTaskUpdateOne {
|
||||
_u.mutation.AppendFilters(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetCreatedBy sets the "created_by" field.
|
||||
func (_u *UsageCleanupTaskUpdateOne) SetCreatedBy(v int64) *UsageCleanupTaskUpdateOne {
|
||||
_u.mutation.ResetCreatedBy()
|
||||
_u.mutation.SetCreatedBy(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableCreatedBy sets the "created_by" field if the given value is not nil.
|
||||
func (_u *UsageCleanupTaskUpdateOne) SetNillableCreatedBy(v *int64) *UsageCleanupTaskUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetCreatedBy(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddCreatedBy adds value to the "created_by" field.
|
||||
func (_u *UsageCleanupTaskUpdateOne) AddCreatedBy(v int64) *UsageCleanupTaskUpdateOne {
|
||||
_u.mutation.AddCreatedBy(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDeletedRows sets the "deleted_rows" field.
|
||||
func (_u *UsageCleanupTaskUpdateOne) SetDeletedRows(v int64) *UsageCleanupTaskUpdateOne {
|
||||
_u.mutation.ResetDeletedRows()
|
||||
_u.mutation.SetDeletedRows(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDeletedRows sets the "deleted_rows" field if the given value is not nil.
|
||||
func (_u *UsageCleanupTaskUpdateOne) SetNillableDeletedRows(v *int64) *UsageCleanupTaskUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetDeletedRows(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddDeletedRows adds value to the "deleted_rows" field.
|
||||
func (_u *UsageCleanupTaskUpdateOne) AddDeletedRows(v int64) *UsageCleanupTaskUpdateOne {
|
||||
_u.mutation.AddDeletedRows(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetErrorMessage sets the "error_message" field.
|
||||
func (_u *UsageCleanupTaskUpdateOne) SetErrorMessage(v string) *UsageCleanupTaskUpdateOne {
|
||||
_u.mutation.SetErrorMessage(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableErrorMessage sets the "error_message" field if the given value is not nil.
|
||||
func (_u *UsageCleanupTaskUpdateOne) SetNillableErrorMessage(v *string) *UsageCleanupTaskUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetErrorMessage(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearErrorMessage clears the value of the "error_message" field.
|
||||
func (_u *UsageCleanupTaskUpdateOne) ClearErrorMessage() *UsageCleanupTaskUpdateOne {
|
||||
_u.mutation.ClearErrorMessage()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetCanceledBy sets the "canceled_by" field.
|
||||
func (_u *UsageCleanupTaskUpdateOne) SetCanceledBy(v int64) *UsageCleanupTaskUpdateOne {
|
||||
_u.mutation.ResetCanceledBy()
|
||||
_u.mutation.SetCanceledBy(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableCanceledBy sets the "canceled_by" field if the given value is not nil.
|
||||
func (_u *UsageCleanupTaskUpdateOne) SetNillableCanceledBy(v *int64) *UsageCleanupTaskUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetCanceledBy(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddCanceledBy adds value to the "canceled_by" field.
|
||||
func (_u *UsageCleanupTaskUpdateOne) AddCanceledBy(v int64) *UsageCleanupTaskUpdateOne {
|
||||
_u.mutation.AddCanceledBy(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearCanceledBy clears the value of the "canceled_by" field.
|
||||
func (_u *UsageCleanupTaskUpdateOne) ClearCanceledBy() *UsageCleanupTaskUpdateOne {
|
||||
_u.mutation.ClearCanceledBy()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetCanceledAt sets the "canceled_at" field.
|
||||
func (_u *UsageCleanupTaskUpdateOne) SetCanceledAt(v time.Time) *UsageCleanupTaskUpdateOne {
|
||||
_u.mutation.SetCanceledAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableCanceledAt sets the "canceled_at" field if the given value is not nil.
|
||||
func (_u *UsageCleanupTaskUpdateOne) SetNillableCanceledAt(v *time.Time) *UsageCleanupTaskUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetCanceledAt(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearCanceledAt clears the value of the "canceled_at" field.
|
||||
func (_u *UsageCleanupTaskUpdateOne) ClearCanceledAt() *UsageCleanupTaskUpdateOne {
|
||||
_u.mutation.ClearCanceledAt()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetStartedAt sets the "started_at" field.
|
||||
func (_u *UsageCleanupTaskUpdateOne) SetStartedAt(v time.Time) *UsageCleanupTaskUpdateOne {
|
||||
_u.mutation.SetStartedAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableStartedAt sets the "started_at" field if the given value is not nil.
|
||||
func (_u *UsageCleanupTaskUpdateOne) SetNillableStartedAt(v *time.Time) *UsageCleanupTaskUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetStartedAt(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearStartedAt clears the value of the "started_at" field.
|
||||
func (_u *UsageCleanupTaskUpdateOne) ClearStartedAt() *UsageCleanupTaskUpdateOne {
|
||||
_u.mutation.ClearStartedAt()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetFinishedAt sets the "finished_at" field.
|
||||
func (_u *UsageCleanupTaskUpdateOne) SetFinishedAt(v time.Time) *UsageCleanupTaskUpdateOne {
|
||||
_u.mutation.SetFinishedAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableFinishedAt sets the "finished_at" field if the given value is not nil.
|
||||
func (_u *UsageCleanupTaskUpdateOne) SetNillableFinishedAt(v *time.Time) *UsageCleanupTaskUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetFinishedAt(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearFinishedAt clears the value of the "finished_at" field.
|
||||
func (_u *UsageCleanupTaskUpdateOne) ClearFinishedAt() *UsageCleanupTaskUpdateOne {
|
||||
_u.mutation.ClearFinishedAt()
|
||||
return _u
|
||||
}
|
||||
|
||||
// Mutation returns the UsageCleanupTaskMutation object of the builder.
|
||||
func (_u *UsageCleanupTaskUpdateOne) Mutation() *UsageCleanupTaskMutation {
|
||||
return _u.mutation
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the UsageCleanupTaskUpdate builder.
|
||||
func (_u *UsageCleanupTaskUpdateOne) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskUpdateOne {
|
||||
_u.mutation.Where(ps...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// Select allows selecting one or more fields (columns) of the returned entity.
|
||||
// The default is selecting all fields defined in the entity schema.
|
||||
func (_u *UsageCleanupTaskUpdateOne) Select(field string, fields ...string) *UsageCleanupTaskUpdateOne {
|
||||
_u.fields = append([]string{field}, fields...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// Save executes the query and returns the updated UsageCleanupTask entity.
|
||||
func (_u *UsageCleanupTaskUpdateOne) Save(ctx context.Context) (*UsageCleanupTask, error) {
|
||||
_u.defaults()
|
||||
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
|
||||
}
|
||||
|
||||
// SaveX is like Save, but panics if an error occurs.
|
||||
func (_u *UsageCleanupTaskUpdateOne) SaveX(ctx context.Context) *UsageCleanupTask {
|
||||
node, err := _u.Save(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// Exec executes the query on the entity.
|
||||
func (_u *UsageCleanupTaskUpdateOne) Exec(ctx context.Context) error {
|
||||
_, err := _u.Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_u *UsageCleanupTaskUpdateOne) ExecX(ctx context.Context) {
|
||||
if err := _u.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// defaults sets the default values of the builder before save.
|
||||
func (_u *UsageCleanupTaskUpdateOne) defaults() {
|
||||
if _, ok := _u.mutation.UpdatedAt(); !ok {
|
||||
v := usagecleanuptask.UpdateDefaultUpdatedAt()
|
||||
_u.mutation.SetUpdatedAt(v)
|
||||
}
|
||||
}
|
||||
|
||||
// check runs all checks and user-defined validators on the builder.
|
||||
func (_u *UsageCleanupTaskUpdateOne) check() error {
|
||||
if v, ok := _u.mutation.Status(); ok {
|
||||
if err := usagecleanuptask.StatusValidator(v); err != nil {
|
||||
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "UsageCleanupTask.status": %w`, err)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_u *UsageCleanupTaskUpdateOne) sqlSave(ctx context.Context) (_node *UsageCleanupTask, err error) {
|
||||
if err := _u.check(); err != nil {
|
||||
return _node, err
|
||||
}
|
||||
_spec := sqlgraph.NewUpdateSpec(usagecleanuptask.Table, usagecleanuptask.Columns, sqlgraph.NewFieldSpec(usagecleanuptask.FieldID, field.TypeInt64))
|
||||
id, ok := _u.mutation.ID()
|
||||
if !ok {
|
||||
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "UsageCleanupTask.id" for update`)}
|
||||
}
|
||||
_spec.Node.ID.Value = id
|
||||
if fields := _u.fields; len(fields) > 0 {
|
||||
_spec.Node.Columns = make([]string, 0, len(fields))
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, usagecleanuptask.FieldID)
|
||||
for _, f := range fields {
|
||||
if !usagecleanuptask.ValidColumn(f) {
|
||||
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
|
||||
}
|
||||
if f != usagecleanuptask.FieldID {
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, f)
|
||||
}
|
||||
}
|
||||
}
|
||||
if ps := _u.mutation.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
if value, ok := _u.mutation.UpdatedAt(); ok {
|
||||
_spec.SetField(usagecleanuptask.FieldUpdatedAt, field.TypeTime, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Status(); ok {
|
||||
_spec.SetField(usagecleanuptask.FieldStatus, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Filters(); ok {
|
||||
_spec.SetField(usagecleanuptask.FieldFilters, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AppendedFilters(); ok {
|
||||
_spec.AddModifier(func(u *sql.UpdateBuilder) {
|
||||
sqljson.Append(u, usagecleanuptask.FieldFilters, value)
|
||||
})
|
||||
}
|
||||
if value, ok := _u.mutation.CreatedBy(); ok {
|
||||
_spec.SetField(usagecleanuptask.FieldCreatedBy, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedCreatedBy(); ok {
|
||||
_spec.AddField(usagecleanuptask.FieldCreatedBy, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.DeletedRows(); ok {
|
||||
_spec.SetField(usagecleanuptask.FieldDeletedRows, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedDeletedRows(); ok {
|
||||
_spec.AddField(usagecleanuptask.FieldDeletedRows, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ErrorMessage(); ok {
|
||||
_spec.SetField(usagecleanuptask.FieldErrorMessage, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.ErrorMessageCleared() {
|
||||
_spec.ClearField(usagecleanuptask.FieldErrorMessage, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.CanceledBy(); ok {
|
||||
_spec.SetField(usagecleanuptask.FieldCanceledBy, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedCanceledBy(); ok {
|
||||
_spec.AddField(usagecleanuptask.FieldCanceledBy, field.TypeInt64, value)
|
||||
}
|
||||
if _u.mutation.CanceledByCleared() {
|
||||
_spec.ClearField(usagecleanuptask.FieldCanceledBy, field.TypeInt64)
|
||||
}
|
||||
if value, ok := _u.mutation.CanceledAt(); ok {
|
||||
_spec.SetField(usagecleanuptask.FieldCanceledAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.CanceledAtCleared() {
|
||||
_spec.ClearField(usagecleanuptask.FieldCanceledAt, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.StartedAt(); ok {
|
||||
_spec.SetField(usagecleanuptask.FieldStartedAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.StartedAtCleared() {
|
||||
_spec.ClearField(usagecleanuptask.FieldStartedAt, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.FinishedAt(); ok {
|
||||
_spec.SetField(usagecleanuptask.FieldFinishedAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.FinishedAtCleared() {
|
||||
_spec.ClearField(usagecleanuptask.FieldFinishedAt, field.TypeTime)
|
||||
}
|
||||
_node = &UsageCleanupTask{config: _u.config}
|
||||
_spec.Assign = _node.assignValues
|
||||
_spec.ScanValues = _node.scanValues
|
||||
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
|
||||
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||
err = &NotFoundError{usagecleanuptask.Label}
|
||||
} else if sqlgraph.IsConstraintError(err) {
|
||||
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_u.mutation.done = true
|
||||
return _node, nil
|
||||
}
|
||||
@@ -62,6 +62,8 @@ type UsageLog struct {
|
||||
ActualCost float64 `json:"actual_cost,omitempty"`
|
||||
// RateMultiplier holds the value of the "rate_multiplier" field.
|
||||
RateMultiplier float64 `json:"rate_multiplier,omitempty"`
|
||||
// AccountRateMultiplier holds the value of the "account_rate_multiplier" field.
|
||||
AccountRateMultiplier *float64 `json:"account_rate_multiplier,omitempty"`
|
||||
// BillingType holds the value of the "billing_type" field.
|
||||
BillingType int8 `json:"billing_type,omitempty"`
|
||||
// Stream holds the value of the "stream" field.
|
||||
@@ -165,7 +167,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
|
||||
switch columns[i] {
|
||||
case usagelog.FieldStream:
|
||||
values[i] = new(sql.NullBool)
|
||||
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier:
|
||||
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
||||
values[i] = new(sql.NullInt64)
|
||||
@@ -316,6 +318,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.RateMultiplier = value.Float64
|
||||
}
|
||||
case usagelog.FieldAccountRateMultiplier:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field account_rate_multiplier", values[i])
|
||||
} else if value.Valid {
|
||||
_m.AccountRateMultiplier = new(float64)
|
||||
*_m.AccountRateMultiplier = value.Float64
|
||||
}
|
||||
case usagelog.FieldBillingType:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field billing_type", values[i])
|
||||
@@ -500,6 +509,11 @@ func (_m *UsageLog) String() string {
|
||||
builder.WriteString("rate_multiplier=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.RateMultiplier))
|
||||
builder.WriteString(", ")
|
||||
if v := _m.AccountRateMultiplier; v != nil {
|
||||
builder.WriteString("account_rate_multiplier=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("billing_type=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.BillingType))
|
||||
builder.WriteString(", ")
|
||||
|
||||
@@ -54,6 +54,8 @@ const (
|
||||
FieldActualCost = "actual_cost"
|
||||
// FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
|
||||
FieldRateMultiplier = "rate_multiplier"
|
||||
// FieldAccountRateMultiplier holds the string denoting the account_rate_multiplier field in the database.
|
||||
FieldAccountRateMultiplier = "account_rate_multiplier"
|
||||
// FieldBillingType holds the string denoting the billing_type field in the database.
|
||||
FieldBillingType = "billing_type"
|
||||
// FieldStream holds the string denoting the stream field in the database.
|
||||
@@ -144,6 +146,7 @@ var Columns = []string{
|
||||
FieldTotalCost,
|
||||
FieldActualCost,
|
||||
FieldRateMultiplier,
|
||||
FieldAccountRateMultiplier,
|
||||
FieldBillingType,
|
||||
FieldStream,
|
||||
FieldDurationMs,
|
||||
@@ -320,6 +323,11 @@ func ByRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldRateMultiplier, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAccountRateMultiplier orders the results by the account_rate_multiplier field.
|
||||
func ByAccountRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldAccountRateMultiplier, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByBillingType orders the results by the billing_type field.
|
||||
func ByBillingType(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldBillingType, opts...).ToFunc()
|
||||
|
||||
@@ -155,6 +155,11 @@ func RateMultiplier(v float64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldRateMultiplier, v))
|
||||
}
|
||||
|
||||
// AccountRateMultiplier applies equality check predicate on the "account_rate_multiplier" field. It's identical to AccountRateMultiplierEQ.
|
||||
func AccountRateMultiplier(v float64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldAccountRateMultiplier, v))
|
||||
}
|
||||
|
||||
// BillingType applies equality check predicate on the "billing_type" field. It's identical to BillingTypeEQ.
|
||||
func BillingType(v int8) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldBillingType, v))
|
||||
@@ -970,6 +975,56 @@ func RateMultiplierLTE(v float64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldRateMultiplier, v))
|
||||
}
|
||||
|
||||
// AccountRateMultiplierEQ applies the EQ predicate on the "account_rate_multiplier" field.
|
||||
func AccountRateMultiplierEQ(v float64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldAccountRateMultiplier, v))
|
||||
}
|
||||
|
||||
// AccountRateMultiplierNEQ applies the NEQ predicate on the "account_rate_multiplier" field.
|
||||
func AccountRateMultiplierNEQ(v float64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldAccountRateMultiplier, v))
|
||||
}
|
||||
|
||||
// AccountRateMultiplierIn applies the In predicate on the "account_rate_multiplier" field.
|
||||
func AccountRateMultiplierIn(vs ...float64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIn(FieldAccountRateMultiplier, vs...))
|
||||
}
|
||||
|
||||
// AccountRateMultiplierNotIn applies the NotIn predicate on the "account_rate_multiplier" field.
|
||||
func AccountRateMultiplierNotIn(vs ...float64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotIn(FieldAccountRateMultiplier, vs...))
|
||||
}
|
||||
|
||||
// AccountRateMultiplierGT applies the GT predicate on the "account_rate_multiplier" field.
|
||||
func AccountRateMultiplierGT(v float64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGT(FieldAccountRateMultiplier, v))
|
||||
}
|
||||
|
||||
// AccountRateMultiplierGTE applies the GTE predicate on the "account_rate_multiplier" field.
|
||||
func AccountRateMultiplierGTE(v float64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGTE(FieldAccountRateMultiplier, v))
|
||||
}
|
||||
|
||||
// AccountRateMultiplierLT applies the LT predicate on the "account_rate_multiplier" field.
|
||||
func AccountRateMultiplierLT(v float64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLT(FieldAccountRateMultiplier, v))
|
||||
}
|
||||
|
||||
// AccountRateMultiplierLTE applies the LTE predicate on the "account_rate_multiplier" field.
|
||||
func AccountRateMultiplierLTE(v float64) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldAccountRateMultiplier, v))
|
||||
}
|
||||
|
||||
// AccountRateMultiplierIsNil applies the IsNil predicate on the "account_rate_multiplier" field.
|
||||
func AccountRateMultiplierIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldAccountRateMultiplier))
|
||||
}
|
||||
|
||||
// AccountRateMultiplierNotNil applies the NotNil predicate on the "account_rate_multiplier" field.
|
||||
func AccountRateMultiplierNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldAccountRateMultiplier))
|
||||
}
|
||||
|
||||
// BillingTypeEQ applies the EQ predicate on the "billing_type" field.
|
||||
func BillingTypeEQ(v int8) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldBillingType, v))
|
||||
|
||||
@@ -267,6 +267,20 @@ func (_c *UsageLogCreate) SetNillableRateMultiplier(v *float64) *UsageLogCreate
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
|
||||
func (_c *UsageLogCreate) SetAccountRateMultiplier(v float64) *UsageLogCreate {
|
||||
_c.mutation.SetAccountRateMultiplier(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableAccountRateMultiplier sets the "account_rate_multiplier" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableAccountRateMultiplier(v *float64) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetAccountRateMultiplier(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetBillingType sets the "billing_type" field.
|
||||
func (_c *UsageLogCreate) SetBillingType(v int8) *UsageLogCreate {
|
||||
_c.mutation.SetBillingType(v)
|
||||
@@ -712,6 +726,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
|
||||
_node.RateMultiplier = value
|
||||
}
|
||||
if value, ok := _c.mutation.AccountRateMultiplier(); ok {
|
||||
_spec.SetField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
|
||||
_node.AccountRateMultiplier = &value
|
||||
}
|
||||
if value, ok := _c.mutation.BillingType(); ok {
|
||||
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
|
||||
_node.BillingType = value
|
||||
@@ -1215,6 +1233,30 @@ func (u *UsageLogUpsert) AddRateMultiplier(v float64) *UsageLogUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
|
||||
func (u *UsageLogUpsert) SetAccountRateMultiplier(v float64) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldAccountRateMultiplier, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateAccountRateMultiplier sets the "account_rate_multiplier" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateAccountRateMultiplier() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldAccountRateMultiplier)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddAccountRateMultiplier adds v to the "account_rate_multiplier" field.
|
||||
func (u *UsageLogUpsert) AddAccountRateMultiplier(v float64) *UsageLogUpsert {
|
||||
u.Add(usagelog.FieldAccountRateMultiplier, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
|
||||
func (u *UsageLogUpsert) ClearAccountRateMultiplier() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldAccountRateMultiplier)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetBillingType sets the "billing_type" field.
|
||||
func (u *UsageLogUpsert) SetBillingType(v int8) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldBillingType, v)
|
||||
@@ -1795,6 +1837,34 @@ func (u *UsageLogUpsertOne) UpdateRateMultiplier() *UsageLogUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
|
||||
func (u *UsageLogUpsertOne) SetAccountRateMultiplier(v float64) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetAccountRateMultiplier(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddAccountRateMultiplier adds v to the "account_rate_multiplier" field.
|
||||
func (u *UsageLogUpsertOne) AddAccountRateMultiplier(v float64) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.AddAccountRateMultiplier(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateAccountRateMultiplier sets the "account_rate_multiplier" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateAccountRateMultiplier() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateAccountRateMultiplier()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
|
||||
func (u *UsageLogUpsertOne) ClearAccountRateMultiplier() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearAccountRateMultiplier()
|
||||
})
|
||||
}
|
||||
|
||||
// SetBillingType sets the "billing_type" field.
|
||||
func (u *UsageLogUpsertOne) SetBillingType(v int8) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
@@ -2566,6 +2636,34 @@ func (u *UsageLogUpsertBulk) UpdateRateMultiplier() *UsageLogUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
|
||||
func (u *UsageLogUpsertBulk) SetAccountRateMultiplier(v float64) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetAccountRateMultiplier(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddAccountRateMultiplier adds v to the "account_rate_multiplier" field.
|
||||
func (u *UsageLogUpsertBulk) AddAccountRateMultiplier(v float64) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.AddAccountRateMultiplier(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateAccountRateMultiplier sets the "account_rate_multiplier" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateAccountRateMultiplier() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateAccountRateMultiplier()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
|
||||
func (u *UsageLogUpsertBulk) ClearAccountRateMultiplier() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearAccountRateMultiplier()
|
||||
})
|
||||
}
|
||||
|
||||
// SetBillingType sets the "billing_type" field.
|
||||
func (u *UsageLogUpsertBulk) SetBillingType(v int8) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
|
||||
@@ -415,6 +415,33 @@ func (_u *UsageLogUpdate) AddRateMultiplier(v float64) *UsageLogUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
|
||||
func (_u *UsageLogUpdate) SetAccountRateMultiplier(v float64) *UsageLogUpdate {
|
||||
_u.mutation.ResetAccountRateMultiplier()
|
||||
_u.mutation.SetAccountRateMultiplier(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableAccountRateMultiplier sets the "account_rate_multiplier" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableAccountRateMultiplier(v *float64) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetAccountRateMultiplier(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAccountRateMultiplier adds value to the "account_rate_multiplier" field.
|
||||
func (_u *UsageLogUpdate) AddAccountRateMultiplier(v float64) *UsageLogUpdate {
|
||||
_u.mutation.AddAccountRateMultiplier(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
|
||||
func (_u *UsageLogUpdate) ClearAccountRateMultiplier() *UsageLogUpdate {
|
||||
_u.mutation.ClearAccountRateMultiplier()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetBillingType sets the "billing_type" field.
|
||||
func (_u *UsageLogUpdate) SetBillingType(v int8) *UsageLogUpdate {
|
||||
_u.mutation.ResetBillingType()
|
||||
@@ -807,6 +834,15 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
|
||||
_spec.AddField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AccountRateMultiplier(); ok {
|
||||
_spec.SetField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedAccountRateMultiplier(); ok {
|
||||
_spec.AddField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.AccountRateMultiplierCleared() {
|
||||
_spec.ClearField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.BillingType(); ok {
|
||||
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
|
||||
}
|
||||
@@ -1406,6 +1442,33 @@ func (_u *UsageLogUpdateOne) AddRateMultiplier(v float64) *UsageLogUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
|
||||
func (_u *UsageLogUpdateOne) SetAccountRateMultiplier(v float64) *UsageLogUpdateOne {
|
||||
_u.mutation.ResetAccountRateMultiplier()
|
||||
_u.mutation.SetAccountRateMultiplier(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableAccountRateMultiplier sets the "account_rate_multiplier" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableAccountRateMultiplier(v *float64) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetAccountRateMultiplier(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAccountRateMultiplier adds value to the "account_rate_multiplier" field.
|
||||
func (_u *UsageLogUpdateOne) AddAccountRateMultiplier(v float64) *UsageLogUpdateOne {
|
||||
_u.mutation.AddAccountRateMultiplier(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
|
||||
func (_u *UsageLogUpdateOne) ClearAccountRateMultiplier() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearAccountRateMultiplier()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetBillingType sets the "billing_type" field.
|
||||
func (_u *UsageLogUpdateOne) SetBillingType(v int8) *UsageLogUpdateOne {
|
||||
_u.mutation.ResetBillingType()
|
||||
@@ -1828,6 +1891,15 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
|
||||
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
|
||||
_spec.AddField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AccountRateMultiplier(); ok {
|
||||
_spec.SetField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedAccountRateMultiplier(); ok {
|
||||
_spec.AddField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.AccountRateMultiplierCleared() {
|
||||
_spec.ClearField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.BillingType(); ok {
|
||||
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
|
||||
}
|
||||
|
||||
@@ -39,6 +39,12 @@ type User struct {
|
||||
Username string `json:"username,omitempty"`
|
||||
// Notes holds the value of the "notes" field.
|
||||
Notes string `json:"notes,omitempty"`
|
||||
// TotpSecretEncrypted holds the value of the "totp_secret_encrypted" field.
|
||||
TotpSecretEncrypted *string `json:"totp_secret_encrypted,omitempty"`
|
||||
// TotpEnabled holds the value of the "totp_enabled" field.
|
||||
TotpEnabled bool `json:"totp_enabled,omitempty"`
|
||||
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
|
||||
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
// The values are being populated by the UserQuery when eager-loading is set.
|
||||
Edges UserEdges `json:"edges"`
|
||||
@@ -156,13 +162,15 @@ func (*User) scanValues(columns []string) ([]any, error) {
|
||||
values := make([]any, len(columns))
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
case user.FieldTotpEnabled:
|
||||
values[i] = new(sql.NullBool)
|
||||
case user.FieldBalance:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case user.FieldID, user.FieldConcurrency:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes:
|
||||
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted:
|
||||
values[i] = new(sql.NullString)
|
||||
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt:
|
||||
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
default:
|
||||
values[i] = new(sql.UnknownType)
|
||||
@@ -252,6 +260,26 @@ func (_m *User) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.Notes = value.String
|
||||
}
|
||||
case user.FieldTotpSecretEncrypted:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field totp_secret_encrypted", values[i])
|
||||
} else if value.Valid {
|
||||
_m.TotpSecretEncrypted = new(string)
|
||||
*_m.TotpSecretEncrypted = value.String
|
||||
}
|
||||
case user.FieldTotpEnabled:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field totp_enabled", values[i])
|
||||
} else if value.Valid {
|
||||
_m.TotpEnabled = value.Bool
|
||||
}
|
||||
case user.FieldTotpEnabledAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field totp_enabled_at", values[i])
|
||||
} else if value.Valid {
|
||||
_m.TotpEnabledAt = new(time.Time)
|
||||
*_m.TotpEnabledAt = value.Time
|
||||
}
|
||||
default:
|
||||
_m.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
@@ -367,6 +395,19 @@ func (_m *User) String() string {
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("notes=")
|
||||
builder.WriteString(_m.Notes)
|
||||
builder.WriteString(", ")
|
||||
if v := _m.TotpSecretEncrypted; v != nil {
|
||||
builder.WriteString("totp_secret_encrypted=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("totp_enabled=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.TotpEnabled))
|
||||
builder.WriteString(", ")
|
||||
if v := _m.TotpEnabledAt; v != nil {
|
||||
builder.WriteString("totp_enabled_at=")
|
||||
builder.WriteString(v.Format(time.ANSIC))
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
@@ -37,6 +37,12 @@ const (
|
||||
FieldUsername = "username"
|
||||
// FieldNotes holds the string denoting the notes field in the database.
|
||||
FieldNotes = "notes"
|
||||
// FieldTotpSecretEncrypted holds the string denoting the totp_secret_encrypted field in the database.
|
||||
FieldTotpSecretEncrypted = "totp_secret_encrypted"
|
||||
// FieldTotpEnabled holds the string denoting the totp_enabled field in the database.
|
||||
FieldTotpEnabled = "totp_enabled"
|
||||
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
|
||||
FieldTotpEnabledAt = "totp_enabled_at"
|
||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||
EdgeAPIKeys = "api_keys"
|
||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||
@@ -134,6 +140,9 @@ var Columns = []string{
|
||||
FieldStatus,
|
||||
FieldUsername,
|
||||
FieldNotes,
|
||||
FieldTotpSecretEncrypted,
|
||||
FieldTotpEnabled,
|
||||
FieldTotpEnabledAt,
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -188,6 +197,8 @@ var (
|
||||
UsernameValidator func(string) error
|
||||
// DefaultNotes holds the default value on creation for the "notes" field.
|
||||
DefaultNotes string
|
||||
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
|
||||
DefaultTotpEnabled bool
|
||||
)
|
||||
|
||||
// OrderOption defines the ordering options for the User queries.
|
||||
@@ -253,6 +264,21 @@ func ByNotes(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldNotes, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByTotpSecretEncrypted orders the results by the totp_secret_encrypted field.
|
||||
func ByTotpSecretEncrypted(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldTotpSecretEncrypted, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByTotpEnabled orders the results by the totp_enabled field.
|
||||
func ByTotpEnabled(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldTotpEnabled, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByTotpEnabledAt orders the results by the totp_enabled_at field.
|
||||
func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAPIKeysCount orders the results by api_keys count.
|
||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
|
||||
@@ -110,6 +110,21 @@ func Notes(v string) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldNotes, v))
|
||||
}
|
||||
|
||||
// TotpSecretEncrypted applies equality check predicate on the "totp_secret_encrypted" field. It's identical to TotpSecretEncryptedEQ.
|
||||
func TotpSecretEncrypted(v string) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldTotpSecretEncrypted, v))
|
||||
}
|
||||
|
||||
// TotpEnabled applies equality check predicate on the "totp_enabled" field. It's identical to TotpEnabledEQ.
|
||||
func TotpEnabled(v bool) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldTotpEnabled, v))
|
||||
}
|
||||
|
||||
// TotpEnabledAt applies equality check predicate on the "totp_enabled_at" field. It's identical to TotpEnabledAtEQ.
|
||||
func TotpEnabledAt(v time.Time) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
|
||||
@@ -710,6 +725,141 @@ func NotesContainsFold(v string) predicate.User {
|
||||
return predicate.User(sql.FieldContainsFold(FieldNotes, v))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedEQ applies the EQ predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedEQ(v string) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldTotpSecretEncrypted, v))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedNEQ applies the NEQ predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedNEQ(v string) predicate.User {
|
||||
return predicate.User(sql.FieldNEQ(FieldTotpSecretEncrypted, v))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedIn applies the In predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedIn(vs ...string) predicate.User {
|
||||
return predicate.User(sql.FieldIn(FieldTotpSecretEncrypted, vs...))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedNotIn applies the NotIn predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedNotIn(vs ...string) predicate.User {
|
||||
return predicate.User(sql.FieldNotIn(FieldTotpSecretEncrypted, vs...))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedGT applies the GT predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedGT(v string) predicate.User {
|
||||
return predicate.User(sql.FieldGT(FieldTotpSecretEncrypted, v))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedGTE applies the GTE predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedGTE(v string) predicate.User {
|
||||
return predicate.User(sql.FieldGTE(FieldTotpSecretEncrypted, v))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedLT applies the LT predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedLT(v string) predicate.User {
|
||||
return predicate.User(sql.FieldLT(FieldTotpSecretEncrypted, v))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedLTE applies the LTE predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedLTE(v string) predicate.User {
|
||||
return predicate.User(sql.FieldLTE(FieldTotpSecretEncrypted, v))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedContains applies the Contains predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedContains(v string) predicate.User {
|
||||
return predicate.User(sql.FieldContains(FieldTotpSecretEncrypted, v))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedHasPrefix applies the HasPrefix predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedHasPrefix(v string) predicate.User {
|
||||
return predicate.User(sql.FieldHasPrefix(FieldTotpSecretEncrypted, v))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedHasSuffix applies the HasSuffix predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedHasSuffix(v string) predicate.User {
|
||||
return predicate.User(sql.FieldHasSuffix(FieldTotpSecretEncrypted, v))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedIsNil applies the IsNil predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedIsNil() predicate.User {
|
||||
return predicate.User(sql.FieldIsNull(FieldTotpSecretEncrypted))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedNotNil applies the NotNil predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedNotNil() predicate.User {
|
||||
return predicate.User(sql.FieldNotNull(FieldTotpSecretEncrypted))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedEqualFold applies the EqualFold predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedEqualFold(v string) predicate.User {
|
||||
return predicate.User(sql.FieldEqualFold(FieldTotpSecretEncrypted, v))
|
||||
}
|
||||
|
||||
// TotpSecretEncryptedContainsFold applies the ContainsFold predicate on the "totp_secret_encrypted" field.
|
||||
func TotpSecretEncryptedContainsFold(v string) predicate.User {
|
||||
return predicate.User(sql.FieldContainsFold(FieldTotpSecretEncrypted, v))
|
||||
}
|
||||
|
||||
// TotpEnabledEQ applies the EQ predicate on the "totp_enabled" field.
|
||||
func TotpEnabledEQ(v bool) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldTotpEnabled, v))
|
||||
}
|
||||
|
||||
// TotpEnabledNEQ applies the NEQ predicate on the "totp_enabled" field.
|
||||
func TotpEnabledNEQ(v bool) predicate.User {
|
||||
return predicate.User(sql.FieldNEQ(FieldTotpEnabled, v))
|
||||
}
|
||||
|
||||
// TotpEnabledAtEQ applies the EQ predicate on the "totp_enabled_at" field.
|
||||
func TotpEnabledAtEQ(v time.Time) predicate.User {
|
||||
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
|
||||
}
|
||||
|
||||
// TotpEnabledAtNEQ applies the NEQ predicate on the "totp_enabled_at" field.
|
||||
func TotpEnabledAtNEQ(v time.Time) predicate.User {
|
||||
return predicate.User(sql.FieldNEQ(FieldTotpEnabledAt, v))
|
||||
}
|
||||
|
||||
// TotpEnabledAtIn applies the In predicate on the "totp_enabled_at" field.
|
||||
func TotpEnabledAtIn(vs ...time.Time) predicate.User {
|
||||
return predicate.User(sql.FieldIn(FieldTotpEnabledAt, vs...))
|
||||
}
|
||||
|
||||
// TotpEnabledAtNotIn applies the NotIn predicate on the "totp_enabled_at" field.
|
||||
func TotpEnabledAtNotIn(vs ...time.Time) predicate.User {
|
||||
return predicate.User(sql.FieldNotIn(FieldTotpEnabledAt, vs...))
|
||||
}
|
||||
|
||||
// TotpEnabledAtGT applies the GT predicate on the "totp_enabled_at" field.
|
||||
func TotpEnabledAtGT(v time.Time) predicate.User {
|
||||
return predicate.User(sql.FieldGT(FieldTotpEnabledAt, v))
|
||||
}
|
||||
|
||||
// TotpEnabledAtGTE applies the GTE predicate on the "totp_enabled_at" field.
|
||||
func TotpEnabledAtGTE(v time.Time) predicate.User {
|
||||
return predicate.User(sql.FieldGTE(FieldTotpEnabledAt, v))
|
||||
}
|
||||
|
||||
// TotpEnabledAtLT applies the LT predicate on the "totp_enabled_at" field.
|
||||
func TotpEnabledAtLT(v time.Time) predicate.User {
|
||||
return predicate.User(sql.FieldLT(FieldTotpEnabledAt, v))
|
||||
}
|
||||
|
||||
// TotpEnabledAtLTE applies the LTE predicate on the "totp_enabled_at" field.
|
||||
func TotpEnabledAtLTE(v time.Time) predicate.User {
|
||||
return predicate.User(sql.FieldLTE(FieldTotpEnabledAt, v))
|
||||
}
|
||||
|
||||
// TotpEnabledAtIsNil applies the IsNil predicate on the "totp_enabled_at" field.
|
||||
func TotpEnabledAtIsNil() predicate.User {
|
||||
return predicate.User(sql.FieldIsNull(FieldTotpEnabledAt))
|
||||
}
|
||||
|
||||
// TotpEnabledAtNotNil applies the NotNil predicate on the "totp_enabled_at" field.
|
||||
func TotpEnabledAtNotNil() predicate.User {
|
||||
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
|
||||
}
|
||||
|
||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||
func HasAPIKeys() predicate.User {
|
||||
return predicate.User(func(s *sql.Selector) {
|
||||
|
||||
@@ -167,6 +167,48 @@ func (_c *UserCreate) SetNillableNotes(v *string) *UserCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
|
||||
func (_c *UserCreate) SetTotpSecretEncrypted(v string) *UserCreate {
|
||||
_c.mutation.SetTotpSecretEncrypted(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
|
||||
func (_c *UserCreate) SetNillableTotpSecretEncrypted(v *string) *UserCreate {
|
||||
if v != nil {
|
||||
_c.SetTotpSecretEncrypted(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetTotpEnabled sets the "totp_enabled" field.
|
||||
func (_c *UserCreate) SetTotpEnabled(v bool) *UserCreate {
|
||||
_c.mutation.SetTotpEnabled(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
|
||||
func (_c *UserCreate) SetNillableTotpEnabled(v *bool) *UserCreate {
|
||||
if v != nil {
|
||||
_c.SetTotpEnabled(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetTotpEnabledAt sets the "totp_enabled_at" field.
|
||||
func (_c *UserCreate) SetTotpEnabledAt(v time.Time) *UserCreate {
|
||||
_c.mutation.SetTotpEnabledAt(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
|
||||
func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
|
||||
if v != nil {
|
||||
_c.SetTotpEnabledAt(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
|
||||
_c.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -362,6 +404,10 @@ func (_c *UserCreate) defaults() error {
|
||||
v := user.DefaultNotes
|
||||
_c.mutation.SetNotes(v)
|
||||
}
|
||||
if _, ok := _c.mutation.TotpEnabled(); !ok {
|
||||
v := user.DefaultTotpEnabled
|
||||
_c.mutation.SetTotpEnabled(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -422,6 +468,9 @@ func (_c *UserCreate) check() error {
|
||||
if _, ok := _c.mutation.Notes(); !ok {
|
||||
return &ValidationError{Name: "notes", err: errors.New(`ent: missing required field "User.notes"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.TotpEnabled(); !ok {
|
||||
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -493,6 +542,18 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(user.FieldNotes, field.TypeString, value)
|
||||
_node.Notes = value
|
||||
}
|
||||
if value, ok := _c.mutation.TotpSecretEncrypted(); ok {
|
||||
_spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
|
||||
_node.TotpSecretEncrypted = &value
|
||||
}
|
||||
if value, ok := _c.mutation.TotpEnabled(); ok {
|
||||
_spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
|
||||
_node.TotpEnabled = value
|
||||
}
|
||||
if value, ok := _c.mutation.TotpEnabledAt(); ok {
|
||||
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
|
||||
_node.TotpEnabledAt = &value
|
||||
}
|
||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -815,6 +876,54 @@ func (u *UserUpsert) UpdateNotes() *UserUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
|
||||
func (u *UserUpsert) SetTotpSecretEncrypted(v string) *UserUpsert {
|
||||
u.Set(user.FieldTotpSecretEncrypted, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
|
||||
func (u *UserUpsert) UpdateTotpSecretEncrypted() *UserUpsert {
|
||||
u.SetExcluded(user.FieldTotpSecretEncrypted)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
|
||||
func (u *UserUpsert) ClearTotpSecretEncrypted() *UserUpsert {
|
||||
u.SetNull(user.FieldTotpSecretEncrypted)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetTotpEnabled sets the "totp_enabled" field.
|
||||
func (u *UserUpsert) SetTotpEnabled(v bool) *UserUpsert {
|
||||
u.Set(user.FieldTotpEnabled, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
|
||||
func (u *UserUpsert) UpdateTotpEnabled() *UserUpsert {
|
||||
u.SetExcluded(user.FieldTotpEnabled)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetTotpEnabledAt sets the "totp_enabled_at" field.
|
||||
func (u *UserUpsert) SetTotpEnabledAt(v time.Time) *UserUpsert {
|
||||
u.Set(user.FieldTotpEnabledAt, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
|
||||
func (u *UserUpsert) UpdateTotpEnabledAt() *UserUpsert {
|
||||
u.SetExcluded(user.FieldTotpEnabledAt)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
|
||||
func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
|
||||
u.SetNull(user.FieldTotpEnabledAt)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||
// Using this option is equivalent to using:
|
||||
//
|
||||
@@ -1021,6 +1130,62 @@ func (u *UserUpsertOne) UpdateNotes() *UserUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
|
||||
func (u *UserUpsertOne) SetTotpSecretEncrypted(v string) *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.SetTotpSecretEncrypted(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
|
||||
func (u *UserUpsertOne) UpdateTotpSecretEncrypted() *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.UpdateTotpSecretEncrypted()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
|
||||
func (u *UserUpsertOne) ClearTotpSecretEncrypted() *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.ClearTotpSecretEncrypted()
|
||||
})
|
||||
}
|
||||
|
||||
// SetTotpEnabled sets the "totp_enabled" field.
|
||||
func (u *UserUpsertOne) SetTotpEnabled(v bool) *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.SetTotpEnabled(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
|
||||
func (u *UserUpsertOne) UpdateTotpEnabled() *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.UpdateTotpEnabled()
|
||||
})
|
||||
}
|
||||
|
||||
// SetTotpEnabledAt sets the "totp_enabled_at" field.
|
||||
func (u *UserUpsertOne) SetTotpEnabledAt(v time.Time) *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.SetTotpEnabledAt(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
|
||||
func (u *UserUpsertOne) UpdateTotpEnabledAt() *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.UpdateTotpEnabledAt()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
|
||||
func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.ClearTotpEnabledAt()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *UserUpsertOne) Exec(ctx context.Context) error {
|
||||
if len(u.create.conflict) == 0 {
|
||||
@@ -1393,6 +1558,62 @@ func (u *UserUpsertBulk) UpdateNotes() *UserUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
|
||||
func (u *UserUpsertBulk) SetTotpSecretEncrypted(v string) *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.SetTotpSecretEncrypted(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
|
||||
func (u *UserUpsertBulk) UpdateTotpSecretEncrypted() *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.UpdateTotpSecretEncrypted()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
|
||||
func (u *UserUpsertBulk) ClearTotpSecretEncrypted() *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.ClearTotpSecretEncrypted()
|
||||
})
|
||||
}
|
||||
|
||||
// SetTotpEnabled sets the "totp_enabled" field.
|
||||
func (u *UserUpsertBulk) SetTotpEnabled(v bool) *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.SetTotpEnabled(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
|
||||
func (u *UserUpsertBulk) UpdateTotpEnabled() *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.UpdateTotpEnabled()
|
||||
})
|
||||
}
|
||||
|
||||
// SetTotpEnabledAt sets the "totp_enabled_at" field.
|
||||
func (u *UserUpsertBulk) SetTotpEnabledAt(v time.Time) *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.SetTotpEnabledAt(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
|
||||
func (u *UserUpsertBulk) UpdateTotpEnabledAt() *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.UpdateTotpEnabledAt()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
|
||||
func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
|
||||
return u.Update(func(s *UserUpsert) {
|
||||
s.ClearTotpEnabledAt()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
|
||||
if u.create.err != nil {
|
||||
|
||||
@@ -187,6 +187,60 @@ func (_u *UserUpdate) SetNillableNotes(v *string) *UserUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
|
||||
func (_u *UserUpdate) SetTotpSecretEncrypted(v string) *UserUpdate {
|
||||
_u.mutation.SetTotpSecretEncrypted(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
|
||||
func (_u *UserUpdate) SetNillableTotpSecretEncrypted(v *string) *UserUpdate {
|
||||
if v != nil {
|
||||
_u.SetTotpSecretEncrypted(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
|
||||
func (_u *UserUpdate) ClearTotpSecretEncrypted() *UserUpdate {
|
||||
_u.mutation.ClearTotpSecretEncrypted()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetTotpEnabled sets the "totp_enabled" field.
|
||||
func (_u *UserUpdate) SetTotpEnabled(v bool) *UserUpdate {
|
||||
_u.mutation.SetTotpEnabled(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
|
||||
func (_u *UserUpdate) SetNillableTotpEnabled(v *bool) *UserUpdate {
|
||||
if v != nil {
|
||||
_u.SetTotpEnabled(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetTotpEnabledAt sets the "totp_enabled_at" field.
|
||||
func (_u *UserUpdate) SetTotpEnabledAt(v time.Time) *UserUpdate {
|
||||
_u.mutation.SetTotpEnabledAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
|
||||
func (_u *UserUpdate) SetNillableTotpEnabledAt(v *time.Time) *UserUpdate {
|
||||
if v != nil {
|
||||
_u.SetTotpEnabledAt(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
|
||||
func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
|
||||
_u.mutation.ClearTotpEnabledAt()
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -603,6 +657,21 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.Notes(); ok {
|
||||
_spec.SetField(user.FieldNotes, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.TotpSecretEncrypted(); ok {
|
||||
_spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.TotpSecretEncryptedCleared() {
|
||||
_spec.ClearField(user.FieldTotpSecretEncrypted, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.TotpEnabled(); ok {
|
||||
_spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.TotpEnabledAt(); ok {
|
||||
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.TotpEnabledAtCleared() {
|
||||
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -1147,6 +1216,60 @@ func (_u *UserUpdateOne) SetNillableNotes(v *string) *UserUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
|
||||
func (_u *UserUpdateOne) SetTotpSecretEncrypted(v string) *UserUpdateOne {
|
||||
_u.mutation.SetTotpSecretEncrypted(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
|
||||
func (_u *UserUpdateOne) SetNillableTotpSecretEncrypted(v *string) *UserUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetTotpSecretEncrypted(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
|
||||
func (_u *UserUpdateOne) ClearTotpSecretEncrypted() *UserUpdateOne {
|
||||
_u.mutation.ClearTotpSecretEncrypted()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetTotpEnabled sets the "totp_enabled" field.
|
||||
func (_u *UserUpdateOne) SetTotpEnabled(v bool) *UserUpdateOne {
|
||||
_u.mutation.SetTotpEnabled(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
|
||||
func (_u *UserUpdateOne) SetNillableTotpEnabled(v *bool) *UserUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetTotpEnabled(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetTotpEnabledAt sets the "totp_enabled_at" field.
|
||||
func (_u *UserUpdateOne) SetTotpEnabledAt(v time.Time) *UserUpdateOne {
|
||||
_u.mutation.SetTotpEnabledAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
|
||||
func (_u *UserUpdateOne) SetNillableTotpEnabledAt(v *time.Time) *UserUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetTotpEnabledAt(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
|
||||
func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
|
||||
_u.mutation.ClearTotpEnabledAt()
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -1593,6 +1716,21 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
|
||||
if value, ok := _u.mutation.Notes(); ok {
|
||||
_spec.SetField(user.FieldNotes, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.TotpSecretEncrypted(); ok {
|
||||
_spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.TotpSecretEncryptedCleared() {
|
||||
_spec.ClearField(user.FieldTotpSecretEncrypted, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.TotpEnabled(); ok {
|
||||
_spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.TotpEnabledAt(); ok {
|
||||
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.TotpEnabledAtCleared() {
|
||||
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
|
||||
@@ -31,11 +31,13 @@ require (
|
||||
ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9 // indirect
|
||||
dario.cat/mergo v1.0.2 // indirect
|
||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/agext/levenshtein v1.2.3 // indirect
|
||||
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
|
||||
github.com/bmatcuk/doublestar v1.3.4 // indirect
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
@@ -97,6 +99,7 @@ require (
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/morikuni/aec v1.0.0 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/olekukonko/tablewriter v0.0.5 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||
@@ -104,9 +107,11 @@ require (
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||
github.com/pquerna/otp v1.5.0 // indirect
|
||||
github.com/quic-go/qpack v0.6.0 // indirect
|
||||
github.com/quic-go/quic-go v0.57.1 // indirect
|
||||
github.com/refraction-networking/utls v1.8.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/robfig/cron/v3 v3.0.1 // indirect
|
||||
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
||||
@@ -139,7 +144,7 @@ require (
|
||||
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||
go.uber.org/multierr v1.9.0 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/mod v0.30.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
golang.org/x/text v0.32.0 // indirect
|
||||
@@ -148,4 +153,8 @@ require (
|
||||
google.golang.org/grpc v1.75.1 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
modernc.org/libc v1.67.6 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
modernc.org/sqlite v1.44.1 // indirect
|
||||
)
|
||||
|
||||
@@ -20,6 +20,8 @@ github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew
|
||||
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
|
||||
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
|
||||
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
@@ -141,6 +143,7 @@ github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
|
||||
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
|
||||
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
@@ -199,6 +202,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
@@ -214,6 +219,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
|
||||
github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
|
||||
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
|
||||
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
|
||||
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
|
||||
@@ -224,6 +231,8 @@ github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4Vi
|
||||
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
|
||||
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
||||
@@ -338,6 +347,8 @@ golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
|
||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
@@ -365,6 +376,7 @@ golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
|
||||
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
|
||||
golang.org/x/tools/go/expect v0.1.0-deprecated h1:jY2C5HGYR5lqex3gEniOQL0r7Dq5+VGVgY1nudX5lXY=
|
||||
golang.org/x/tools/go/expect v0.1.0-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY=
|
||||
golang.org/x/tools/go/expect v0.1.1-deprecated h1:jpBZDwmgPhXsKZC6WhL20P4b/wmnpsEAGHaNy0n/rJM=
|
||||
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated h1:1h2MnaIAIXISqTFKdENegdpAgUXz6NrPEsbIeWaBRvM=
|
||||
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated/go.mod h1:RVAQXBGNv1ib0J382/DPCRS/BPnsGebyM1Gj5VSDpG8=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
@@ -387,4 +399,12 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
||||
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
||||
modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI=
|
||||
modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/sqlite v1.44.1 h1:qybx/rNpfQipX/t47OxbHmkkJuv2JWifCMH8SVUiDas=
|
||||
modernc.org/sqlite v1.44.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=
|
||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||
|
||||
@@ -19,7 +19,9 @@ const (
|
||||
RunModeSimple = "simple"
|
||||
)
|
||||
|
||||
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' https://challenges.cloudflare.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
|
||||
// DefaultCSPPolicy is the default Content-Security-Policy with nonce support
|
||||
// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
|
||||
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
|
||||
|
||||
// 连接池隔离策略常量
|
||||
// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗
|
||||
@@ -45,6 +47,7 @@ type Config struct {
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
Ops OpsConfig `mapstructure:"ops"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
Totp TotpConfig `mapstructure:"totp"`
|
||||
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
@@ -53,6 +56,7 @@ type Config struct {
|
||||
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
|
||||
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
|
||||
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
|
||||
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
|
||||
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||
@@ -232,6 +236,10 @@ type GatewayConfig struct {
|
||||
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
|
||||
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
|
||||
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
|
||||
// SessionIdleTimeoutMinutes: 会话空闲超时时间(分钟),默认 5 分钟
|
||||
// 用于 Anthropic OAuth/SetupToken 账号的会话数量限制功能
|
||||
// 空闲超过此时间的会话将被自动释放
|
||||
SessionIdleTimeoutMinutes int `mapstructure:"session_idle_timeout_minutes"`
|
||||
|
||||
// StreamDataIntervalTimeout: 流数据间隔超时(秒),0表示禁用
|
||||
StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"`
|
||||
@@ -251,8 +259,43 @@ type GatewayConfig struct {
|
||||
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
|
||||
FailoverOn400 bool `mapstructure:"failover_on_400"`
|
||||
|
||||
// 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限)
|
||||
MaxAccountSwitches int `mapstructure:"max_account_switches"`
|
||||
// Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格)
|
||||
MaxAccountSwitchesGemini int `mapstructure:"max_account_switches_gemini"`
|
||||
|
||||
// Antigravity 429 fallback 限流时间(分钟),解析重置时间失败时使用
|
||||
AntigravityFallbackCooldownMinutes int `mapstructure:"antigravity_fallback_cooldown_minutes"`
|
||||
|
||||
// Scheduling: 账号调度相关配置
|
||||
Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"`
|
||||
|
||||
// TLSFingerprint: TLS指纹伪装配置
|
||||
TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"`
|
||||
}
|
||||
|
||||
// TLSFingerprintConfig TLS指纹伪装配置
|
||||
// 用于模拟 Claude CLI (Node.js) 的 TLS 握手特征,避免被识别为非官方客户端
|
||||
type TLSFingerprintConfig struct {
|
||||
// Enabled: 是否全局启用TLS指纹功能
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
// Profiles: 预定义的TLS指纹配置模板
|
||||
// key 为模板名称,如 "claude_cli_v2", "chrome_120" 等
|
||||
Profiles map[string]TLSProfileConfig `mapstructure:"profiles"`
|
||||
}
|
||||
|
||||
// TLSProfileConfig 单个TLS指纹模板的配置
|
||||
type TLSProfileConfig struct {
|
||||
// Name: 模板显示名称
|
||||
Name string `mapstructure:"name"`
|
||||
// EnableGREASE: 是否启用GREASE扩展(Chrome使用,Node.js不使用)
|
||||
EnableGREASE bool `mapstructure:"enable_grease"`
|
||||
// CipherSuites: TLS加密套件列表(空则使用内置默认值)
|
||||
CipherSuites []uint16 `mapstructure:"cipher_suites"`
|
||||
// Curves: 椭圆曲线列表(空则使用内置默认值)
|
||||
Curves []uint16 `mapstructure:"curves"`
|
||||
// PointFormats: 点格式列表(空则使用内置默认值)
|
||||
PointFormats []uint8 `mapstructure:"point_formats"`
|
||||
}
|
||||
|
||||
// GatewaySchedulingConfig accounts scheduling configuration.
|
||||
@@ -265,6 +308,9 @@ type GatewaySchedulingConfig struct {
|
||||
FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"`
|
||||
FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"`
|
||||
|
||||
// 兜底层账户选择策略: "last_used"(按最后使用时间排序,默认) 或 "random"(随机)
|
||||
FallbackSelectionMode string `mapstructure:"fallback_selection_mode"`
|
||||
|
||||
// 负载计算
|
||||
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
|
||||
|
||||
@@ -421,6 +467,16 @@ type JWTConfig struct {
|
||||
ExpireHour int `mapstructure:"expire_hour"`
|
||||
}
|
||||
|
||||
// TotpConfig TOTP 双因素认证配置
|
||||
type TotpConfig struct {
|
||||
// EncryptionKey 用于加密 TOTP 密钥的 AES-256 密钥(32 字节 hex 编码)
|
||||
// 如果为空,将自动生成一个随机密钥(仅适用于开发环境)
|
||||
EncryptionKey string `mapstructure:"encryption_key"`
|
||||
// EncryptionKeyConfigured 标记加密密钥是否为手动配置(非自动生成)
|
||||
// 只有手动配置了密钥才允许在管理后台启用 TOTP 功能
|
||||
EncryptionKeyConfigured bool `mapstructure:"-"`
|
||||
}
|
||||
|
||||
type TurnstileConfig struct {
|
||||
Required bool `mapstructure:"required"`
|
||||
}
|
||||
@@ -487,6 +543,20 @@ type DashboardAggregationRetentionConfig struct {
|
||||
DailyDays int `mapstructure:"daily_days"`
|
||||
}
|
||||
|
||||
// UsageCleanupConfig 使用记录清理任务配置
|
||||
type UsageCleanupConfig struct {
|
||||
// Enabled: 是否启用清理任务执行器
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
// MaxRangeDays: 单次任务允许的最大时间跨度(天)
|
||||
MaxRangeDays int `mapstructure:"max_range_days"`
|
||||
// BatchSize: 单批删除数量
|
||||
BatchSize int `mapstructure:"batch_size"`
|
||||
// WorkerIntervalSeconds: 后台任务轮询间隔(秒)
|
||||
WorkerIntervalSeconds int `mapstructure:"worker_interval_seconds"`
|
||||
// TaskTimeoutSeconds: 单次任务最大执行时长(秒)
|
||||
TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"`
|
||||
}
|
||||
|
||||
func NormalizeRunMode(value string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(value))
|
||||
switch normalized {
|
||||
@@ -567,6 +637,20 @@ func Load() (*Config, error) {
|
||||
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
|
||||
}
|
||||
|
||||
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
|
||||
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
|
||||
if cfg.Totp.EncryptionKey == "" {
|
||||
key, err := generateJWTSecret(32) // Reuse the same random generation function
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate totp encryption key error: %w", err)
|
||||
}
|
||||
cfg.Totp.EncryptionKey = key
|
||||
cfg.Totp.EncryptionKeyConfigured = false
|
||||
log.Println("Warning: TOTP encryption key auto-generated. Consider setting a fixed key for production.")
|
||||
} else {
|
||||
cfg.Totp.EncryptionKeyConfigured = true
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("validate config error: %w", err)
|
||||
}
|
||||
@@ -697,6 +781,9 @@ func setDefaults() {
|
||||
viper.SetDefault("jwt.secret", "")
|
||||
viper.SetDefault("jwt.expire_hour", 24)
|
||||
|
||||
// TOTP
|
||||
viper.SetDefault("totp.encryption_key", "")
|
||||
|
||||
// Default
|
||||
// Admin credentials are created via the setup flow (web wizard / CLI / AUTO_SETUP).
|
||||
// Do not ship fixed defaults here to avoid insecure "known credentials" in production.
|
||||
@@ -747,12 +834,22 @@ func setDefaults() {
|
||||
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
|
||||
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
|
||||
|
||||
// Usage cleanup task
|
||||
viper.SetDefault("usage_cleanup.enabled", true)
|
||||
viper.SetDefault("usage_cleanup.max_range_days", 31)
|
||||
viper.SetDefault("usage_cleanup.batch_size", 5000)
|
||||
viper.SetDefault("usage_cleanup.worker_interval_seconds", 10)
|
||||
viper.SetDefault("usage_cleanup.task_timeout_seconds", 1800)
|
||||
|
||||
// Gateway
|
||||
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久
|
||||
viper.SetDefault("gateway.log_upstream_error_body", true)
|
||||
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
|
||||
viper.SetDefault("gateway.inject_beta_for_apikey", false)
|
||||
viper.SetDefault("gateway.failover_on_400", false)
|
||||
viper.SetDefault("gateway.max_account_switches", 10)
|
||||
viper.SetDefault("gateway.max_account_switches_gemini", 3)
|
||||
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
|
||||
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
|
||||
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
|
||||
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
|
||||
@@ -765,11 +862,12 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
|
||||
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
|
||||
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
||||
viper.SetDefault("gateway.max_line_size", 10*1024*1024)
|
||||
viper.SetDefault("gateway.max_line_size", 40*1024*1024)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
|
||||
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
|
||||
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
|
||||
viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used")
|
||||
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
|
||||
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
|
||||
viper.SetDefault("gateway.scheduling.db_fallback_enabled", true)
|
||||
@@ -781,6 +879,8 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3)
|
||||
viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000)
|
||||
viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300)
|
||||
// TLS指纹伪装配置(默认关闭,需要账号级别单独启用)
|
||||
viper.SetDefault("gateway.tls_fingerprint.enabled", true)
|
||||
viper.SetDefault("concurrency.ping_interval", 10)
|
||||
|
||||
// TokenRefresh
|
||||
@@ -983,6 +1083,33 @@ func (c *Config) Validate() error {
|
||||
return fmt.Errorf("dashboard_aggregation.recompute_days must be non-negative")
|
||||
}
|
||||
}
|
||||
if c.UsageCleanup.Enabled {
|
||||
if c.UsageCleanup.MaxRangeDays <= 0 {
|
||||
return fmt.Errorf("usage_cleanup.max_range_days must be positive")
|
||||
}
|
||||
if c.UsageCleanup.BatchSize <= 0 {
|
||||
return fmt.Errorf("usage_cleanup.batch_size must be positive")
|
||||
}
|
||||
if c.UsageCleanup.WorkerIntervalSeconds <= 0 {
|
||||
return fmt.Errorf("usage_cleanup.worker_interval_seconds must be positive")
|
||||
}
|
||||
if c.UsageCleanup.TaskTimeoutSeconds <= 0 {
|
||||
return fmt.Errorf("usage_cleanup.task_timeout_seconds must be positive")
|
||||
}
|
||||
} else {
|
||||
if c.UsageCleanup.MaxRangeDays < 0 {
|
||||
return fmt.Errorf("usage_cleanup.max_range_days must be non-negative")
|
||||
}
|
||||
if c.UsageCleanup.BatchSize < 0 {
|
||||
return fmt.Errorf("usage_cleanup.batch_size must be non-negative")
|
||||
}
|
||||
if c.UsageCleanup.WorkerIntervalSeconds < 0 {
|
||||
return fmt.Errorf("usage_cleanup.worker_interval_seconds must be non-negative")
|
||||
}
|
||||
if c.UsageCleanup.TaskTimeoutSeconds < 0 {
|
||||
return fmt.Errorf("usage_cleanup.task_timeout_seconds must be non-negative")
|
||||
}
|
||||
}
|
||||
if c.Gateway.MaxBodySize <= 0 {
|
||||
return fmt.Errorf("gateway.max_body_size must be positive")
|
||||
}
|
||||
|
||||
@@ -280,3 +280,573 @@ func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) {
|
||||
t.Fatalf("Validate() expected backfill_max_days error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDefaultUsageCleanupConfig(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
if !cfg.UsageCleanup.Enabled {
|
||||
t.Fatalf("UsageCleanup.Enabled = false, want true")
|
||||
}
|
||||
if cfg.UsageCleanup.MaxRangeDays != 31 {
|
||||
t.Fatalf("UsageCleanup.MaxRangeDays = %d, want 31", cfg.UsageCleanup.MaxRangeDays)
|
||||
}
|
||||
if cfg.UsageCleanup.BatchSize != 5000 {
|
||||
t.Fatalf("UsageCleanup.BatchSize = %d, want 5000", cfg.UsageCleanup.BatchSize)
|
||||
}
|
||||
if cfg.UsageCleanup.WorkerIntervalSeconds != 10 {
|
||||
t.Fatalf("UsageCleanup.WorkerIntervalSeconds = %d, want 10", cfg.UsageCleanup.WorkerIntervalSeconds)
|
||||
}
|
||||
if cfg.UsageCleanup.TaskTimeoutSeconds != 1800 {
|
||||
t.Fatalf("UsageCleanup.TaskTimeoutSeconds = %d, want 1800", cfg.UsageCleanup.TaskTimeoutSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUsageCleanupConfigEnabled(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.UsageCleanup.Enabled = true
|
||||
cfg.UsageCleanup.MaxRangeDays = 0
|
||||
err = cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatalf("Validate() expected error for usage_cleanup.max_range_days, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "usage_cleanup.max_range_days") {
|
||||
t.Fatalf("Validate() expected max_range_days error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUsageCleanupConfigDisabled(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.UsageCleanup.Enabled = false
|
||||
cfg.UsageCleanup.BatchSize = -1
|
||||
err = cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatalf("Validate() expected error for usage_cleanup.batch_size, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "usage_cleanup.batch_size") {
|
||||
t.Fatalf("Validate() expected batch_size error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigAddressHelpers(t *testing.T) {
|
||||
server := ServerConfig{Host: "127.0.0.1", Port: 9000}
|
||||
if server.Address() != "127.0.0.1:9000" {
|
||||
t.Fatalf("ServerConfig.Address() = %q", server.Address())
|
||||
}
|
||||
|
||||
dbCfg := DatabaseConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "postgres",
|
||||
Password: "",
|
||||
DBName: "sub2api",
|
||||
SSLMode: "disable",
|
||||
}
|
||||
if !strings.Contains(dbCfg.DSN(), "password=") {
|
||||
} else {
|
||||
t.Fatalf("DatabaseConfig.DSN() should not include password when empty")
|
||||
}
|
||||
|
||||
dbCfg.Password = "secret"
|
||||
if !strings.Contains(dbCfg.DSN(), "password=secret") {
|
||||
t.Fatalf("DatabaseConfig.DSN() missing password")
|
||||
}
|
||||
|
||||
dbCfg.Password = ""
|
||||
if strings.Contains(dbCfg.DSNWithTimezone("UTC"), "password=") {
|
||||
t.Fatalf("DatabaseConfig.DSNWithTimezone() should omit password when empty")
|
||||
}
|
||||
|
||||
if !strings.Contains(dbCfg.DSNWithTimezone(""), "TimeZone=Asia/Shanghai") {
|
||||
t.Fatalf("DatabaseConfig.DSNWithTimezone() should use default timezone")
|
||||
}
|
||||
if !strings.Contains(dbCfg.DSNWithTimezone("UTC"), "TimeZone=UTC") {
|
||||
t.Fatalf("DatabaseConfig.DSNWithTimezone() should use provided timezone")
|
||||
}
|
||||
|
||||
redis := RedisConfig{Host: "redis", Port: 6379}
|
||||
if redis.Address() != "redis:6379" {
|
||||
t.Fatalf("RedisConfig.Address() = %q", redis.Address())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeStringSlice(t *testing.T) {
|
||||
values := normalizeStringSlice([]string{" a ", "", "b", " ", "c"})
|
||||
if len(values) != 3 || values[0] != "a" || values[1] != "b" || values[2] != "c" {
|
||||
t.Fatalf("normalizeStringSlice() unexpected result: %#v", values)
|
||||
}
|
||||
if normalizeStringSlice(nil) != nil {
|
||||
t.Fatalf("normalizeStringSlice(nil) expected nil slice")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetServerAddressFromEnv(t *testing.T) {
|
||||
t.Setenv("SERVER_HOST", "127.0.0.1")
|
||||
t.Setenv("SERVER_PORT", "9090")
|
||||
|
||||
address := GetServerAddress()
|
||||
if address != "127.0.0.1:9090" {
|
||||
t.Fatalf("GetServerAddress() = %q", address)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAbsoluteHTTPURL(t *testing.T) {
|
||||
if err := ValidateAbsoluteHTTPURL("https://example.com/path"); err != nil {
|
||||
t.Fatalf("ValidateAbsoluteHTTPURL valid url error: %v", err)
|
||||
}
|
||||
if err := ValidateAbsoluteHTTPURL(""); err == nil {
|
||||
t.Fatalf("ValidateAbsoluteHTTPURL should reject empty url")
|
||||
}
|
||||
if err := ValidateAbsoluteHTTPURL("/relative"); err == nil {
|
||||
t.Fatalf("ValidateAbsoluteHTTPURL should reject relative url")
|
||||
}
|
||||
if err := ValidateAbsoluteHTTPURL("ftp://example.com"); err == nil {
|
||||
t.Fatalf("ValidateAbsoluteHTTPURL should reject ftp scheme")
|
||||
}
|
||||
if err := ValidateAbsoluteHTTPURL("https://example.com/#frag"); err == nil {
|
||||
t.Fatalf("ValidateAbsoluteHTTPURL should reject fragment")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateFrontendRedirectURL(t *testing.T) {
|
||||
if err := ValidateFrontendRedirectURL("/auth/callback"); err != nil {
|
||||
t.Fatalf("ValidateFrontendRedirectURL relative error: %v", err)
|
||||
}
|
||||
if err := ValidateFrontendRedirectURL("https://example.com/auth"); err != nil {
|
||||
t.Fatalf("ValidateFrontendRedirectURL absolute error: %v", err)
|
||||
}
|
||||
if err := ValidateFrontendRedirectURL("example.com/path"); err == nil {
|
||||
t.Fatalf("ValidateFrontendRedirectURL should reject non-absolute url")
|
||||
}
|
||||
if err := ValidateFrontendRedirectURL("//evil.com"); err == nil {
|
||||
t.Fatalf("ValidateFrontendRedirectURL should reject // prefix")
|
||||
}
|
||||
if err := ValidateFrontendRedirectURL("javascript:alert(1)"); err == nil {
|
||||
t.Fatalf("ValidateFrontendRedirectURL should reject javascript scheme")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWarnIfInsecureURL(t *testing.T) {
|
||||
warnIfInsecureURL("test", "http://example.com")
|
||||
warnIfInsecureURL("test", "bad://url")
|
||||
}
|
||||
|
||||
func TestGenerateJWTSecretDefaultLength(t *testing.T) {
|
||||
secret, err := generateJWTSecret(0)
|
||||
if err != nil {
|
||||
t.Fatalf("generateJWTSecret error: %v", err)
|
||||
}
|
||||
if len(secret) == 0 {
|
||||
t.Fatalf("generateJWTSecret returned empty string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOpsCleanupScheduleRequired(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
cfg.Ops.Cleanup.Enabled = true
|
||||
cfg.Ops.Cleanup.Schedule = ""
|
||||
err = cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatalf("Validate() expected error for ops.cleanup.schedule")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "ops.cleanup.schedule") {
|
||||
t.Fatalf("Validate() expected ops.cleanup.schedule error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConcurrencyPingInterval(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
cfg.Concurrency.PingInterval = 3
|
||||
err = cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatalf("Validate() expected error for concurrency.ping_interval")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "concurrency.ping_interval") {
|
||||
t.Fatalf("Validate() expected concurrency.ping_interval error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvideConfig(t *testing.T) {
|
||||
viper.Reset()
|
||||
if _, err := ProvideConfig(); err != nil {
|
||||
t.Fatalf("ProvideConfig() error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConfigWithLinuxDoEnabled(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Security.CSP.Enabled = true
|
||||
cfg.Security.CSP.Policy = "default-src 'self'"
|
||||
|
||||
cfg.LinuxDo.Enabled = true
|
||||
cfg.LinuxDo.ClientID = "client"
|
||||
cfg.LinuxDo.ClientSecret = "secret"
|
||||
cfg.LinuxDo.AuthorizeURL = "https://example.com/oauth2/authorize"
|
||||
cfg.LinuxDo.TokenURL = "https://example.com/oauth2/token"
|
||||
cfg.LinuxDo.UserInfoURL = "https://example.com/oauth2/userinfo"
|
||||
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
|
||||
cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback"
|
||||
cfg.LinuxDo.TokenAuthMethod = "client_secret_post"
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
t.Fatalf("Validate() unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJWTSecretStrength(t *testing.T) {
|
||||
if !isWeakJWTSecret("change-me-in-production") {
|
||||
t.Fatalf("isWeakJWTSecret should detect weak secret")
|
||||
}
|
||||
if isWeakJWTSecret("StrongSecretValue") {
|
||||
t.Fatalf("isWeakJWTSecret should accept strong secret")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateJWTSecretWithLength(t *testing.T) {
|
||||
secret, err := generateJWTSecret(16)
|
||||
if err != nil {
|
||||
t.Fatalf("generateJWTSecret error: %v", err)
|
||||
}
|
||||
if len(secret) == 0 {
|
||||
t.Fatalf("generateJWTSecret returned empty string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAbsoluteHTTPURLMissingHost(t *testing.T) {
|
||||
if err := ValidateAbsoluteHTTPURL("https://"); err == nil {
|
||||
t.Fatalf("ValidateAbsoluteHTTPURL should reject missing host")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateFrontendRedirectURLInvalidChars(t *testing.T) {
|
||||
if err := ValidateFrontendRedirectURL("/auth/\ncallback"); err == nil {
|
||||
t.Fatalf("ValidateFrontendRedirectURL should reject invalid chars")
|
||||
}
|
||||
if err := ValidateFrontendRedirectURL("http://"); err == nil {
|
||||
t.Fatalf("ValidateFrontendRedirectURL should reject missing host")
|
||||
}
|
||||
if err := ValidateFrontendRedirectURL("mailto:user@example.com"); err == nil {
|
||||
t.Fatalf("ValidateFrontendRedirectURL should reject mailto")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWarnIfInsecureURLHTTPS(t *testing.T) {
|
||||
warnIfInsecureURL("secure", "https://example.com")
|
||||
}
|
||||
|
||||
func TestValidateConfigErrors(t *testing.T) {
|
||||
buildValid := func(t *testing.T) *Config {
|
||||
t.Helper()
|
||||
viper.Reset()
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
mutate func(*Config)
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "jwt expire hour positive",
|
||||
mutate: func(c *Config) { c.JWT.ExpireHour = 0 },
|
||||
wantErr: "jwt.expire_hour must be positive",
|
||||
},
|
||||
{
|
||||
name: "jwt expire hour max",
|
||||
mutate: func(c *Config) { c.JWT.ExpireHour = 200 },
|
||||
wantErr: "jwt.expire_hour must be <= 168",
|
||||
},
|
||||
{
|
||||
name: "csp policy required",
|
||||
mutate: func(c *Config) { c.Security.CSP.Enabled = true; c.Security.CSP.Policy = "" },
|
||||
wantErr: "security.csp.policy",
|
||||
},
|
||||
{
|
||||
name: "linuxdo client id required",
|
||||
mutate: func(c *Config) {
|
||||
c.LinuxDo.Enabled = true
|
||||
c.LinuxDo.ClientID = ""
|
||||
},
|
||||
wantErr: "linuxdo_connect.client_id",
|
||||
},
|
||||
{
|
||||
name: "linuxdo token auth method",
|
||||
mutate: func(c *Config) {
|
||||
c.LinuxDo.Enabled = true
|
||||
c.LinuxDo.ClientID = "client"
|
||||
c.LinuxDo.ClientSecret = "secret"
|
||||
c.LinuxDo.AuthorizeURL = "https://example.com/authorize"
|
||||
c.LinuxDo.TokenURL = "https://example.com/token"
|
||||
c.LinuxDo.UserInfoURL = "https://example.com/userinfo"
|
||||
c.LinuxDo.RedirectURL = "https://example.com/callback"
|
||||
c.LinuxDo.FrontendRedirectURL = "/auth/callback"
|
||||
c.LinuxDo.TokenAuthMethod = "invalid"
|
||||
},
|
||||
wantErr: "linuxdo_connect.token_auth_method",
|
||||
},
|
||||
{
|
||||
name: "billing circuit breaker threshold",
|
||||
mutate: func(c *Config) { c.Billing.CircuitBreaker.FailureThreshold = 0 },
|
||||
wantErr: "billing.circuit_breaker.failure_threshold",
|
||||
},
|
||||
{
|
||||
name: "billing circuit breaker reset",
|
||||
mutate: func(c *Config) { c.Billing.CircuitBreaker.ResetTimeoutSeconds = 0 },
|
||||
wantErr: "billing.circuit_breaker.reset_timeout_seconds",
|
||||
},
|
||||
{
|
||||
name: "billing circuit breaker half open",
|
||||
mutate: func(c *Config) { c.Billing.CircuitBreaker.HalfOpenRequests = 0 },
|
||||
wantErr: "billing.circuit_breaker.half_open_requests",
|
||||
},
|
||||
{
|
||||
name: "database max open conns",
|
||||
mutate: func(c *Config) { c.Database.MaxOpenConns = 0 },
|
||||
wantErr: "database.max_open_conns",
|
||||
},
|
||||
{
|
||||
name: "database max lifetime",
|
||||
mutate: func(c *Config) { c.Database.ConnMaxLifetimeMinutes = -1 },
|
||||
wantErr: "database.conn_max_lifetime_minutes",
|
||||
},
|
||||
{
|
||||
name: "database idle exceeds open",
|
||||
mutate: func(c *Config) { c.Database.MaxIdleConns = c.Database.MaxOpenConns + 1 },
|
||||
wantErr: "database.max_idle_conns cannot exceed",
|
||||
},
|
||||
{
|
||||
name: "redis dial timeout",
|
||||
mutate: func(c *Config) { c.Redis.DialTimeoutSeconds = 0 },
|
||||
wantErr: "redis.dial_timeout_seconds",
|
||||
},
|
||||
{
|
||||
name: "redis read timeout",
|
||||
mutate: func(c *Config) { c.Redis.ReadTimeoutSeconds = 0 },
|
||||
wantErr: "redis.read_timeout_seconds",
|
||||
},
|
||||
{
|
||||
name: "redis write timeout",
|
||||
mutate: func(c *Config) { c.Redis.WriteTimeoutSeconds = 0 },
|
||||
wantErr: "redis.write_timeout_seconds",
|
||||
},
|
||||
{
|
||||
name: "redis pool size",
|
||||
mutate: func(c *Config) { c.Redis.PoolSize = 0 },
|
||||
wantErr: "redis.pool_size",
|
||||
},
|
||||
{
|
||||
name: "redis idle exceeds pool",
|
||||
mutate: func(c *Config) { c.Redis.MinIdleConns = c.Redis.PoolSize + 1 },
|
||||
wantErr: "redis.min_idle_conns cannot exceed",
|
||||
},
|
||||
{
|
||||
name: "dashboard cache disabled negative",
|
||||
mutate: func(c *Config) { c.Dashboard.Enabled = false; c.Dashboard.StatsTTLSeconds = -1 },
|
||||
wantErr: "dashboard_cache.stats_ttl_seconds",
|
||||
},
|
||||
{
|
||||
name: "dashboard cache fresh ttl positive",
|
||||
mutate: func(c *Config) { c.Dashboard.Enabled = true; c.Dashboard.StatsFreshTTLSeconds = 0 },
|
||||
wantErr: "dashboard_cache.stats_fresh_ttl_seconds",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation enabled interval",
|
||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.IntervalSeconds = 0 },
|
||||
wantErr: "dashboard_aggregation.interval_seconds",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation backfill positive",
|
||||
mutate: func(c *Config) {
|
||||
c.DashboardAgg.Enabled = true
|
||||
c.DashboardAgg.BackfillEnabled = true
|
||||
c.DashboardAgg.BackfillMaxDays = 0
|
||||
},
|
||||
wantErr: "dashboard_aggregation.backfill_max_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation retention",
|
||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
|
||||
wantErr: "dashboard_aggregation.retention.usage_logs_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation disabled interval",
|
||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
|
||||
wantErr: "dashboard_aggregation.interval_seconds",
|
||||
},
|
||||
{
|
||||
name: "usage cleanup max range",
|
||||
mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.MaxRangeDays = 0 },
|
||||
wantErr: "usage_cleanup.max_range_days",
|
||||
},
|
||||
{
|
||||
name: "usage cleanup worker interval",
|
||||
mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.WorkerIntervalSeconds = 0 },
|
||||
wantErr: "usage_cleanup.worker_interval_seconds",
|
||||
},
|
||||
{
|
||||
name: "usage cleanup batch size",
|
||||
mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.BatchSize = 0 },
|
||||
wantErr: "usage_cleanup.batch_size",
|
||||
},
|
||||
{
|
||||
name: "usage cleanup disabled negative",
|
||||
mutate: func(c *Config) { c.UsageCleanup.Enabled = false; c.UsageCleanup.BatchSize = -1 },
|
||||
wantErr: "usage_cleanup.batch_size",
|
||||
},
|
||||
{
|
||||
name: "gateway max body size",
|
||||
mutate: func(c *Config) { c.Gateway.MaxBodySize = 0 },
|
||||
wantErr: "gateway.max_body_size",
|
||||
},
|
||||
{
|
||||
name: "gateway max idle conns",
|
||||
mutate: func(c *Config) { c.Gateway.MaxIdleConns = 0 },
|
||||
wantErr: "gateway.max_idle_conns",
|
||||
},
|
||||
{
|
||||
name: "gateway max idle conns per host",
|
||||
mutate: func(c *Config) { c.Gateway.MaxIdleConnsPerHost = 0 },
|
||||
wantErr: "gateway.max_idle_conns_per_host",
|
||||
},
|
||||
{
|
||||
name: "gateway idle timeout",
|
||||
mutate: func(c *Config) { c.Gateway.IdleConnTimeoutSeconds = 0 },
|
||||
wantErr: "gateway.idle_conn_timeout_seconds",
|
||||
},
|
||||
{
|
||||
name: "gateway max upstream clients",
|
||||
mutate: func(c *Config) { c.Gateway.MaxUpstreamClients = 0 },
|
||||
wantErr: "gateway.max_upstream_clients",
|
||||
},
|
||||
{
|
||||
name: "gateway client idle ttl",
|
||||
mutate: func(c *Config) { c.Gateway.ClientIdleTTLSeconds = 0 },
|
||||
wantErr: "gateway.client_idle_ttl_seconds",
|
||||
},
|
||||
{
|
||||
name: "gateway concurrency slot ttl",
|
||||
mutate: func(c *Config) { c.Gateway.ConcurrencySlotTTLMinutes = 0 },
|
||||
wantErr: "gateway.concurrency_slot_ttl_minutes",
|
||||
},
|
||||
{
|
||||
name: "gateway max conns per host",
|
||||
mutate: func(c *Config) { c.Gateway.MaxConnsPerHost = -1 },
|
||||
wantErr: "gateway.max_conns_per_host",
|
||||
},
|
||||
{
|
||||
name: "gateway connection isolation",
|
||||
mutate: func(c *Config) { c.Gateway.ConnectionPoolIsolation = "invalid" },
|
||||
wantErr: "gateway.connection_pool_isolation",
|
||||
},
|
||||
{
|
||||
name: "gateway stream keepalive range",
|
||||
mutate: func(c *Config) { c.Gateway.StreamKeepaliveInterval = 4 },
|
||||
wantErr: "gateway.stream_keepalive_interval",
|
||||
},
|
||||
{
|
||||
name: "gateway stream data interval range",
|
||||
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = 5 },
|
||||
wantErr: "gateway.stream_data_interval_timeout",
|
||||
},
|
||||
{
|
||||
name: "gateway stream data interval negative",
|
||||
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = -1 },
|
||||
wantErr: "gateway.stream_data_interval_timeout must be non-negative",
|
||||
},
|
||||
{
|
||||
name: "gateway max line size",
|
||||
mutate: func(c *Config) { c.Gateway.MaxLineSize = 1024 },
|
||||
wantErr: "gateway.max_line_size must be at least",
|
||||
},
|
||||
{
|
||||
name: "gateway max line size negative",
|
||||
mutate: func(c *Config) { c.Gateway.MaxLineSize = -1 },
|
||||
wantErr: "gateway.max_line_size must be non-negative",
|
||||
},
|
||||
{
|
||||
name: "gateway scheduling sticky waiting",
|
||||
mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 },
|
||||
wantErr: "gateway.scheduling.sticky_session_max_waiting",
|
||||
},
|
||||
{
|
||||
name: "gateway scheduling outbox poll",
|
||||
mutate: func(c *Config) { c.Gateway.Scheduling.OutboxPollIntervalSeconds = 0 },
|
||||
wantErr: "gateway.scheduling.outbox_poll_interval_seconds",
|
||||
},
|
||||
{
|
||||
name: "gateway scheduling outbox failures",
|
||||
mutate: func(c *Config) { c.Gateway.Scheduling.OutboxLagRebuildFailures = 0 },
|
||||
wantErr: "gateway.scheduling.outbox_lag_rebuild_failures",
|
||||
},
|
||||
{
|
||||
name: "gateway outbox lag rebuild",
|
||||
mutate: func(c *Config) {
|
||||
c.Gateway.Scheduling.OutboxLagWarnSeconds = 10
|
||||
c.Gateway.Scheduling.OutboxLagRebuildSeconds = 5
|
||||
},
|
||||
wantErr: "gateway.scheduling.outbox_lag_rebuild_seconds",
|
||||
},
|
||||
{
|
||||
name: "ops metrics collector ttl",
|
||||
mutate: func(c *Config) { c.Ops.MetricsCollectorCache.TTL = -1 },
|
||||
wantErr: "ops.metrics_collector_cache.ttl",
|
||||
},
|
||||
{
|
||||
name: "ops cleanup retention",
|
||||
mutate: func(c *Config) { c.Ops.Cleanup.ErrorLogRetentionDays = -1 },
|
||||
wantErr: "ops.cleanup.error_log_retention_days",
|
||||
},
|
||||
{
|
||||
name: "ops cleanup minute retention",
|
||||
mutate: func(c *Config) { c.Ops.Cleanup.MinuteMetricsRetentionDays = -1 },
|
||||
wantErr: "ops.cleanup.minute_metrics_retention_days",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := buildValid(t)
|
||||
tt.mutate(cfg)
|
||||
err := cfg.Validate()
|
||||
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Fatalf("Validate() error = %v, want %q", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,6 +44,8 @@ type AccountHandler struct {
|
||||
accountTestService *service.AccountTestService
|
||||
concurrencyService *service.ConcurrencyService
|
||||
crsSyncService *service.CRSSyncService
|
||||
sessionLimitCache service.SessionLimitCache
|
||||
tokenCacheInvalidator service.TokenCacheInvalidator
|
||||
}
|
||||
|
||||
// NewAccountHandler creates a new admin account handler
|
||||
@@ -58,6 +60,8 @@ func NewAccountHandler(
|
||||
accountTestService *service.AccountTestService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
crsSyncService *service.CRSSyncService,
|
||||
sessionLimitCache service.SessionLimitCache,
|
||||
tokenCacheInvalidator service.TokenCacheInvalidator,
|
||||
) *AccountHandler {
|
||||
return &AccountHandler{
|
||||
adminService: adminService,
|
||||
@@ -70,6 +74,8 @@ func NewAccountHandler(
|
||||
accountTestService: accountTestService,
|
||||
concurrencyService: concurrencyService,
|
||||
crsSyncService: crsSyncService,
|
||||
sessionLimitCache: sessionLimitCache,
|
||||
tokenCacheInvalidator: tokenCacheInvalidator,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,6 +90,7 @@ type CreateAccountRequest struct {
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
||||
@@ -101,6 +108,7 @@ type UpdateAccountRequest struct {
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
@@ -115,6 +123,7 @@ type BulkUpdateAccountsRequest struct {
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
|
||||
Schedulable *bool `json:"schedulable"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
@@ -127,6 +136,9 @@ type BulkUpdateAccountsRequest struct {
|
||||
type AccountWithConcurrency struct {
|
||||
*dto.Account
|
||||
CurrentConcurrency int `json:"current_concurrency"`
|
||||
// 以下字段仅对 Anthropic OAuth/SetupToken 账号有效,且仅在启用相应功能时返回
|
||||
CurrentWindowCost *float64 `json:"current_window_cost,omitempty"` // 当前窗口费用
|
||||
ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数
|
||||
}
|
||||
|
||||
// List handles listing all accounts with pagination
|
||||
@@ -161,13 +173,87 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
concurrencyCounts = make(map[int64]int)
|
||||
}
|
||||
|
||||
// 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
|
||||
windowCostAccountIDs := make([]int64, 0)
|
||||
sessionLimitAccountIDs := make([]int64, 0)
|
||||
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if acc.IsAnthropicOAuthOrSetupToken() {
|
||||
if acc.GetWindowCostLimit() > 0 {
|
||||
windowCostAccountIDs = append(windowCostAccountIDs, acc.ID)
|
||||
}
|
||||
if acc.GetMaxSessions() > 0 {
|
||||
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
|
||||
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 并行获取窗口费用和活跃会话数
|
||||
var windowCosts map[int64]float64
|
||||
var activeSessions map[int64]int
|
||||
|
||||
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
|
||||
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
|
||||
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
|
||||
if activeSessions == nil {
|
||||
activeSessions = make(map[int64]int)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取窗口费用(并行查询)
|
||||
if len(windowCostAccountIDs) > 0 {
|
||||
windowCosts = make(map[int64]float64)
|
||||
var mu sync.Mutex
|
||||
g, gctx := errgroup.WithContext(c.Request.Context())
|
||||
g.SetLimit(10) // 限制并发数
|
||||
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 {
|
||||
continue
|
||||
}
|
||||
accCopy := acc // 闭包捕获
|
||||
g.Go(func() error {
|
||||
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
|
||||
startTime := accCopy.GetCurrentWindowStartTime()
|
||||
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
|
||||
if err == nil && stats != nil {
|
||||
mu.Lock()
|
||||
windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用
|
||||
mu.Unlock()
|
||||
}
|
||||
return nil // 不返回错误,允许部分失败
|
||||
})
|
||||
}
|
||||
_ = g.Wait()
|
||||
}
|
||||
|
||||
// Build response with concurrency info
|
||||
result := make([]AccountWithConcurrency, len(accounts))
|
||||
for i := range accounts {
|
||||
result[i] = AccountWithConcurrency{
|
||||
Account: dto.AccountFromService(&accounts[i]),
|
||||
CurrentConcurrency: concurrencyCounts[accounts[i].ID],
|
||||
acc := &accounts[i]
|
||||
item := AccountWithConcurrency{
|
||||
Account: dto.AccountFromService(acc),
|
||||
CurrentConcurrency: concurrencyCounts[acc.ID],
|
||||
}
|
||||
|
||||
// 添加窗口费用(仅当启用时)
|
||||
if windowCosts != nil {
|
||||
if cost, ok := windowCosts[acc.ID]; ok {
|
||||
item.CurrentWindowCost = &cost
|
||||
}
|
||||
}
|
||||
|
||||
// 添加活跃会话数(仅当启用时)
|
||||
if activeSessions != nil {
|
||||
if count, ok := activeSessions[acc.ID]; ok {
|
||||
item.ActiveSessions = &count
|
||||
}
|
||||
}
|
||||
|
||||
result[i] = item
|
||||
}
|
||||
|
||||
response.Paginated(c, result, total, page, pageSize)
|
||||
@@ -199,6 +285,10 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if req.RateMultiplier != nil && *req.RateMultiplier < 0 {
|
||||
response.BadRequest(c, "rate_multiplier must be >= 0")
|
||||
return
|
||||
}
|
||||
|
||||
// 确定是否跳过混合渠道检查
|
||||
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
||||
@@ -213,6 +303,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
AutoPauseOnExpired: req.AutoPauseOnExpired,
|
||||
@@ -258,6 +349,10 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if req.RateMultiplier != nil && *req.RateMultiplier < 0 {
|
||||
response.BadRequest(c, "rate_multiplier must be >= 0")
|
||||
return
|
||||
}
|
||||
|
||||
// 确定是否跳过混合渠道检查
|
||||
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
||||
@@ -271,6 +366,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency, // 指针类型,nil 表示未提供
|
||||
Priority: req.Priority, // 指针类型,nil 表示未提供
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
Status: req.Status,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
@@ -450,6 +546,41 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
|
||||
// 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
|
||||
if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
|
||||
if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
|
||||
newCredentials["project_id"] = oldProjectID
|
||||
}
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,更新凭证但不标记为 error
|
||||
// LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试
|
||||
if tokenInfo.ProjectIDMissing {
|
||||
// 先更新凭证(token 本身刷新成功了)
|
||||
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
if updateErr != nil {
|
||||
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
|
||||
return
|
||||
}
|
||||
// 不标记为 error,只返回警告信息
|
||||
response.Success(c, gin.H{
|
||||
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
|
||||
"warning": "missing_project_id_temporary",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 成功获取到 project_id,如果之前是 missing_project_id 错误则清除
|
||||
if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
||||
if _, clearErr := h.adminService.ClearAccountError(c.Request.Context(), accountID); clearErr != nil {
|
||||
response.InternalError(c, "Failed to clear account error: "+clearErr.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Use Anthropic/Claude OAuth service to refresh token
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
@@ -485,6 +616,14 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
|
||||
if h.tokenCacheInvalidator != nil {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), updatedAccount); invalidateErr != nil {
|
||||
// 缓存失效失败只记录日志,不影响主流程
|
||||
_ = c.Error(invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(updatedAccount))
|
||||
}
|
||||
|
||||
@@ -534,6 +673,15 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 清除错误后,同时清除 token 缓存,确保下次请求会获取最新的 token(触发刷新或从 DB 读取)
|
||||
// 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题
|
||||
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil {
|
||||
// 缓存失效失败只记录日志,不影响主流程
|
||||
_ = c.Error(invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
}
|
||||
|
||||
@@ -652,6 +800,10 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if req.RateMultiplier != nil && *req.RateMultiplier < 0 {
|
||||
response.BadRequest(c, "rate_multiplier must be >= 0")
|
||||
return
|
||||
}
|
||||
|
||||
// 确定是否跳过混合渠道检查
|
||||
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
||||
@@ -660,6 +812,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
req.ProxyID != nil ||
|
||||
req.Concurrency != nil ||
|
||||
req.Priority != nil ||
|
||||
req.RateMultiplier != nil ||
|
||||
req.Status != "" ||
|
||||
req.Schedulable != nil ||
|
||||
req.GroupIDs != nil ||
|
||||
@@ -677,6 +830,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
Status: req.Status,
|
||||
Schedulable: req.Schedulable,
|
||||
GroupIDs: req.GroupIDs,
|
||||
|
||||
262
backend/internal/handler/admin/admin_basic_handlers_test.go
Normal file
262
backend/internal/handler/admin/admin_basic_handlers_test.go
Normal file
@@ -0,0 +1,262 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func setupAdminRouter() (*gin.Engine, *stubAdminService) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
adminSvc := newStubAdminService()
|
||||
|
||||
userHandler := NewUserHandler(adminSvc)
|
||||
groupHandler := NewGroupHandler(adminSvc)
|
||||
proxyHandler := NewProxyHandler(adminSvc)
|
||||
redeemHandler := NewRedeemHandler(adminSvc)
|
||||
|
||||
router.GET("/api/v1/admin/users", userHandler.List)
|
||||
router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
|
||||
router.POST("/api/v1/admin/users", userHandler.Create)
|
||||
router.PUT("/api/v1/admin/users/:id", userHandler.Update)
|
||||
router.DELETE("/api/v1/admin/users/:id", userHandler.Delete)
|
||||
router.POST("/api/v1/admin/users/:id/balance", userHandler.UpdateBalance)
|
||||
router.GET("/api/v1/admin/users/:id/api-keys", userHandler.GetUserAPIKeys)
|
||||
router.GET("/api/v1/admin/users/:id/usage", userHandler.GetUserUsage)
|
||||
|
||||
router.GET("/api/v1/admin/groups", groupHandler.List)
|
||||
router.GET("/api/v1/admin/groups/all", groupHandler.GetAll)
|
||||
router.GET("/api/v1/admin/groups/:id", groupHandler.GetByID)
|
||||
router.POST("/api/v1/admin/groups", groupHandler.Create)
|
||||
router.PUT("/api/v1/admin/groups/:id", groupHandler.Update)
|
||||
router.DELETE("/api/v1/admin/groups/:id", groupHandler.Delete)
|
||||
router.GET("/api/v1/admin/groups/:id/stats", groupHandler.GetStats)
|
||||
router.GET("/api/v1/admin/groups/:id/api-keys", groupHandler.GetGroupAPIKeys)
|
||||
|
||||
router.GET("/api/v1/admin/proxies", proxyHandler.List)
|
||||
router.GET("/api/v1/admin/proxies/all", proxyHandler.GetAll)
|
||||
router.GET("/api/v1/admin/proxies/:id", proxyHandler.GetByID)
|
||||
router.POST("/api/v1/admin/proxies", proxyHandler.Create)
|
||||
router.PUT("/api/v1/admin/proxies/:id", proxyHandler.Update)
|
||||
router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete)
|
||||
router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete)
|
||||
router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test)
|
||||
router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats)
|
||||
router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts)
|
||||
|
||||
router.GET("/api/v1/admin/redeem-codes", redeemHandler.List)
|
||||
router.GET("/api/v1/admin/redeem-codes/:id", redeemHandler.GetByID)
|
||||
router.POST("/api/v1/admin/redeem-codes", redeemHandler.Generate)
|
||||
router.DELETE("/api/v1/admin/redeem-codes/:id", redeemHandler.Delete)
|
||||
router.POST("/api/v1/admin/redeem-codes/batch-delete", redeemHandler.BatchDelete)
|
||||
router.POST("/api/v1/admin/redeem-codes/:id/expire", redeemHandler.Expire)
|
||||
router.GET("/api/v1/admin/redeem-codes/:id/stats", redeemHandler.GetStats)
|
||||
|
||||
return router, adminSvc
|
||||
}
|
||||
|
||||
func TestUserHandlerEndpoints(t *testing.T) {
|
||||
router, _ := setupAdminRouter()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/users?page=1&page_size=20", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2}
|
||||
body, _ := json.Marshal(createBody)
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
updateBody := map[string]any{"email": "updated@example.com"}
|
||||
body, _ = json.Marshal(updateBody)
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/users/1", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/users/1", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/balance", bytes.NewBufferString(`{"balance":1,"operation":"add"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1/api-keys", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1/usage?period=today", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestGroupHandlerEndpoints(t *testing.T) {
|
||||
router, _ := setupAdminRouter()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/all", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{"name": "new", "platform": "anthropic", "subscription_type": "standard"})
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/groups", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
body, _ = json.Marshal(map[string]any{"name": "update"})
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/groups/2", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/groups/2", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2/stats", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2/api-keys", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestProxyHandlerEndpoints(t *testing.T) {
|
||||
router, _ := setupAdminRouter()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/all", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{"name": "proxy", "protocol": "http", "host": "localhost", "port": 8080})
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
body, _ = json.Marshal(map[string]any{"name": "proxy2"})
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/proxies/4", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/proxies/4", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/batch-delete", bytes.NewBufferString(`{"ids":[1,2]}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/test", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/accounts", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestRedeemHandlerEndpoints(t *testing.T) {
|
||||
router, _ := setupAdminRouter()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/5", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{"count": 1, "type": "balance", "value": 10})
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/redeem-codes/5", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/batch-delete", bytes.NewBufferString(`{"ids":[1,2]}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/5/expire", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/5/stats", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
134
backend/internal/handler/admin/admin_helpers_test.go
Normal file
134
backend/internal/handler/admin/admin_helpers_test.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseTimeRange(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
req := httptest.NewRequest(http.MethodGet, "/?start_date=2024-01-01&end_date=2024-01-02&timezone=UTC", nil)
|
||||
c.Request = req
|
||||
|
||||
start, end := parseTimeRange(c)
|
||||
require.Equal(t, time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), start)
|
||||
require.Equal(t, time.Date(2024, 1, 3, 0, 0, 0, 0, time.UTC), end)
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/?start_date=bad&timezone=UTC", nil)
|
||||
c.Request = req
|
||||
start, end = parseTimeRange(c)
|
||||
require.False(t, start.IsZero())
|
||||
require.False(t, end.IsZero())
|
||||
}
|
||||
|
||||
func TestParseOpsViewParam(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/?view=excluded", nil)
|
||||
require.Equal(t, opsListViewExcluded, parseOpsViewParam(c))
|
||||
|
||||
c2, _ := gin.CreateTestContext(w)
|
||||
c2.Request = httptest.NewRequest(http.MethodGet, "/?view=all", nil)
|
||||
require.Equal(t, opsListViewAll, parseOpsViewParam(c2))
|
||||
|
||||
c3, _ := gin.CreateTestContext(w)
|
||||
c3.Request = httptest.NewRequest(http.MethodGet, "/?view=unknown", nil)
|
||||
require.Equal(t, opsListViewErrors, parseOpsViewParam(c3))
|
||||
|
||||
require.Equal(t, "", parseOpsViewParam(nil))
|
||||
}
|
||||
|
||||
func TestParseOpsDuration(t *testing.T) {
|
||||
dur, ok := parseOpsDuration("1h")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, time.Hour, dur)
|
||||
|
||||
_, ok = parseOpsDuration("invalid")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestParseOpsTimeRange(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
now := time.Now().UTC()
|
||||
startStr := now.Add(-time.Hour).Format(time.RFC3339)
|
||||
endStr := now.Format(time.RFC3339)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/?start_time="+startStr+"&end_time="+endStr, nil)
|
||||
start, end, err := parseOpsTimeRange(c, "1h")
|
||||
require.NoError(t, err)
|
||||
require.True(t, start.Before(end))
|
||||
|
||||
c2, _ := gin.CreateTestContext(w)
|
||||
c2.Request = httptest.NewRequest(http.MethodGet, "/?start_time=bad", nil)
|
||||
_, _, err = parseOpsTimeRange(c2, "1h")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseOpsRealtimeWindow(t *testing.T) {
|
||||
dur, label, ok := parseOpsRealtimeWindow("5m")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 5*time.Minute, dur)
|
||||
require.Equal(t, "5min", label)
|
||||
|
||||
_, _, ok = parseOpsRealtimeWindow("invalid")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestPickThroughputBucketSeconds(t *testing.T) {
|
||||
require.Equal(t, 60, pickThroughputBucketSeconds(30*time.Minute))
|
||||
require.Equal(t, 300, pickThroughputBucketSeconds(6*time.Hour))
|
||||
require.Equal(t, 3600, pickThroughputBucketSeconds(48*time.Hour))
|
||||
}
|
||||
|
||||
func TestParseOpsQueryMode(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/?mode=raw", nil)
|
||||
require.Equal(t, service.ParseOpsQueryMode("raw"), parseOpsQueryMode(c))
|
||||
require.Equal(t, service.OpsQueryMode(""), parseOpsQueryMode(nil))
|
||||
}
|
||||
|
||||
func TestOpsAlertRuleValidation(t *testing.T) {
|
||||
raw := map[string]json.RawMessage{
|
||||
"name": json.RawMessage(`"High error rate"`),
|
||||
"metric_type": json.RawMessage(`"error_rate"`),
|
||||
"operator": json.RawMessage(`">"`),
|
||||
"threshold": json.RawMessage(`90`),
|
||||
}
|
||||
|
||||
validated, err := validateOpsAlertRulePayload(raw)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "High error rate", validated.Name)
|
||||
|
||||
_, err = validateOpsAlertRulePayload(map[string]json.RawMessage{})
|
||||
require.Error(t, err)
|
||||
|
||||
require.True(t, isPercentOrRateMetric("error_rate"))
|
||||
require.False(t, isPercentOrRateMetric("concurrency_queue_depth"))
|
||||
}
|
||||
|
||||
func TestOpsWSHelpers(t *testing.T) {
|
||||
prefixes, invalid := parseTrustedProxyList("10.0.0.0/8,invalid")
|
||||
require.Len(t, prefixes, 1)
|
||||
require.Len(t, invalid, 1)
|
||||
|
||||
host := hostWithoutPort("example.com:443")
|
||||
require.Equal(t, "example.com", host)
|
||||
|
||||
addr := netip.MustParseAddr("10.0.0.1")
|
||||
require.True(t, isAddrInTrustedProxies(addr, prefixes))
|
||||
require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes))
|
||||
}
|
||||
294
backend/internal/handler/admin/admin_service_stub_test.go
Normal file
294
backend/internal/handler/admin/admin_service_stub_test.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type stubAdminService struct {
|
||||
users []service.User
|
||||
apiKeys []service.APIKey
|
||||
groups []service.Group
|
||||
accounts []service.Account
|
||||
proxies []service.Proxy
|
||||
proxyCounts []service.ProxyWithAccountCount
|
||||
redeems []service.RedeemCode
|
||||
}
|
||||
|
||||
func newStubAdminService() *stubAdminService {
|
||||
now := time.Now().UTC()
|
||||
user := service.User{
|
||||
ID: 1,
|
||||
Email: "user@example.com",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
apiKey := service.APIKey{
|
||||
ID: 10,
|
||||
UserID: user.ID,
|
||||
Key: "sk-test",
|
||||
Name: "test",
|
||||
Status: service.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
group := service.Group{
|
||||
ID: 2,
|
||||
Name: "group",
|
||||
Platform: service.PlatformAnthropic,
|
||||
Status: service.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
account := service.Account{
|
||||
ID: 3,
|
||||
Name: "account",
|
||||
Platform: service.PlatformAnthropic,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
proxy := service.Proxy{
|
||||
ID: 4,
|
||||
Name: "proxy",
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Status: service.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
redeem := service.RedeemCode{
|
||||
ID: 5,
|
||||
Code: "R-TEST",
|
||||
Type: service.RedeemTypeBalance,
|
||||
Value: 10,
|
||||
Status: service.StatusUnused,
|
||||
CreatedAt: now,
|
||||
}
|
||||
return &stubAdminService{
|
||||
users: []service.User{user},
|
||||
apiKeys: []service.APIKey{apiKey},
|
||||
groups: []service.Group{group},
|
||||
accounts: []service.Account{account},
|
||||
proxies: []service.Proxy{proxy},
|
||||
proxyCounts: []service.ProxyWithAccountCount{{Proxy: proxy, AccountCount: 1}},
|
||||
redeems: []service.RedeemCode{redeem},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters) ([]service.User, int64, error) {
|
||||
return s.users, int64(len(s.users)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetUser(ctx context.Context, id int64) (*service.User, error) {
|
||||
for i := range s.users {
|
||||
if s.users[i].ID == id {
|
||||
return &s.users[i], nil
|
||||
}
|
||||
}
|
||||
user := service.User{ID: id, Email: "user@example.com", Status: service.StatusActive}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CreateUser(ctx context.Context, input *service.CreateUserInput) (*service.User, error) {
|
||||
user := service.User{ID: 100, Email: input.Email, Status: service.StatusActive}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) UpdateUser(ctx context.Context, id int64, input *service.UpdateUserInput) (*service.User, error) {
|
||||
user := service.User{ID: id, Email: "updated@example.com", Status: service.StatusActive}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) DeleteUser(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*service.User, error) {
|
||||
user := service.User{ID: userID, Balance: balance, Status: service.StatusActive}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]service.APIKey, int64, error) {
|
||||
return s.apiKeys, int64(len(s.apiKeys)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
|
||||
return map[string]any{"user_id": userID}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]service.Group, int64, error) {
|
||||
return s.groups, int64(len(s.groups)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetAllGroups(ctx context.Context) ([]service.Group, error) {
|
||||
return s.groups, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
|
||||
return s.groups, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetGroup(ctx context.Context, id int64) (*service.Group, error) {
|
||||
group := service.Group{ID: id, Name: "group", Status: service.StatusActive}
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CreateGroup(ctx context.Context, input *service.CreateGroupInput) (*service.Group, error) {
|
||||
group := service.Group{ID: 200, Name: input.Name, Status: service.StatusActive}
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) UpdateGroup(ctx context.Context, id int64, input *service.UpdateGroupInput) (*service.Group, error) {
|
||||
group := service.Group{ID: id, Name: input.Name, Status: service.StatusActive}
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) DeleteGroup(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]service.APIKey, int64, error) {
|
||||
return s.apiKeys, int64(len(s.apiKeys)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]service.Account, int64, error) {
|
||||
return s.accounts, int64(len(s.accounts)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetAccount(ctx context.Context, id int64) (*service.Account, error) {
|
||||
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
|
||||
out := make([]*service.Account, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
|
||||
out = append(out, &account)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.CreateAccountInput) (*service.Account, error) {
|
||||
account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
|
||||
account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) DeleteAccount(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) RefreshAccountCredentials(ctx context.Context, id int64) (*service.Account, error) {
|
||||
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ClearAccountError(ctx context.Context, id int64) (*service.Account, error) {
|
||||
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) SetAccountError(ctx context.Context, id int64, errorMsg string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*service.Account, error) {
|
||||
account := service.Account{ID: id, Name: "account", Status: service.StatusActive, Schedulable: schedulable}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *service.BulkUpdateAccountsInput) (*service.BulkUpdateAccountsResult, error) {
|
||||
return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
|
||||
return s.proxies, int64(len(s.proxies)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.ProxyWithAccountCount, int64, error) {
|
||||
return s.proxyCounts, int64(len(s.proxyCounts)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetAllProxies(ctx context.Context) ([]service.Proxy, error) {
|
||||
return s.proxies, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetAllProxiesWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) {
|
||||
return s.proxyCounts, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetProxy(ctx context.Context, id int64) (*service.Proxy, error) {
|
||||
proxy := service.Proxy{ID: id, Name: "proxy", Status: service.StatusActive}
|
||||
return &proxy, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CreateProxy(ctx context.Context, input *service.CreateProxyInput) (*service.Proxy, error) {
|
||||
proxy := service.Proxy{ID: 400, Name: input.Name, Status: service.StatusActive}
|
||||
return &proxy, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) UpdateProxy(ctx context.Context, id int64, input *service.UpdateProxyInput) (*service.Proxy, error) {
|
||||
proxy := service.Proxy{ID: id, Name: input.Name, Status: service.StatusActive}
|
||||
return &proxy, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) DeleteProxy(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) BatchDeleteProxies(ctx context.Context, ids []int64) (*service.ProxyBatchDeleteResult, error) {
|
||||
return &service.ProxyBatchDeleteResult{DeletedIDs: ids}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetProxyAccounts(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
|
||||
return []service.ProxyAccountSummary{{ID: 1, Name: "account"}}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.ProxyTestResult, error) {
|
||||
return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) {
|
||||
return s.redeems, int64(len(s.redeems)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetRedeemCode(ctx context.Context, id int64) (*service.RedeemCode, error) {
|
||||
code := service.RedeemCode{ID: id, Code: "R-TEST", Status: service.StatusUnused}
|
||||
return &code, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GenerateRedeemCodes(ctx context.Context, input *service.GenerateRedeemCodesInput) ([]service.RedeemCode, error) {
|
||||
return s.redeems, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) DeleteRedeemCode(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) {
|
||||
return int64(len(ids)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ExpireRedeemCode(ctx context.Context, id int64) (*service.RedeemCode, error) {
|
||||
code := service.RedeemCode{ID: id, Code: "R-TEST", Status: service.StatusUsed}
|
||||
return &code, nil
|
||||
}
|
||||
|
||||
// Ensure stub implements interface.
|
||||
var _ service.AdminService = (*stubAdminService)(nil)
|
||||
@@ -186,13 +186,17 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
|
||||
|
||||
// GetUsageTrend handles getting usage trend data
|
||||
// GET /api/v1/admin/dashboard/trend
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type
|
||||
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
|
||||
// Parse optional filter params
|
||||
var userID, apiKeyID int64
|
||||
var userID, apiKeyID, accountID, groupID int64
|
||||
var model string
|
||||
var stream *bool
|
||||
var billingType *int8
|
||||
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
||||
userID = id
|
||||
@@ -203,8 +207,35 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
apiKeyID = id
|
||||
}
|
||||
}
|
||||
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
|
||||
if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil {
|
||||
accountID = id
|
||||
}
|
||||
}
|
||||
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
|
||||
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
|
||||
groupID = id
|
||||
}
|
||||
}
|
||||
if modelStr := c.Query("model"); modelStr != "" {
|
||||
model = modelStr
|
||||
}
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
|
||||
stream = &streamVal
|
||||
}
|
||||
}
|
||||
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
||||
if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
|
||||
bt := int8(v)
|
||||
billingType = &bt
|
||||
} else {
|
||||
response.BadRequest(c, "Invalid billing_type")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID)
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
@@ -220,12 +251,15 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
|
||||
// GetModelStats handles getting model usage statistics
|
||||
// GET /api/v1/admin/dashboard/models
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type
|
||||
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
|
||||
// Parse optional filter params
|
||||
var userID, apiKeyID int64
|
||||
var userID, apiKeyID, accountID, groupID int64
|
||||
var stream *bool
|
||||
var billingType *int8
|
||||
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
||||
userID = id
|
||||
@@ -236,8 +270,32 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
apiKeyID = id
|
||||
}
|
||||
}
|
||||
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
|
||||
if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil {
|
||||
accountID = id
|
||||
}
|
||||
}
|
||||
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
|
||||
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
|
||||
groupID = id
|
||||
}
|
||||
}
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
|
||||
stream = &streamVal
|
||||
}
|
||||
}
|
||||
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
||||
if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
|
||||
bt := int8(v)
|
||||
billingType = &bt
|
||||
} else {
|
||||
response.BadRequest(c, "Invalid billing_type")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID)
|
||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
|
||||
@@ -40,6 +40,9 @@ type CreateGroupRequest struct {
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
}
|
||||
|
||||
// UpdateGroupRequest represents update group request
|
||||
@@ -60,6 +63,9 @@ type UpdateGroupRequest struct {
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
ClaudeCodeOnly *bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
|
||||
}
|
||||
|
||||
// List handles listing all groups with pagination
|
||||
@@ -88,9 +94,9 @@ func (h *GroupHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
outGroups := make([]dto.Group, 0, len(groups))
|
||||
outGroups := make([]dto.AdminGroup, 0, len(groups))
|
||||
for i := range groups {
|
||||
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
|
||||
outGroups = append(outGroups, *dto.GroupFromServiceAdmin(&groups[i]))
|
||||
}
|
||||
response.Paginated(c, outGroups, total, page, pageSize)
|
||||
}
|
||||
@@ -114,9 +120,9 @@ func (h *GroupHandler) GetAll(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
outGroups := make([]dto.Group, 0, len(groups))
|
||||
outGroups := make([]dto.AdminGroup, 0, len(groups))
|
||||
for i := range groups {
|
||||
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
|
||||
outGroups = append(outGroups, *dto.GroupFromServiceAdmin(&groups[i]))
|
||||
}
|
||||
response.Success(c, outGroups)
|
||||
}
|
||||
@@ -136,7 +142,7 @@ func (h *GroupHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.GroupFromService(group))
|
||||
response.Success(c, dto.GroupFromServiceAdmin(group))
|
||||
}
|
||||
|
||||
// Create handles creating a new group
|
||||
@@ -149,27 +155,29 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: req.Platform,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: req.Platform,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
ModelRouting: req.ModelRouting,
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.GroupFromService(group))
|
||||
response.Success(c, dto.GroupFromServiceAdmin(group))
|
||||
}
|
||||
|
||||
// Update handles updating a group
|
||||
@@ -188,28 +196,30 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
}
|
||||
|
||||
group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: req.Platform,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: req.Status,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: req.Platform,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: req.Status,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
ModelRouting: req.ModelRouting,
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.GroupFromService(group))
|
||||
response.Success(c, dto.GroupFromServiceAdmin(group))
|
||||
}
|
||||
|
||||
// Delete handles deleting a group
|
||||
|
||||
@@ -7,8 +7,10 @@ import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
@@ -18,8 +20,6 @@ var validOpsAlertMetricTypes = []string{
|
||||
"success_rate",
|
||||
"error_rate",
|
||||
"upstream_error_rate",
|
||||
"p95_latency_ms",
|
||||
"p99_latency_ms",
|
||||
"cpu_usage_percent",
|
||||
"memory_usage_percent",
|
||||
"concurrency_queue_depth",
|
||||
@@ -372,8 +372,135 @@ func (h *OpsHandler) DeleteAlertRule(c *gin.Context) {
|
||||
response.Success(c, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
// GetAlertEvent returns a single ops alert event.
|
||||
// GET /api/v1/admin/ops/alert-events/:id
|
||||
func (h *OpsHandler) GetAlertEvent(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid event ID")
|
||||
return
|
||||
}
|
||||
|
||||
ev, err := h.opsService.GetAlertEventByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, ev)
|
||||
}
|
||||
|
||||
// UpdateAlertEventStatus updates an ops alert event status.
|
||||
// PUT /api/v1/admin/ops/alert-events/:id/status
|
||||
func (h *OpsHandler) UpdateAlertEventStatus(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid event ID")
|
||||
return
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||
response.BadRequest(c, "Invalid request body")
|
||||
return
|
||||
}
|
||||
payload.Status = strings.TrimSpace(payload.Status)
|
||||
if payload.Status == "" {
|
||||
response.BadRequest(c, "Invalid status")
|
||||
return
|
||||
}
|
||||
if payload.Status != service.OpsAlertStatusResolved && payload.Status != service.OpsAlertStatusManualResolved {
|
||||
response.BadRequest(c, "Invalid status")
|
||||
return
|
||||
}
|
||||
|
||||
var resolvedAt *time.Time
|
||||
if payload.Status == service.OpsAlertStatusResolved || payload.Status == service.OpsAlertStatusManualResolved {
|
||||
now := time.Now().UTC()
|
||||
resolvedAt = &now
|
||||
}
|
||||
if err := h.opsService.UpdateAlertEventStatus(c.Request.Context(), id, payload.Status, resolvedAt); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"updated": true})
|
||||
}
|
||||
|
||||
// ListAlertEvents lists recent ops alert events.
|
||||
// GET /api/v1/admin/ops/alert-events
|
||||
// CreateAlertSilence creates a scoped silence for ops alerts.
|
||||
// POST /api/v1/admin/ops/alert-silences
|
||||
func (h *OpsHandler) CreateAlertSilence(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
RuleID int64 `json:"rule_id"`
|
||||
Platform string `json:"platform"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Region *string `json:"region"`
|
||||
Until string `json:"until"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||
response.BadRequest(c, "Invalid request body")
|
||||
return
|
||||
}
|
||||
until, err := time.Parse(time.RFC3339, strings.TrimSpace(payload.Until))
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid until")
|
||||
return
|
||||
}
|
||||
|
||||
createdBy := (*int64)(nil)
|
||||
if subject, ok := middleware.GetAuthSubjectFromContext(c); ok {
|
||||
uid := subject.UserID
|
||||
createdBy = &uid
|
||||
}
|
||||
|
||||
silence := &service.OpsAlertSilence{
|
||||
RuleID: payload.RuleID,
|
||||
Platform: strings.TrimSpace(payload.Platform),
|
||||
GroupID: payload.GroupID,
|
||||
Region: payload.Region,
|
||||
Until: until,
|
||||
Reason: strings.TrimSpace(payload.Reason),
|
||||
CreatedBy: createdBy,
|
||||
}
|
||||
|
||||
created, err := h.opsService.CreateAlertSilence(c.Request.Context(), silence)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, created)
|
||||
}
|
||||
|
||||
func (h *OpsHandler) ListAlertEvents(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
@@ -384,7 +511,7 @@ func (h *OpsHandler) ListAlertEvents(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
limit := 100
|
||||
limit := 20
|
||||
if raw := strings.TrimSpace(c.Query("limit")); raw != "" {
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err != nil || n <= 0 {
|
||||
@@ -400,6 +527,49 @@ func (h *OpsHandler) ListAlertEvents(c *gin.Context) {
|
||||
Severity: strings.TrimSpace(c.Query("severity")),
|
||||
}
|
||||
|
||||
if v := strings.TrimSpace(c.Query("email_sent")); v != "" {
|
||||
vv := strings.ToLower(v)
|
||||
switch vv {
|
||||
case "true", "1":
|
||||
b := true
|
||||
filter.EmailSent = &b
|
||||
case "false", "0":
|
||||
b := false
|
||||
filter.EmailSent = &b
|
||||
default:
|
||||
response.BadRequest(c, "Invalid email_sent")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Cursor pagination: both params must be provided together.
|
||||
rawTS := strings.TrimSpace(c.Query("before_fired_at"))
|
||||
rawID := strings.TrimSpace(c.Query("before_id"))
|
||||
if (rawTS == "") != (rawID == "") {
|
||||
response.BadRequest(c, "before_fired_at and before_id must be provided together")
|
||||
return
|
||||
}
|
||||
if rawTS != "" {
|
||||
ts, err := time.Parse(time.RFC3339Nano, rawTS)
|
||||
if err != nil {
|
||||
if t2, err2 := time.Parse(time.RFC3339, rawTS); err2 == nil {
|
||||
ts = t2
|
||||
} else {
|
||||
response.BadRequest(c, "Invalid before_fired_at")
|
||||
return
|
||||
}
|
||||
}
|
||||
filter.BeforeFiredAt = &ts
|
||||
}
|
||||
if rawID != "" {
|
||||
id, err := strconv.ParseInt(rawID, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid before_id")
|
||||
return
|
||||
}
|
||||
filter.BeforeID = &id
|
||||
}
|
||||
|
||||
// Optional global filter support (platform/group/time range).
|
||||
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
|
||||
filter.Platform = platform
|
||||
|
||||
@@ -19,6 +19,57 @@ type OpsHandler struct {
|
||||
opsService *service.OpsService
|
||||
}
|
||||
|
||||
// GetErrorLogByID returns ops error log detail.
|
||||
// GET /api/v1/admin/ops/errors/:id
|
||||
func (h *OpsHandler) GetErrorLogByID(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
idStr := strings.TrimSpace(c.Param("id"))
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid error id")
|
||||
return
|
||||
}
|
||||
|
||||
detail, err := h.opsService.GetErrorLogByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, detail)
|
||||
}
|
||||
|
||||
const (
|
||||
opsListViewErrors = "errors"
|
||||
opsListViewExcluded = "excluded"
|
||||
opsListViewAll = "all"
|
||||
)
|
||||
|
||||
func parseOpsViewParam(c *gin.Context) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
v := strings.ToLower(strings.TrimSpace(c.Query("view")))
|
||||
switch v {
|
||||
case "", opsListViewErrors:
|
||||
return opsListViewErrors
|
||||
case opsListViewExcluded:
|
||||
return opsListViewExcluded
|
||||
case opsListViewAll:
|
||||
return opsListViewAll
|
||||
default:
|
||||
return opsListViewErrors
|
||||
}
|
||||
}
|
||||
|
||||
func NewOpsHandler(opsService *service.OpsService) *OpsHandler {
|
||||
return &OpsHandler{opsService: opsService}
|
||||
}
|
||||
@@ -47,16 +98,26 @@ func (h *OpsHandler) GetErrorLogs(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
filter := &service.OpsErrorLogFilter{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}
|
||||
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
|
||||
|
||||
if !startTime.IsZero() {
|
||||
filter.StartTime = &startTime
|
||||
}
|
||||
if !endTime.IsZero() {
|
||||
filter.EndTime = &endTime
|
||||
}
|
||||
filter.View = parseOpsViewParam(c)
|
||||
filter.Phase = strings.TrimSpace(c.Query("phase"))
|
||||
filter.Owner = strings.TrimSpace(c.Query("error_owner"))
|
||||
filter.Source = strings.TrimSpace(c.Query("error_source"))
|
||||
filter.Query = strings.TrimSpace(c.Query("q"))
|
||||
filter.UserQuery = strings.TrimSpace(c.Query("user_query"))
|
||||
|
||||
// Force request errors: client-visible status >= 400.
|
||||
// buildOpsErrorLogsWhere already applies this for non-upstream phase.
|
||||
if strings.EqualFold(strings.TrimSpace(filter.Phase), "upstream") {
|
||||
filter.Phase = ""
|
||||
}
|
||||
|
||||
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
|
||||
filter.Platform = platform
|
||||
@@ -77,11 +138,19 @@ func (h *OpsHandler) GetErrorLogs(c *gin.Context) {
|
||||
}
|
||||
filter.AccountID = &id
|
||||
}
|
||||
if phase := strings.TrimSpace(c.Query("phase")); phase != "" {
|
||||
filter.Phase = phase
|
||||
}
|
||||
if q := strings.TrimSpace(c.Query("q")); q != "" {
|
||||
filter.Query = q
|
||||
|
||||
if v := strings.TrimSpace(c.Query("resolved")); v != "" {
|
||||
switch strings.ToLower(v) {
|
||||
case "1", "true", "yes":
|
||||
b := true
|
||||
filter.Resolved = &b
|
||||
case "0", "false", "no":
|
||||
b := false
|
||||
filter.Resolved = &b
|
||||
default:
|
||||
response.BadRequest(c, "Invalid resolved")
|
||||
return
|
||||
}
|
||||
}
|
||||
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
|
||||
parts := strings.Split(statusCodesStr, ",")
|
||||
@@ -106,13 +175,120 @@ func (h *OpsHandler) GetErrorLogs(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
|
||||
}
|
||||
|
||||
// GetErrorLogByID returns a single error log detail.
|
||||
// GET /api/v1/admin/ops/errors/:id
|
||||
func (h *OpsHandler) GetErrorLogByID(c *gin.Context) {
|
||||
// ListRequestErrors lists client-visible request errors.
|
||||
// GET /api/v1/admin/ops/request-errors
|
||||
func (h *OpsHandler) ListRequestErrors(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
if pageSize > 500 {
|
||||
pageSize = 500
|
||||
}
|
||||
startTime, endTime, err := parseOpsTimeRange(c, "1h")
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
|
||||
if !startTime.IsZero() {
|
||||
filter.StartTime = &startTime
|
||||
}
|
||||
if !endTime.IsZero() {
|
||||
filter.EndTime = &endTime
|
||||
}
|
||||
filter.View = parseOpsViewParam(c)
|
||||
filter.Phase = strings.TrimSpace(c.Query("phase"))
|
||||
filter.Owner = strings.TrimSpace(c.Query("error_owner"))
|
||||
filter.Source = strings.TrimSpace(c.Query("error_source"))
|
||||
filter.Query = strings.TrimSpace(c.Query("q"))
|
||||
filter.UserQuery = strings.TrimSpace(c.Query("user_query"))
|
||||
|
||||
// Force request errors: client-visible status >= 400.
|
||||
// buildOpsErrorLogsWhere already applies this for non-upstream phase.
|
||||
if strings.EqualFold(strings.TrimSpace(filter.Phase), "upstream") {
|
||||
filter.Phase = ""
|
||||
}
|
||||
|
||||
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
|
||||
filter.Platform = platform
|
||||
}
|
||||
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
|
||||
id, err := strconv.ParseInt(v, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid group_id")
|
||||
return
|
||||
}
|
||||
filter.GroupID = &id
|
||||
}
|
||||
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
|
||||
id, err := strconv.ParseInt(v, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid account_id")
|
||||
return
|
||||
}
|
||||
filter.AccountID = &id
|
||||
}
|
||||
|
||||
if v := strings.TrimSpace(c.Query("resolved")); v != "" {
|
||||
switch strings.ToLower(v) {
|
||||
case "1", "true", "yes":
|
||||
b := true
|
||||
filter.Resolved = &b
|
||||
case "0", "false", "no":
|
||||
b := false
|
||||
filter.Resolved = &b
|
||||
default:
|
||||
response.BadRequest(c, "Invalid resolved")
|
||||
return
|
||||
}
|
||||
}
|
||||
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
|
||||
parts := strings.Split(statusCodesStr, ",")
|
||||
out := make([]int, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
p := strings.TrimSpace(part)
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.Atoi(p)
|
||||
if err != nil || n < 0 {
|
||||
response.BadRequest(c, "Invalid status_codes")
|
||||
return
|
||||
}
|
||||
out = append(out, n)
|
||||
}
|
||||
filter.StatusCodes = out
|
||||
}
|
||||
|
||||
result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
|
||||
}
|
||||
|
||||
// GetRequestError returns request error detail.
|
||||
// GET /api/v1/admin/ops/request-errors/:id
|
||||
func (h *OpsHandler) GetRequestError(c *gin.Context) {
|
||||
// same storage; just proxy to existing detail
|
||||
h.GetErrorLogByID(c)
|
||||
}
|
||||
|
||||
// ListRequestErrorUpstreamErrors lists upstream error logs correlated to a request error.
|
||||
// GET /api/v1/admin/ops/request-errors/:id/upstream-errors
|
||||
func (h *OpsHandler) ListRequestErrorUpstreamErrors(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
@@ -129,15 +305,306 @@ func (h *OpsHandler) GetErrorLogByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Load request error to get correlation keys.
|
||||
detail, err := h.opsService.GetErrorLogByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, detail)
|
||||
// Correlate by request_id/client_request_id.
|
||||
requestID := strings.TrimSpace(detail.RequestID)
|
||||
clientRequestID := strings.TrimSpace(detail.ClientRequestID)
|
||||
if requestID == "" && clientRequestID == "" {
|
||||
response.Paginated(c, []*service.OpsErrorLog{}, 0, 1, 10)
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
if pageSize > 500 {
|
||||
pageSize = 500
|
||||
}
|
||||
|
||||
// Keep correlation window wide enough so linked upstream errors
|
||||
// are discoverable even when UI defaults to 1h elsewhere.
|
||||
startTime, endTime, err := parseOpsTimeRange(c, "30d")
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
|
||||
if !startTime.IsZero() {
|
||||
filter.StartTime = &startTime
|
||||
}
|
||||
if !endTime.IsZero() {
|
||||
filter.EndTime = &endTime
|
||||
}
|
||||
filter.View = "all"
|
||||
filter.Phase = "upstream"
|
||||
filter.Owner = "provider"
|
||||
filter.Source = strings.TrimSpace(c.Query("error_source"))
|
||||
filter.Query = strings.TrimSpace(c.Query("q"))
|
||||
|
||||
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
|
||||
filter.Platform = platform
|
||||
}
|
||||
|
||||
// Prefer exact match on request_id; if missing, fall back to client_request_id.
|
||||
if requestID != "" {
|
||||
filter.RequestID = requestID
|
||||
} else {
|
||||
filter.ClientRequestID = clientRequestID
|
||||
}
|
||||
|
||||
result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// If client asks for details, expand each upstream error log to include upstream response fields.
|
||||
includeDetail := strings.TrimSpace(c.Query("include_detail"))
|
||||
if includeDetail == "1" || strings.EqualFold(includeDetail, "true") || strings.EqualFold(includeDetail, "yes") {
|
||||
details := make([]*service.OpsErrorLogDetail, 0, len(result.Errors))
|
||||
for _, item := range result.Errors {
|
||||
if item == nil {
|
||||
continue
|
||||
}
|
||||
d, err := h.opsService.GetErrorLogByID(c.Request.Context(), item.ID)
|
||||
if err != nil || d == nil {
|
||||
continue
|
||||
}
|
||||
details = append(details, d)
|
||||
}
|
||||
response.Paginated(c, details, int64(result.Total), result.Page, result.PageSize)
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
|
||||
}
|
||||
|
||||
// RetryRequestErrorClient retries the client request based on stored request body.
|
||||
// POST /api/v1/admin/ops/request-errors/:id/retry-client
|
||||
func (h *OpsHandler) RetryRequestErrorClient(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok || subject.UserID <= 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
idStr := strings.TrimSpace(c.Param("id"))
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid error id")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, service.OpsRetryModeClient, nil)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// RetryRequestErrorUpstreamEvent retries a specific upstream attempt using captured upstream_request_body.
|
||||
// POST /api/v1/admin/ops/request-errors/:id/upstream-errors/:idx/retry
|
||||
func (h *OpsHandler) RetryRequestErrorUpstreamEvent(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok || subject.UserID <= 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
idStr := strings.TrimSpace(c.Param("id"))
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid error id")
|
||||
return
|
||||
}
|
||||
|
||||
idxStr := strings.TrimSpace(c.Param("idx"))
|
||||
idx, err := strconv.Atoi(idxStr)
|
||||
if err != nil || idx < 0 {
|
||||
response.BadRequest(c, "Invalid upstream idx")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.opsService.RetryUpstreamEvent(c.Request.Context(), subject.UserID, id, idx)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// ResolveRequestError toggles resolved status.
|
||||
// PUT /api/v1/admin/ops/request-errors/:id/resolve
|
||||
func (h *OpsHandler) ResolveRequestError(c *gin.Context) {
|
||||
h.UpdateErrorResolution(c)
|
||||
}
|
||||
|
||||
// ListUpstreamErrors lists independent upstream errors.
|
||||
// GET /api/v1/admin/ops/upstream-errors
|
||||
func (h *OpsHandler) ListUpstreamErrors(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
if pageSize > 500 {
|
||||
pageSize = 500
|
||||
}
|
||||
startTime, endTime, err := parseOpsTimeRange(c, "1h")
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
|
||||
if !startTime.IsZero() {
|
||||
filter.StartTime = &startTime
|
||||
}
|
||||
if !endTime.IsZero() {
|
||||
filter.EndTime = &endTime
|
||||
}
|
||||
|
||||
filter.View = parseOpsViewParam(c)
|
||||
filter.Phase = "upstream"
|
||||
filter.Owner = "provider"
|
||||
filter.Source = strings.TrimSpace(c.Query("error_source"))
|
||||
filter.Query = strings.TrimSpace(c.Query("q"))
|
||||
|
||||
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
|
||||
filter.Platform = platform
|
||||
}
|
||||
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
|
||||
id, err := strconv.ParseInt(v, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid group_id")
|
||||
return
|
||||
}
|
||||
filter.GroupID = &id
|
||||
}
|
||||
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
|
||||
id, err := strconv.ParseInt(v, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid account_id")
|
||||
return
|
||||
}
|
||||
filter.AccountID = &id
|
||||
}
|
||||
|
||||
if v := strings.TrimSpace(c.Query("resolved")); v != "" {
|
||||
switch strings.ToLower(v) {
|
||||
case "1", "true", "yes":
|
||||
b := true
|
||||
filter.Resolved = &b
|
||||
case "0", "false", "no":
|
||||
b := false
|
||||
filter.Resolved = &b
|
||||
default:
|
||||
response.BadRequest(c, "Invalid resolved")
|
||||
return
|
||||
}
|
||||
}
|
||||
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
|
||||
parts := strings.Split(statusCodesStr, ",")
|
||||
out := make([]int, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
p := strings.TrimSpace(part)
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.Atoi(p)
|
||||
if err != nil || n < 0 {
|
||||
response.BadRequest(c, "Invalid status_codes")
|
||||
return
|
||||
}
|
||||
out = append(out, n)
|
||||
}
|
||||
filter.StatusCodes = out
|
||||
}
|
||||
|
||||
result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
|
||||
}
|
||||
|
||||
// GetUpstreamError returns upstream error detail.
|
||||
// GET /api/v1/admin/ops/upstream-errors/:id
|
||||
func (h *OpsHandler) GetUpstreamError(c *gin.Context) {
|
||||
h.GetErrorLogByID(c)
|
||||
}
|
||||
|
||||
// RetryUpstreamError retries upstream error using the original account_id.
|
||||
// POST /api/v1/admin/ops/upstream-errors/:id/retry
|
||||
func (h *OpsHandler) RetryUpstreamError(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok || subject.UserID <= 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
idStr := strings.TrimSpace(c.Param("id"))
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid error id")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, service.OpsRetryModeUpstream, nil)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// ResolveUpstreamError toggles resolved status.
|
||||
// PUT /api/v1/admin/ops/upstream-errors/:id/resolve
|
||||
func (h *OpsHandler) ResolveUpstreamError(c *gin.Context) {
|
||||
h.UpdateErrorResolution(c)
|
||||
}
|
||||
|
||||
// ==================== Existing endpoints ====================
|
||||
|
||||
// ListRequestDetails returns a request-level list (success + error) for drill-down.
|
||||
// GET /api/v1/admin/ops/requests
|
||||
func (h *OpsHandler) ListRequestDetails(c *gin.Context) {
|
||||
@@ -242,6 +709,11 @@ func (h *OpsHandler) ListRequestDetails(c *gin.Context) {
|
||||
type opsRetryRequest struct {
|
||||
Mode string `json:"mode"`
|
||||
PinnedAccountID *int64 `json:"pinned_account_id"`
|
||||
Force bool `json:"force"`
|
||||
}
|
||||
|
||||
type opsResolveRequest struct {
|
||||
Resolved bool `json:"resolved"`
|
||||
}
|
||||
|
||||
// RetryErrorRequest retries a failed request using stored request_body.
|
||||
@@ -278,6 +750,16 @@ func (h *OpsHandler) RetryErrorRequest(c *gin.Context) {
|
||||
req.Mode = service.OpsRetryModeClient
|
||||
}
|
||||
|
||||
// Force flag is currently a UI-level acknowledgement. Server may still enforce safety constraints.
|
||||
_ = req.Force
|
||||
|
||||
// Legacy endpoint safety: only allow retrying the client request here.
|
||||
// Upstream retries must go through the split endpoints.
|
||||
if strings.EqualFold(strings.TrimSpace(req.Mode), service.OpsRetryModeUpstream) {
|
||||
response.BadRequest(c, "upstream retry is not supported on this endpoint")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, req.Mode, req.PinnedAccountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@@ -287,6 +769,81 @@ func (h *OpsHandler) RetryErrorRequest(c *gin.Context) {
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// ListRetryAttempts lists retry attempts for an error log.
|
||||
// GET /api/v1/admin/ops/errors/:id/retries
|
||||
func (h *OpsHandler) ListRetryAttempts(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
idStr := strings.TrimSpace(c.Param("id"))
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid error id")
|
||||
return
|
||||
}
|
||||
|
||||
limit := 50
|
||||
if v := strings.TrimSpace(c.Query("limit")); v != "" {
|
||||
n, err := strconv.Atoi(v)
|
||||
if err != nil || n <= 0 {
|
||||
response.BadRequest(c, "Invalid limit")
|
||||
return
|
||||
}
|
||||
limit = n
|
||||
}
|
||||
|
||||
items, err := h.opsService.ListRetryAttemptsByErrorID(c.Request.Context(), id, limit)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, items)
|
||||
}
|
||||
|
||||
// UpdateErrorResolution allows manual resolve/unresolve.
|
||||
// PUT /api/v1/admin/ops/errors/:id/resolve
|
||||
func (h *OpsHandler) UpdateErrorResolution(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok || subject.UserID <= 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
idStr := strings.TrimSpace(c.Param("id"))
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid error id")
|
||||
return
|
||||
}
|
||||
|
||||
var req opsResolveRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
uid := subject.UserID
|
||||
if err := h.opsService.UpdateErrorResolution(c.Request.Context(), id, req.Resolved, &uid, nil); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"ok": true})
|
||||
}
|
||||
|
||||
func parseOpsTimeRange(c *gin.Context, defaultRange string) (time.Time, time.Time, error) {
|
||||
startStr := strings.TrimSpace(c.Query("start_time"))
|
||||
endStr := strings.TrimSpace(c.Query("end_time"))
|
||||
@@ -358,6 +915,10 @@ func parseOpsDuration(v string) (time.Duration, bool) {
|
||||
return 6 * time.Hour, true
|
||||
case "24h":
|
||||
return 24 * time.Hour, true
|
||||
case "7d":
|
||||
return 7 * 24 * time.Hour, true
|
||||
case "30d":
|
||||
return 30 * 24 * time.Hour, true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
|
||||
@@ -196,6 +196,28 @@ func (h *ProxyHandler) Delete(c *gin.Context) {
|
||||
response.Success(c, gin.H{"message": "Proxy deleted successfully"})
|
||||
}
|
||||
|
||||
// BatchDelete handles batch deleting proxies
|
||||
// POST /api/v1/admin/proxies/batch-delete
|
||||
func (h *ProxyHandler) BatchDelete(c *gin.Context) {
|
||||
type BatchDeleteRequest struct {
|
||||
IDs []int64 `json:"ids" binding:"required,min=1"`
|
||||
}
|
||||
|
||||
var req BatchDeleteRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.adminService.BatchDeleteProxies(c.Request.Context(), req.IDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// Test handles testing proxy connectivity
|
||||
// POST /api/v1/admin/proxies/:id/test
|
||||
func (h *ProxyHandler) Test(c *gin.Context) {
|
||||
@@ -243,19 +265,17 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
accounts, total, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID, page, pageSize)
|
||||
accounts, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.Account, 0, len(accounts))
|
||||
out := make([]dto.ProxyAccountSummary, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
out = append(out, *dto.AccountFromService(&accounts[i]))
|
||||
out = append(out, *dto.ProxyAccountSummaryFromService(&accounts[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// BatchCreateProxyItem represents a single proxy in batch create request
|
||||
|
||||
@@ -54,9 +54,9 @@ func (h *RedeemHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.RedeemCode, 0, len(codes))
|
||||
out := make([]dto.AdminRedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
|
||||
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
@@ -76,7 +76,7 @@ func (h *RedeemHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RedeemCodeFromService(code))
|
||||
response.Success(c, dto.RedeemCodeFromServiceAdmin(code))
|
||||
}
|
||||
|
||||
// Generate handles generating new redeem codes
|
||||
@@ -100,9 +100,9 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.RedeemCode, 0, len(codes))
|
||||
out := make([]dto.AdminRedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
|
||||
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
@@ -163,7 +163,7 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RedeemCodeFromService(code))
|
||||
response.Success(c, dto.RedeemCodeFromServiceAdmin(code))
|
||||
}
|
||||
|
||||
// GetStats handles getting redeem code statistics
|
||||
|
||||
@@ -47,6 +47,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||
SMTPHost: settings.SMTPHost,
|
||||
SMTPPort: settings.SMTPPort,
|
||||
SMTPUsername: settings.SMTPUsername,
|
||||
@@ -68,6 +72,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
@@ -87,8 +92,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
// UpdateSettingsRequest 更新设置请求
|
||||
type UpdateSettingsRequest struct {
|
||||
// 注册设置
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
|
||||
// 邮件服务设置
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
@@ -111,13 +119,14 @@ type UpdateSettingsRequest struct {
|
||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||
|
||||
// OEM设置
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
@@ -194,6 +203,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// TOTP 双因素认证参数验证
|
||||
// 只有手动配置了加密密钥才允许启用 TOTP 功能
|
||||
if req.TotpEnabled && !previousSettings.TotpEnabled {
|
||||
// 尝试启用 TOTP,检查加密密钥是否已手动配置
|
||||
if !h.settingService.IsTotpEncryptionKeyConfigured() {
|
||||
response.BadRequest(c, "Cannot enable TOTP: TOTP_ENCRYPTION_KEY environment variable must be configured first. Generate a key with 'openssl rand -hex 32' and set it in your environment.")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// LinuxDo Connect 参数验证
|
||||
if req.LinuxDoConnectEnabled {
|
||||
req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID)
|
||||
@@ -238,6 +257,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||
TotpEnabled: req.TotpEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
@@ -259,6 +281,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
HomeContent: req.HomeContent,
|
||||
HideCcsImportButton: req.HideCcsImportButton,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
@@ -311,6 +334,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
||||
TotpEnabled: updatedSettings.TotpEnabled,
|
||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||
SMTPHost: updatedSettings.SMTPHost,
|
||||
SMTPPort: updatedSettings.SMTPPort,
|
||||
SMTPUsername: updatedSettings.SMTPUsername,
|
||||
@@ -332,6 +359,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
ContactInfo: updatedSettings.ContactInfo,
|
||||
DocURL: updatedSettings.DocURL,
|
||||
HomeContent: updatedSettings.HomeContent,
|
||||
HideCcsImportButton: updatedSettings.HideCcsImportButton,
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
@@ -376,6 +404,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
|
||||
changed = append(changed, "email_verify_enabled")
|
||||
}
|
||||
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
||||
changed = append(changed, "password_reset_enabled")
|
||||
}
|
||||
if before.TotpEnabled != after.TotpEnabled {
|
||||
changed = append(changed, "totp_enabled")
|
||||
}
|
||||
if before.SMTPHost != after.SMTPHost {
|
||||
changed = append(changed, "smtp_host")
|
||||
}
|
||||
@@ -439,6 +473,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.HomeContent != after.HomeContent {
|
||||
changed = append(changed, "home_content")
|
||||
}
|
||||
if before.HideCcsImportButton != after.HideCcsImportButton {
|
||||
changed = append(changed, "hide_ccs_import_button")
|
||||
}
|
||||
if before.DefaultConcurrency != after.DefaultConcurrency {
|
||||
changed = append(changed, "default_concurrency")
|
||||
}
|
||||
|
||||
@@ -53,9 +53,9 @@ type BulkAssignSubscriptionRequest struct {
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// ExtendSubscriptionRequest represents extend subscription request
|
||||
type ExtendSubscriptionRequest struct {
|
||||
Days int `json:"days" binding:"required,min=1,max=36500"` // max 100 years
|
||||
// AdjustSubscriptionRequest represents adjust subscription request (extend or shorten)
|
||||
type AdjustSubscriptionRequest struct {
|
||||
Days int `json:"days" binding:"required,min=-36500,max=36500"` // negative to shorten, positive to extend
|
||||
}
|
||||
|
||||
// List handles listing all subscriptions with pagination and filters
|
||||
@@ -77,15 +77,19 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
|
||||
}
|
||||
status := c.Query("status")
|
||||
|
||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
|
||||
// Parse sorting parameters
|
||||
sortBy := c.DefaultQuery("sort_by", "created_at")
|
||||
sortOrder := c.DefaultQuery("sort_order", "desc")
|
||||
|
||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
|
||||
}
|
||||
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
|
||||
}
|
||||
@@ -105,7 +109,7 @@ func (h *SubscriptionHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserSubscriptionFromService(subscription))
|
||||
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
|
||||
}
|
||||
|
||||
// GetProgress handles getting subscription usage progress
|
||||
@@ -150,7 +154,7 @@ func (h *SubscriptionHandler) Assign(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserSubscriptionFromService(subscription))
|
||||
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
|
||||
}
|
||||
|
||||
// BulkAssign handles bulk assigning subscriptions to multiple users
|
||||
@@ -180,7 +184,7 @@ func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
|
||||
response.Success(c, dto.BulkAssignResultFromService(result))
|
||||
}
|
||||
|
||||
// Extend handles extending a subscription
|
||||
// Extend handles adjusting a subscription (extend or shorten)
|
||||
// POST /api/v1/admin/subscriptions/:id/extend
|
||||
func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
@@ -189,7 +193,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var req ExtendSubscriptionRequest
|
||||
var req AdjustSubscriptionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
@@ -201,7 +205,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserSubscriptionFromService(subscription))
|
||||
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
|
||||
}
|
||||
|
||||
// Revoke handles revoking a subscription
|
||||
@@ -239,9 +243,9 @@ func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
|
||||
}
|
||||
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
|
||||
}
|
||||
@@ -261,9 +265,9 @@ func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
377
backend/internal/handler/admin/usage_cleanup_handler_test.go
Normal file
377
backend/internal/handler/admin/usage_cleanup_handler_test.go
Normal file
@@ -0,0 +1,377 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type cleanupRepoStub struct {
|
||||
mu sync.Mutex
|
||||
created []*service.UsageCleanupTask
|
||||
listTasks []service.UsageCleanupTask
|
||||
listResult *pagination.PaginationResult
|
||||
listErr error
|
||||
statusByID map[int64]string
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) CreateTask(ctx context.Context, task *service.UsageCleanupTask) error {
|
||||
if task == nil {
|
||||
return nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if task.ID == 0 {
|
||||
task.ID = int64(len(s.created) + 1)
|
||||
}
|
||||
if task.CreatedAt.IsZero() {
|
||||
task.CreatedAt = time.Now().UTC()
|
||||
}
|
||||
task.UpdatedAt = task.CreatedAt
|
||||
clone := *task
|
||||
s.created = append(s.created, &clone)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.listTasks, s.listResult, s.listErr
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*service.UsageCleanupTask, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.statusByID == nil {
|
||||
return "", sql.ErrNoRows
|
||||
}
|
||||
status, ok := s.statusByID[taskID]
|
||||
if !ok {
|
||||
return "", sql.ErrNoRows
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.statusByID == nil {
|
||||
s.statusByID = map[int64]string{}
|
||||
}
|
||||
status := s.statusByID[taskID]
|
||||
if status != service.UsageCleanupStatusPending && status != service.UsageCleanupStatusRunning {
|
||||
return false, nil
|
||||
}
|
||||
s.statusByID[taskID] = service.UsageCleanupStatusCanceled
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) DeleteUsageLogsBatch(ctx context.Context, filters service.UsageCleanupFilters, limit int) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var _ service.UsageCleanupRepository = (*cleanupRepoStub)(nil)
|
||||
|
||||
func setupCleanupRouter(cleanupService *service.UsageCleanupService, userID int64) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
if userID > 0 {
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: userID})
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
|
||||
handler := NewUsageHandler(nil, nil, nil, cleanupService)
|
||||
router.POST("/api/v1/admin/usage/cleanup-tasks", handler.CreateCleanupTask)
|
||||
router.GET("/api/v1/admin/usage/cleanup-tasks", handler.ListCleanupTasks)
|
||||
router.POST("/api/v1/admin/usage/cleanup-tasks/:id/cancel", handler.CancelCleanupTask)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskUnauthorized(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 0)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskUnavailable(t *testing.T) {
|
||||
router := setupCleanupRouter(nil, 1)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusServiceUnavailable, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskBindError(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 88)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString("{bad-json"))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskMissingRange(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 88)
|
||||
|
||||
payload := map[string]any{
|
||||
"start_date": "2024-01-01",
|
||||
"timezone": "UTC",
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskInvalidDate(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 88)
|
||||
|
||||
payload := map[string]any{
|
||||
"start_date": "2024-13-01",
|
||||
"end_date": "2024-01-02",
|
||||
"timezone": "UTC",
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskInvalidEndDate(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 88)
|
||||
|
||||
payload := map[string]any{
|
||||
"start_date": "2024-01-01",
|
||||
"end_date": "2024-02-40",
|
||||
"timezone": "UTC",
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 99)
|
||||
|
||||
payload := map[string]any{
|
||||
"start_date": " 2024-01-01 ",
|
||||
"end_date": "2024-01-02",
|
||||
"timezone": "UTC",
|
||||
"model": "gpt-4",
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
var resp response.Response
|
||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
|
||||
require.Equal(t, 0, resp.Code)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.created, 1)
|
||||
created := repo.created[0]
|
||||
require.Equal(t, int64(99), created.CreatedBy)
|
||||
require.NotNil(t, created.Filters.Model)
|
||||
require.Equal(t, "gpt-4", *created.Filters.Model)
|
||||
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC).Add(24*time.Hour - time.Nanosecond)
|
||||
require.True(t, created.Filters.StartTime.Equal(start))
|
||||
require.True(t, created.Filters.EndTime.Equal(end))
|
||||
}
|
||||
|
||||
func TestUsageHandlerListCleanupTasksUnavailable(t *testing.T) {
|
||||
router := setupCleanupRouter(nil, 0)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusServiceUnavailable, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerListCleanupTasksSuccess(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
repo.listTasks = []service.UsageCleanupTask{
|
||||
{
|
||||
ID: 7,
|
||||
Status: service.UsageCleanupStatusSucceeded,
|
||||
CreatedBy: 4,
|
||||
},
|
||||
}
|
||||
repo.listResult = &pagination.PaginationResult{Total: 1, Page: 1, PageSize: 20, Pages: 1}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 1)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
var resp struct {
|
||||
Code int `json:"code"`
|
||||
Data struct {
|
||||
Items []dto.UsageCleanupTask `json:"items"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
|
||||
require.Equal(t, 0, resp.Code)
|
||||
require.Len(t, resp.Data.Items, 1)
|
||||
require.Equal(t, int64(7), resp.Data.Items[0].ID)
|
||||
require.Equal(t, int64(1), resp.Data.Total)
|
||||
require.Equal(t, 1, resp.Data.Page)
|
||||
}
|
||||
|
||||
func TestUsageHandlerListCleanupTasksError(t *testing.T) {
|
||||
repo := &cleanupRepoStub{listErr: errors.New("boom")}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 1)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusInternalServerError, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCancelCleanupTaskUnauthorized(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 0)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/1/cancel", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCancelCleanupTaskNotFound(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 1)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/999/cancel", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusNotFound, rec.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCancelCleanupTaskConflict(t *testing.T) {
|
||||
repo := &cleanupRepoStub{statusByID: map[int64]string{2: service.UsageCleanupStatusSucceeded}}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 1)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/2/cancel", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusConflict, rec.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCancelCleanupTaskSuccess(t *testing.T) {
|
||||
repo := &cleanupRepoStub{statusByID: map[int64]string{3: service.UsageCleanupStatusPending}}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 1)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/3/cancel", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
@@ -1,7 +1,10 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
@@ -9,6 +12,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -16,9 +20,10 @@ import (
|
||||
|
||||
// UsageHandler handles admin usage-related requests
|
||||
type UsageHandler struct {
|
||||
usageService *service.UsageService
|
||||
apiKeyService *service.APIKeyService
|
||||
adminService service.AdminService
|
||||
usageService *service.UsageService
|
||||
apiKeyService *service.APIKeyService
|
||||
adminService service.AdminService
|
||||
cleanupService *service.UsageCleanupService
|
||||
}
|
||||
|
||||
// NewUsageHandler creates a new admin usage handler
|
||||
@@ -26,14 +31,30 @@ func NewUsageHandler(
|
||||
usageService *service.UsageService,
|
||||
apiKeyService *service.APIKeyService,
|
||||
adminService service.AdminService,
|
||||
cleanupService *service.UsageCleanupService,
|
||||
) *UsageHandler {
|
||||
return &UsageHandler{
|
||||
usageService: usageService,
|
||||
apiKeyService: apiKeyService,
|
||||
adminService: adminService,
|
||||
usageService: usageService,
|
||||
apiKeyService: apiKeyService,
|
||||
adminService: adminService,
|
||||
cleanupService: cleanupService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateUsageCleanupTaskRequest represents cleanup task creation request
|
||||
type CreateUsageCleanupTaskRequest struct {
|
||||
StartDate string `json:"start_date"`
|
||||
EndDate string `json:"end_date"`
|
||||
UserID *int64 `json:"user_id"`
|
||||
APIKeyID *int64 `json:"api_key_id"`
|
||||
AccountID *int64 `json:"account_id"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Model *string `json:"model"`
|
||||
Stream *bool `json:"stream"`
|
||||
BillingType *int8 `json:"billing_type"`
|
||||
Timezone string `json:"timezone"`
|
||||
}
|
||||
|
||||
// List handles listing all usage records with filters
|
||||
// GET /api/v1/admin/usage
|
||||
func (h *UsageHandler) List(c *gin.Context) {
|
||||
@@ -142,7 +163,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UsageLog, 0, len(records))
|
||||
out := make([]dto.AdminUsageLog, 0, len(records))
|
||||
for i := range records {
|
||||
out = append(out, *dto.UsageLogFromServiceAdmin(&records[i]))
|
||||
}
|
||||
@@ -344,3 +365,162 @@ func (h *UsageHandler) SearchAPIKeys(c *gin.Context) {
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// ListCleanupTasks handles listing usage cleanup tasks
|
||||
// GET /api/v1/admin/usage/cleanup-tasks
|
||||
func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
|
||||
if h.cleanupService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
|
||||
return
|
||||
}
|
||||
operator := int64(0)
|
||||
if subject, ok := middleware.GetAuthSubjectFromContext(c); ok {
|
||||
operator = subject.UserID
|
||||
}
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
log.Printf("[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params)
|
||||
if err != nil {
|
||||
log.Printf("[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
out := make([]dto.UsageCleanupTask, 0, len(tasks))
|
||||
for i := range tasks {
|
||||
out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i]))
|
||||
}
|
||||
log.Printf("[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
|
||||
response.Paginated(c, out, result.Total, page, pageSize)
|
||||
}
|
||||
|
||||
// CreateCleanupTask handles creating a usage cleanup task
|
||||
// POST /api/v1/admin/usage/cleanup-tasks
|
||||
func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
||||
if h.cleanupService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
|
||||
return
|
||||
}
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok || subject.UserID <= 0 {
|
||||
response.Unauthorized(c, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
var req CreateUsageCleanupTaskRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
req.StartDate = strings.TrimSpace(req.StartDate)
|
||||
req.EndDate = strings.TrimSpace(req.EndDate)
|
||||
if req.StartDate == "" || req.EndDate == "" {
|
||||
response.BadRequest(c, "start_date and end_date are required")
|
||||
return
|
||||
}
|
||||
|
||||
startTime, err := timezone.ParseInUserLocation("2006-01-02", req.StartDate, req.Timezone)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
endTime, err := timezone.ParseInUserLocation("2006-01-02", req.EndDate, req.Timezone)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
||||
|
||||
filters := service.UsageCleanupFilters{
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
UserID: req.UserID,
|
||||
APIKeyID: req.APIKeyID,
|
||||
AccountID: req.AccountID,
|
||||
GroupID: req.GroupID,
|
||||
Model: req.Model,
|
||||
Stream: req.Stream,
|
||||
BillingType: req.BillingType,
|
||||
}
|
||||
|
||||
var userID any
|
||||
if filters.UserID != nil {
|
||||
userID = *filters.UserID
|
||||
}
|
||||
var apiKeyID any
|
||||
if filters.APIKeyID != nil {
|
||||
apiKeyID = *filters.APIKeyID
|
||||
}
|
||||
var accountID any
|
||||
if filters.AccountID != nil {
|
||||
accountID = *filters.AccountID
|
||||
}
|
||||
var groupID any
|
||||
if filters.GroupID != nil {
|
||||
groupID = *filters.GroupID
|
||||
}
|
||||
var model any
|
||||
if filters.Model != nil {
|
||||
model = *filters.Model
|
||||
}
|
||||
var stream any
|
||||
if filters.Stream != nil {
|
||||
stream = *filters.Stream
|
||||
}
|
||||
var billingType any
|
||||
if filters.BillingType != nil {
|
||||
billingType = *filters.BillingType
|
||||
}
|
||||
|
||||
log.Printf("[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
|
||||
subject.UserID,
|
||||
filters.StartTime.Format(time.RFC3339),
|
||||
filters.EndTime.Format(time.RFC3339),
|
||||
userID,
|
||||
apiKeyID,
|
||||
accountID,
|
||||
groupID,
|
||||
model,
|
||||
stream,
|
||||
billingType,
|
||||
req.Timezone,
|
||||
)
|
||||
|
||||
task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID)
|
||||
if err != nil {
|
||||
log.Printf("[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
|
||||
response.Success(c, dto.UsageCleanupTaskFromService(task))
|
||||
}
|
||||
|
||||
// CancelCleanupTask handles canceling a usage cleanup task
|
||||
// POST /api/v1/admin/usage/cleanup-tasks/:id/cancel
|
||||
func (h *UsageHandler) CancelCleanupTask(c *gin.Context) {
|
||||
if h.cleanupService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
|
||||
return
|
||||
}
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok || subject.UserID <= 0 {
|
||||
response.Unauthorized(c, "Unauthorized")
|
||||
return
|
||||
}
|
||||
idStr := strings.TrimSpace(c.Param("id"))
|
||||
taskID, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil || taskID <= 0 {
|
||||
response.BadRequest(c, "Invalid task id")
|
||||
return
|
||||
}
|
||||
log.Printf("[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
|
||||
if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil {
|
||||
log.Printf("[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
log.Printf("[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
|
||||
response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled})
|
||||
}
|
||||
|
||||
@@ -84,9 +84,9 @@ func (h *UserHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.User, 0, len(users))
|
||||
out := make([]dto.AdminUser, 0, len(users))
|
||||
for i := range users {
|
||||
out = append(out, *dto.UserFromService(&users[i]))
|
||||
out = append(out, *dto.UserFromServiceAdmin(&users[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
@@ -129,7 +129,7 @@ func (h *UserHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
response.Success(c, dto.UserFromServiceAdmin(user))
|
||||
}
|
||||
|
||||
// Create handles creating a new user
|
||||
@@ -155,7 +155,7 @@ func (h *UserHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
response.Success(c, dto.UserFromServiceAdmin(user))
|
||||
}
|
||||
|
||||
// Update handles updating a user
|
||||
@@ -189,7 +189,7 @@ func (h *UserHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
response.Success(c, dto.UserFromServiceAdmin(user))
|
||||
}
|
||||
|
||||
// Delete handles deleting a user
|
||||
@@ -231,7 +231,7 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
response.Success(c, dto.UserFromServiceAdmin(user))
|
||||
}
|
||||
|
||||
// GetUserAPIKeys handles getting user's API keys
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
@@ -18,16 +20,18 @@ type AuthHandler struct {
|
||||
userService *service.UserService
|
||||
settingSvc *service.SettingService
|
||||
promoService *service.PromoService
|
||||
totpService *service.TotpService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService) *AuthHandler {
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, totpService *service.TotpService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
cfg: cfg,
|
||||
authService: authService,
|
||||
userService: userService,
|
||||
settingSvc: settingService,
|
||||
promoService: promoService,
|
||||
totpService: totpService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,6 +148,100 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if TOTP 2FA is enabled for this user
|
||||
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
|
||||
// Create a temporary login session for 2FA
|
||||
tempToken, err := h.totpService.CreateLoginSession(c.Request.Context(), user.ID, user.Email)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to create 2FA session")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, TotpLoginResponse{
|
||||
Requires2FA: true,
|
||||
TempToken: tempToken,
|
||||
UserEmailMasked: service.MaskEmail(user.Email),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
User: dto.UserFromService(user),
|
||||
})
|
||||
}
|
||||
|
||||
// TotpLoginResponse represents the response when 2FA is required
|
||||
type TotpLoginResponse struct {
|
||||
Requires2FA bool `json:"requires_2fa"`
|
||||
TempToken string `json:"temp_token,omitempty"`
|
||||
UserEmailMasked string `json:"user_email_masked,omitempty"`
|
||||
}
|
||||
|
||||
// Login2FARequest represents the 2FA login request
|
||||
type Login2FARequest struct {
|
||||
TempToken string `json:"temp_token" binding:"required"`
|
||||
TotpCode string `json:"totp_code" binding:"required,len=6"`
|
||||
}
|
||||
|
||||
// Login2FA completes the login with 2FA verification
|
||||
// POST /api/v1/auth/login/2fa
|
||||
func (h *AuthHandler) Login2FA(c *gin.Context) {
|
||||
var req Login2FARequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("login_2fa_request",
|
||||
"temp_token_len", len(req.TempToken),
|
||||
"totp_code_len", len(req.TotpCode))
|
||||
|
||||
// Get the login session
|
||||
session, err := h.totpService.GetLoginSession(c.Request.Context(), req.TempToken)
|
||||
if err != nil || session == nil {
|
||||
tokenPrefix := ""
|
||||
if len(req.TempToken) >= 8 {
|
||||
tokenPrefix = req.TempToken[:8]
|
||||
}
|
||||
slog.Debug("login_2fa_session_invalid",
|
||||
"temp_token_prefix", tokenPrefix,
|
||||
"error", err)
|
||||
response.BadRequest(c, "Invalid or expired 2FA session")
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("login_2fa_session_found",
|
||||
"user_id", session.UserID,
|
||||
"email", session.Email)
|
||||
|
||||
// Verify the TOTP code
|
||||
if err := h.totpService.VerifyCode(c.Request.Context(), session.UserID, req.TotpCode); err != nil {
|
||||
slog.Debug("login_2fa_verify_failed",
|
||||
"user_id", session.UserID,
|
||||
"error", err)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the login session
|
||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||
|
||||
// Get the user
|
||||
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate the JWT token
|
||||
token, err := h.authService.GenerateToken(user)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate token")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
@@ -195,6 +293,15 @@ type ValidatePromoCodeResponse struct {
|
||||
// ValidatePromoCode 验证优惠码(公开接口,注册前调用)
|
||||
// POST /api/v1/auth/validate-promo-code
|
||||
func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
|
||||
// 检查优惠码功能是否启用
|
||||
if h.settingSvc != nil && !h.settingSvc.IsPromoCodeEnabled(c.Request.Context()) {
|
||||
response.Success(c, ValidatePromoCodeResponse{
|
||||
Valid: false,
|
||||
ErrorCode: "PROMO_CODE_DISABLED",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var req ValidatePromoCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
@@ -238,3 +345,85 @@ func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
|
||||
BonusAmount: promoCode.BonusAmount,
|
||||
})
|
||||
}
|
||||
|
||||
// ForgotPasswordRequest 忘记密码请求
|
||||
type ForgotPasswordRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
TurnstileToken string `json:"turnstile_token"`
|
||||
}
|
||||
|
||||
// ForgotPasswordResponse 忘记密码响应
|
||||
type ForgotPasswordResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ForgotPassword 请求密码重置
|
||||
// POST /api/v1/auth/forgot-password
|
||||
func (h *AuthHandler) ForgotPassword(c *gin.Context) {
|
||||
var req ForgotPasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Turnstile 验证
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Build frontend base URL from request
|
||||
scheme := "https"
|
||||
if c.Request.TLS == nil {
|
||||
// Check X-Forwarded-Proto header (common in reverse proxy setups)
|
||||
if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" {
|
||||
scheme = proto
|
||||
} else {
|
||||
scheme = "http"
|
||||
}
|
||||
}
|
||||
frontendBaseURL := scheme + "://" + c.Request.Host
|
||||
|
||||
// Request password reset (async)
|
||||
// Note: This returns success even if email doesn't exist (to prevent enumeration)
|
||||
if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, ForgotPasswordResponse{
|
||||
Message: "If your email is registered, you will receive a password reset link shortly.",
|
||||
})
|
||||
}
|
||||
|
||||
// ResetPasswordRequest 重置密码请求
|
||||
type ResetPasswordRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Token string `json:"token" binding:"required"`
|
||||
NewPassword string `json:"new_password" binding:"required,min=6"`
|
||||
}
|
||||
|
||||
// ResetPasswordResponse 重置密码响应
|
||||
type ResetPasswordResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ResetPassword 重置密码
|
||||
// POST /api/v1/auth/reset-password
|
||||
func (h *AuthHandler) ResetPassword(c *gin.Context) {
|
||||
var req ResetPasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Reset password
|
||||
if err := h.authService.ResetPassword(c.Request.Context(), req.Email, req.Token, req.NewPassword); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, ResetPasswordResponse{
|
||||
Message: "Your password has been reset successfully. You can now log in with your new password.",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ func UserFromServiceShallow(u *service.User) *User {
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Notes: u.Notes,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
@@ -48,6 +47,22 @@ func UserFromService(u *service.User) *User {
|
||||
return out
|
||||
}
|
||||
|
||||
// UserFromServiceAdmin converts a service User to DTO for admin users.
|
||||
// It includes notes - user-facing endpoints must not use this.
|
||||
func UserFromServiceAdmin(u *service.User) *AdminUser {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
base := UserFromService(u)
|
||||
if base == nil {
|
||||
return nil
|
||||
}
|
||||
return &AdminUser{
|
||||
User: *base,
|
||||
Notes: u.Notes,
|
||||
}
|
||||
}
|
||||
|
||||
func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
if k == nil {
|
||||
return nil
|
||||
@@ -72,7 +87,41 @@ func GroupFromServiceShallow(g *service.Group) *Group {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
return &Group{
|
||||
out := groupFromServiceBase(g)
|
||||
return &out
|
||||
}
|
||||
|
||||
func GroupFromService(g *service.Group) *Group {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
return GroupFromServiceShallow(g)
|
||||
}
|
||||
|
||||
// GroupFromServiceAdmin converts a service Group to DTO for admin users.
|
||||
// It includes internal fields like model_routing and account_count.
|
||||
func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
out := &AdminGroup{
|
||||
Group: groupFromServiceBase(g),
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
AccountCount: g.AccountCount,
|
||||
}
|
||||
if len(g.AccountGroups) > 0 {
|
||||
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
||||
for i := range g.AccountGroups {
|
||||
ag := g.AccountGroups[i]
|
||||
out.AccountGroups = append(out.AccountGroups, *AccountGroupFromService(&ag))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func groupFromServiceBase(g *service.Group) Group {
|
||||
return Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: g.Description,
|
||||
@@ -91,30 +140,14 @@ func GroupFromServiceShallow(g *service.Group) *Group {
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
AccountCount: g.AccountCount,
|
||||
}
|
||||
}
|
||||
|
||||
func GroupFromService(g *service.Group) *Group {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
out := GroupFromServiceShallow(g)
|
||||
if len(g.AccountGroups) > 0 {
|
||||
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
||||
for i := range g.AccountGroups {
|
||||
ag := g.AccountGroups[i]
|
||||
out.AccountGroups = append(out.AccountGroups, *AccountGroupFromService(&ag))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
return &Account{
|
||||
out := &Account{
|
||||
ID: a.ID,
|
||||
Name: a.Name,
|
||||
Notes: a.Notes,
|
||||
@@ -125,6 +158,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
ProxyID: a.ProxyID,
|
||||
Concurrency: a.Concurrency,
|
||||
Priority: a.Priority,
|
||||
RateMultiplier: a.BillingRateMultiplier(),
|
||||
Status: a.Status,
|
||||
ErrorMessage: a.ErrorMessage,
|
||||
LastUsedAt: a.LastUsedAt,
|
||||
@@ -143,6 +177,34 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
SessionWindowStatus: a.SessionWindowStatus,
|
||||
GroupIDs: a.GroupIDs,
|
||||
}
|
||||
|
||||
// 提取 5h 窗口费用控制和会话数量控制配置(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||
if a.IsAnthropicOAuthOrSetupToken() {
|
||||
if limit := a.GetWindowCostLimit(); limit > 0 {
|
||||
out.WindowCostLimit = &limit
|
||||
}
|
||||
if reserve := a.GetWindowCostStickyReserve(); reserve > 0 {
|
||||
out.WindowCostStickyReserve = &reserve
|
||||
}
|
||||
if maxSessions := a.GetMaxSessions(); maxSessions > 0 {
|
||||
out.MaxSessions = &maxSessions
|
||||
}
|
||||
if idleTimeout := a.GetSessionIdleTimeoutMinutes(); idleTimeout > 0 {
|
||||
out.SessionIdleTimeoutMin = &idleTimeout
|
||||
}
|
||||
// TLS指纹伪装开关
|
||||
if a.IsTLSFingerprintEnabled() {
|
||||
enabled := true
|
||||
out.EnableTLSFingerprint = &enabled
|
||||
}
|
||||
// 会话ID伪装开关
|
||||
if a.IsSessionIDMaskingEnabled() {
|
||||
enabled := true
|
||||
out.EnableSessionIDMasking = &enabled
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func AccountFromService(a *service.Account) *Account {
|
||||
@@ -212,8 +274,29 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
|
||||
return nil
|
||||
}
|
||||
return &ProxyWithAccountCount{
|
||||
Proxy: *ProxyFromService(&p.Proxy),
|
||||
AccountCount: p.AccountCount,
|
||||
Proxy: *ProxyFromService(&p.Proxy),
|
||||
AccountCount: p.AccountCount,
|
||||
LatencyMs: p.LatencyMs,
|
||||
LatencyStatus: p.LatencyStatus,
|
||||
LatencyMessage: p.LatencyMessage,
|
||||
IPAddress: p.IPAddress,
|
||||
Country: p.Country,
|
||||
CountryCode: p.CountryCode,
|
||||
Region: p.Region,
|
||||
City: p.City,
|
||||
}
|
||||
}
|
||||
|
||||
func ProxyAccountSummaryFromService(a *service.ProxyAccountSummary) *ProxyAccountSummary {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
return &ProxyAccountSummary{
|
||||
ID: a.ID,
|
||||
Name: a.Name,
|
||||
Platform: a.Platform,
|
||||
Type: a.Type,
|
||||
Notes: a.Notes,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -221,7 +304,24 @@ func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
|
||||
if rc == nil {
|
||||
return nil
|
||||
}
|
||||
return &RedeemCode{
|
||||
out := redeemCodeFromServiceBase(rc)
|
||||
return &out
|
||||
}
|
||||
|
||||
// RedeemCodeFromServiceAdmin converts a service RedeemCode to DTO for admin users.
|
||||
// It includes notes - user-facing endpoints must not use this.
|
||||
func RedeemCodeFromServiceAdmin(rc *service.RedeemCode) *AdminRedeemCode {
|
||||
if rc == nil {
|
||||
return nil
|
||||
}
|
||||
return &AdminRedeemCode{
|
||||
RedeemCode: redeemCodeFromServiceBase(rc),
|
||||
Notes: rc.Notes,
|
||||
}
|
||||
}
|
||||
|
||||
func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode {
|
||||
return RedeemCode{
|
||||
ID: rc.ID,
|
||||
Code: rc.Code,
|
||||
Type: rc.Type,
|
||||
@@ -229,7 +329,6 @@ func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
|
||||
Status: rc.Status,
|
||||
UsedBy: rc.UsedBy,
|
||||
UsedAt: rc.UsedAt,
|
||||
Notes: rc.Notes,
|
||||
CreatedAt: rc.CreatedAt,
|
||||
GroupID: rc.GroupID,
|
||||
ValidityDays: rc.ValidityDays,
|
||||
@@ -250,14 +349,9 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary {
|
||||
}
|
||||
}
|
||||
|
||||
// usageLogFromServiceBase is a helper that converts service UsageLog to DTO.
|
||||
// The account parameter allows caller to control what Account info is included.
|
||||
// The includeIPAddress parameter controls whether to include the IP address (admin-only).
|
||||
func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, includeIPAddress bool) *UsageLog {
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
result := &UsageLog{
|
||||
func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
// 普通用户 DTO:严禁包含管理员字段(例如 account_rate_multiplier、ip_address、account)。
|
||||
return UsageLog{
|
||||
ID: l.ID,
|
||||
UserID: l.UserID,
|
||||
APIKeyID: l.APIKeyID,
|
||||
@@ -289,30 +383,63 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, inclu
|
||||
CreatedAt: l.CreatedAt,
|
||||
User: UserFromServiceShallow(l.User),
|
||||
APIKey: APIKeyFromService(l.APIKey),
|
||||
Account: account,
|
||||
Group: GroupFromServiceShallow(l.Group),
|
||||
Subscription: UserSubscriptionFromService(l.Subscription),
|
||||
}
|
||||
// IP 地址仅对管理员可见
|
||||
if includeIPAddress {
|
||||
result.IPAddress = l.IPAddress
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// UsageLogFromService converts a service UsageLog to DTO for regular users.
|
||||
// It excludes Account details and IP address - users should not see these.
|
||||
func UsageLogFromService(l *service.UsageLog) *UsageLog {
|
||||
return usageLogFromServiceBase(l, nil, false)
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
u := usageLogFromServiceUser(l)
|
||||
return &u
|
||||
}
|
||||
|
||||
// UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users.
|
||||
// It includes minimal Account info (ID, Name only) and IP address.
|
||||
func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog {
|
||||
func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog {
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true)
|
||||
return &AdminUsageLog{
|
||||
UsageLog: usageLogFromServiceUser(l),
|
||||
AccountRateMultiplier: l.AccountRateMultiplier,
|
||||
IPAddress: l.IPAddress,
|
||||
Account: AccountSummaryFromService(l.Account),
|
||||
}
|
||||
}
|
||||
|
||||
func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTask {
|
||||
if task == nil {
|
||||
return nil
|
||||
}
|
||||
return &UsageCleanupTask{
|
||||
ID: task.ID,
|
||||
Status: task.Status,
|
||||
Filters: UsageCleanupFilters{
|
||||
StartTime: task.Filters.StartTime,
|
||||
EndTime: task.Filters.EndTime,
|
||||
UserID: task.Filters.UserID,
|
||||
APIKeyID: task.Filters.APIKeyID,
|
||||
AccountID: task.Filters.AccountID,
|
||||
GroupID: task.Filters.GroupID,
|
||||
Model: task.Filters.Model,
|
||||
Stream: task.Filters.Stream,
|
||||
BillingType: task.Filters.BillingType,
|
||||
},
|
||||
CreatedBy: task.CreatedBy,
|
||||
DeletedRows: task.DeletedRows,
|
||||
ErrorMessage: task.ErrorMsg,
|
||||
CanceledBy: task.CanceledBy,
|
||||
CanceledAt: task.CanceledAt,
|
||||
StartedAt: task.StartedAt,
|
||||
FinishedAt: task.FinishedAt,
|
||||
CreatedAt: task.CreatedAt,
|
||||
UpdatedAt: task.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func SettingFromService(s *service.Setting) *Setting {
|
||||
@@ -331,7 +458,27 @@ func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscriptio
|
||||
if sub == nil {
|
||||
return nil
|
||||
}
|
||||
return &UserSubscription{
|
||||
out := userSubscriptionFromServiceBase(sub)
|
||||
return &out
|
||||
}
|
||||
|
||||
// UserSubscriptionFromServiceAdmin converts a service UserSubscription to DTO for admin users.
|
||||
// It includes assignment metadata and notes.
|
||||
func UserSubscriptionFromServiceAdmin(sub *service.UserSubscription) *AdminUserSubscription {
|
||||
if sub == nil {
|
||||
return nil
|
||||
}
|
||||
return &AdminUserSubscription{
|
||||
UserSubscription: userSubscriptionFromServiceBase(sub),
|
||||
AssignedBy: sub.AssignedBy,
|
||||
AssignedAt: sub.AssignedAt,
|
||||
Notes: sub.Notes,
|
||||
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
|
||||
}
|
||||
}
|
||||
|
||||
func userSubscriptionFromServiceBase(sub *service.UserSubscription) UserSubscription {
|
||||
return UserSubscription{
|
||||
ID: sub.ID,
|
||||
UserID: sub.UserID,
|
||||
GroupID: sub.GroupID,
|
||||
@@ -344,14 +491,10 @@ func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscriptio
|
||||
DailyUsageUSD: sub.DailyUsageUSD,
|
||||
WeeklyUsageUSD: sub.WeeklyUsageUSD,
|
||||
MonthlyUsageUSD: sub.MonthlyUsageUSD,
|
||||
AssignedBy: sub.AssignedBy,
|
||||
AssignedAt: sub.AssignedAt,
|
||||
Notes: sub.Notes,
|
||||
CreatedAt: sub.CreatedAt,
|
||||
UpdatedAt: sub.UpdatedAt,
|
||||
User: UserFromServiceShallow(sub.User),
|
||||
Group: GroupFromServiceShallow(sub.Group),
|
||||
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -359,9 +502,9 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
subs := make([]UserSubscription, 0, len(r.Subscriptions))
|
||||
subs := make([]AdminUserSubscription, 0, len(r.Subscriptions))
|
||||
for i := range r.Subscriptions {
|
||||
subs = append(subs, *UserSubscriptionFromService(&r.Subscriptions[i]))
|
||||
subs = append(subs, *UserSubscriptionFromServiceAdmin(&r.Subscriptions[i]))
|
||||
}
|
||||
return &BulkAssignResult{
|
||||
SuccessCount: r.SuccessCount,
|
||||
|
||||
@@ -2,8 +2,12 @@ package dto
|
||||
|
||||
// SystemSettings represents the admin settings API response payload.
|
||||
type SystemSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
@@ -22,13 +26,14 @@ type SystemSettings struct {
|
||||
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
|
||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
@@ -52,19 +57,23 @@ type SystemSettings struct {
|
||||
}
|
||||
|
||||
type PublicSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
Version string `json:"version"`
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// StreamTimeoutSettings 流超时处理配置 DTO
|
||||
|
||||
@@ -6,7 +6,6 @@ type User struct {
|
||||
ID int64 `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Notes string `json:"notes"`
|
||||
Role string `json:"role"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
@@ -19,6 +18,14 @@ type User struct {
|
||||
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
|
||||
}
|
||||
|
||||
// AdminUser 是管理员接口使用的 user DTO(包含敏感/内部字段)。
|
||||
// 注意:普通用户接口不得返回 notes 等管理员备注信息。
|
||||
type AdminUser struct {
|
||||
User
|
||||
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
@@ -60,6 +67,16 @@ type Group struct {
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// AdminGroup 是管理员接口使用的 group DTO(包含敏感/内部字段)。
|
||||
// 注意:普通用户接口不得返回 model_routing/account_count/account_groups 等内部信息。
|
||||
type AdminGroup struct {
|
||||
Group
|
||||
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
AccountCount int64 `json:"account_count,omitempty"`
|
||||
@@ -76,6 +93,7 @@ type Account struct {
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
Status string `json:"status"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
LastUsedAt *time.Time `json:"last_used_at"`
|
||||
@@ -97,6 +115,25 @@ type Account struct {
|
||||
SessionWindowEnd *time.Time `json:"session_window_end"`
|
||||
SessionWindowStatus string `json:"session_window_status"`
|
||||
|
||||
// 5h窗口费用控制(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||
// 从 extra 字段提取,方便前端显示和编辑
|
||||
WindowCostLimit *float64 `json:"window_cost_limit,omitempty"`
|
||||
WindowCostStickyReserve *float64 `json:"window_cost_sticky_reserve,omitempty"`
|
||||
|
||||
// 会话数量控制(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||
// 从 extra 字段提取,方便前端显示和编辑
|
||||
MaxSessions *int `json:"max_sessions,omitempty"`
|
||||
SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"`
|
||||
|
||||
// TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||
// 从 extra 字段提取,方便前端显示和编辑
|
||||
EnableTLSFingerprint *bool `json:"enable_tls_fingerprint,omitempty"`
|
||||
|
||||
// 会话ID伪装(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||
// 启用后将在15分钟内固定 metadata.user_id 中的 session ID
|
||||
// 从 extra 字段提取,方便前端显示和编辑
|
||||
EnableSessionIDMasking *bool `json:"session_id_masking_enabled,omitempty"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
|
||||
@@ -129,7 +166,23 @@ type Proxy struct {
|
||||
|
||||
type ProxyWithAccountCount struct {
|
||||
Proxy
|
||||
AccountCount int64 `json:"account_count"`
|
||||
AccountCount int64 `json:"account_count"`
|
||||
LatencyMs *int64 `json:"latency_ms,omitempty"`
|
||||
LatencyStatus string `json:"latency_status,omitempty"`
|
||||
LatencyMessage string `json:"latency_message,omitempty"`
|
||||
IPAddress string `json:"ip_address,omitempty"`
|
||||
Country string `json:"country,omitempty"`
|
||||
CountryCode string `json:"country_code,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
City string `json:"city,omitempty"`
|
||||
}
|
||||
|
||||
type ProxyAccountSummary struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Notes *string `json:"notes,omitempty"`
|
||||
}
|
||||
|
||||
type RedeemCode struct {
|
||||
@@ -140,7 +193,6 @@ type RedeemCode struct {
|
||||
Status string `json:"status"`
|
||||
UsedBy *int64 `json:"used_by"`
|
||||
UsedAt *time.Time `json:"used_at"`
|
||||
Notes string `json:"notes"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
GroupID *int64 `json:"group_id"`
|
||||
@@ -150,6 +202,15 @@ type RedeemCode struct {
|
||||
Group *Group `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
// AdminRedeemCode 是管理员接口使用的 redeem code DTO(包含 notes 等字段)。
|
||||
// 注意:普通用户接口不得返回 notes 等内部信息。
|
||||
type AdminRedeemCode struct {
|
||||
RedeemCode
|
||||
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// UsageLog 是普通用户接口使用的 usage log DTO(不包含管理员字段)。
|
||||
type UsageLog struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
@@ -189,18 +250,55 @@ type UsageLog struct {
|
||||
// User-Agent
|
||||
UserAgent *string `json:"user_agent"`
|
||||
|
||||
// IP 地址(仅管理员可见)
|
||||
IPAddress *string `json:"ip_address,omitempty"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
APIKey *APIKey `json:"api_key,omitempty"`
|
||||
Account *AccountSummary `json:"account,omitempty"` // Use minimal AccountSummary to prevent data leakage
|
||||
Group *Group `json:"group,omitempty"`
|
||||
Subscription *UserSubscription `json:"subscription,omitempty"`
|
||||
}
|
||||
|
||||
// AdminUsageLog 是管理员接口使用的 usage log DTO(包含管理员字段)。
|
||||
type AdminUsageLog struct {
|
||||
UsageLog
|
||||
|
||||
// AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理)
|
||||
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
|
||||
|
||||
// IPAddress 用户请求 IP(仅管理员可见)
|
||||
IPAddress *string `json:"ip_address,omitempty"`
|
||||
|
||||
// Account 最小账号信息(避免泄露敏感字段)
|
||||
Account *AccountSummary `json:"account,omitempty"`
|
||||
}
|
||||
|
||||
type UsageCleanupFilters struct {
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
UserID *int64 `json:"user_id,omitempty"`
|
||||
APIKeyID *int64 `json:"api_key_id,omitempty"`
|
||||
AccountID *int64 `json:"account_id,omitempty"`
|
||||
GroupID *int64 `json:"group_id,omitempty"`
|
||||
Model *string `json:"model,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
BillingType *int8 `json:"billing_type,omitempty"`
|
||||
}
|
||||
|
||||
type UsageCleanupTask struct {
|
||||
ID int64 `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Filters UsageCleanupFilters `json:"filters"`
|
||||
CreatedBy int64 `json:"created_by"`
|
||||
DeletedRows int64 `json:"deleted_rows"`
|
||||
ErrorMessage *string `json:"error_message,omitempty"`
|
||||
CanceledBy *int64 `json:"canceled_by,omitempty"`
|
||||
CanceledAt *time.Time `json:"canceled_at,omitempty"`
|
||||
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||
FinishedAt *time.Time `json:"finished_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// AccountSummary is a minimal account info for usage log display.
|
||||
// It intentionally excludes sensitive fields like Credentials, Proxy, etc.
|
||||
type AccountSummary struct {
|
||||
@@ -232,23 +330,30 @@ type UserSubscription struct {
|
||||
WeeklyUsageUSD float64 `json:"weekly_usage_usd"`
|
||||
MonthlyUsageUSD float64 `json:"monthly_usage_usd"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
// AdminUserSubscription 是管理员接口使用的订阅 DTO(包含分配信息/备注等字段)。
|
||||
// 注意:普通用户接口不得返回 assigned_by/assigned_at/notes/assigned_by_user 等管理员字段。
|
||||
type AdminUserSubscription struct {
|
||||
UserSubscription
|
||||
|
||||
AssignedBy *int64 `json:"assigned_by"`
|
||||
AssignedAt time.Time `json:"assigned_at"`
|
||||
Notes string `json:"notes"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
AssignedByUser *User `json:"assigned_by_user,omitempty"`
|
||||
AssignedByUser *User `json:"assigned_by_user,omitempty"`
|
||||
}
|
||||
|
||||
type BulkAssignResult struct {
|
||||
SuccessCount int `json:"success_count"`
|
||||
FailedCount int `json:"failed_count"`
|
||||
Subscriptions []UserSubscription `json:"subscriptions"`
|
||||
Errors []string `json:"errors"`
|
||||
SuccessCount int `json:"success_count"`
|
||||
FailedCount int `json:"failed_count"`
|
||||
Subscriptions []AdminUserSubscription `json:"subscriptions"`
|
||||
Errors []string `json:"errors"`
|
||||
}
|
||||
|
||||
// PromoCode 注册优惠码
|
||||
|
||||
@@ -31,6 +31,8 @@ type GatewayHandler struct {
|
||||
userService *service.UserService
|
||||
billingCacheService *service.BillingCacheService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
maxAccountSwitchesGemini int
|
||||
}
|
||||
|
||||
// NewGatewayHandler creates a new GatewayHandler
|
||||
@@ -44,8 +46,16 @@ func NewGatewayHandler(
|
||||
cfg *config.Config,
|
||||
) *GatewayHandler {
|
||||
pingInterval := time.Duration(0)
|
||||
maxAccountSwitches := 10
|
||||
maxAccountSwitchesGemini := 3
|
||||
if cfg != nil {
|
||||
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
|
||||
if cfg.Gateway.MaxAccountSwitches > 0 {
|
||||
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
|
||||
}
|
||||
if cfg.Gateway.MaxAccountSwitchesGemini > 0 {
|
||||
maxAccountSwitchesGemini = cfg.Gateway.MaxAccountSwitchesGemini
|
||||
}
|
||||
}
|
||||
return &GatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
@@ -54,6 +64,8 @@ func NewGatewayHandler(
|
||||
userService: userService,
|
||||
billingCacheService: billingCacheService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,13 +191,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
if platform == service.PlatformGemini {
|
||||
const maxAccountSwitches = 3
|
||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
@@ -197,17 +209,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||
if account.IsInterceptWarmupEnabled() {
|
||||
interceptType := detectInterceptType(body)
|
||||
if interceptType != InterceptTypeNone {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
if reqStream {
|
||||
sendMockInterceptStream(c, reqModel, interceptType)
|
||||
} else {
|
||||
sendMockInterceptResponse(c, reqModel, interceptType)
|
||||
}
|
||||
return
|
||||
}
|
||||
if reqStream {
|
||||
sendMockWarmupStream(c, reqModel)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, reqModel)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
@@ -313,14 +328,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
const maxAccountSwitches = 10
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
for {
|
||||
// 选择支持该模型的账号
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
@@ -332,17 +347,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||
if account.IsInterceptWarmupEnabled() {
|
||||
interceptType := detectInterceptType(body)
|
||||
if interceptType != InterceptTypeNone {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
if reqStream {
|
||||
sendMockInterceptStream(c, reqModel, interceptType)
|
||||
} else {
|
||||
sendMockInterceptResponse(c, reqModel, interceptType)
|
||||
}
|
||||
return
|
||||
}
|
||||
if reqStream {
|
||||
sendMockWarmupStream(c, reqModel)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, reqModel)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
@@ -753,17 +771,30 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// isWarmupRequest 检测是否为预热请求(标题生成、Warmup等)
|
||||
func isWarmupRequest(body []byte) bool {
|
||||
// 快速检查:如果body不包含关键字,直接返回false
|
||||
// InterceptType 表示请求拦截类型
|
||||
type InterceptType int
|
||||
|
||||
const (
|
||||
InterceptTypeNone InterceptType = iota
|
||||
InterceptTypeWarmup // 预热请求(返回 "New Conversation")
|
||||
InterceptTypeSuggestionMode // SUGGESTION MODE(返回空字符串)
|
||||
)
|
||||
|
||||
// detectInterceptType 检测请求是否需要拦截,返回拦截类型
|
||||
func detectInterceptType(body []byte) InterceptType {
|
||||
// 快速检查:如果不包含任何关键字,直接返回
|
||||
bodyStr := string(body)
|
||||
if !strings.Contains(bodyStr, "title") && !strings.Contains(bodyStr, "Warmup") {
|
||||
return false
|
||||
hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:")
|
||||
hasWarmupKeyword := strings.Contains(bodyStr, "title") || strings.Contains(bodyStr, "Warmup")
|
||||
|
||||
if !hasSuggestionMode && !hasWarmupKeyword {
|
||||
return InterceptTypeNone
|
||||
}
|
||||
|
||||
// 解析完整请求
|
||||
// 解析请求(只解析一次)
|
||||
var req struct {
|
||||
Messages []struct {
|
||||
Role string `json:"role"`
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
@@ -774,43 +805,71 @@ func isWarmupRequest(body []byte) bool {
|
||||
} `json:"system"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return false
|
||||
return InterceptTypeNone
|
||||
}
|
||||
|
||||
// 检查 messages 中的标题提示模式
|
||||
for _, msg := range req.Messages {
|
||||
for _, content := range msg.Content {
|
||||
if content.Type == "text" {
|
||||
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
|
||||
content.Text == "Warmup" {
|
||||
return true
|
||||
// 检查 SUGGESTION MODE(最后一条 user 消息)
|
||||
if hasSuggestionMode && len(req.Messages) > 0 {
|
||||
lastMsg := req.Messages[len(req.Messages)-1]
|
||||
if lastMsg.Role == "user" && len(lastMsg.Content) > 0 &&
|
||||
lastMsg.Content[0].Type == "text" &&
|
||||
strings.HasPrefix(lastMsg.Content[0].Text, "[SUGGESTION MODE:") {
|
||||
return InterceptTypeSuggestionMode
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 Warmup 请求
|
||||
if hasWarmupKeyword {
|
||||
// 检查 messages 中的标题提示模式
|
||||
for _, msg := range req.Messages {
|
||||
for _, content := range msg.Content {
|
||||
if content.Type == "text" {
|
||||
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
|
||||
content.Text == "Warmup" {
|
||||
return InterceptTypeWarmup
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// 检查 system 中的标题提取模式
|
||||
for _, sys := range req.System {
|
||||
if strings.Contains(sys.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
|
||||
return InterceptTypeWarmup
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 system 中的标题提取模式
|
||||
for _, system := range req.System {
|
||||
if strings.Contains(system.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return InterceptTypeNone
|
||||
}
|
||||
|
||||
// sendMockWarmupStream 发送流式 mock 响应(用于预热请求拦截)
|
||||
func sendMockWarmupStream(c *gin.Context, model string) {
|
||||
// sendMockInterceptStream 发送流式 mock 响应(用于请求拦截)
|
||||
func sendMockInterceptStream(c *gin.Context, model string, interceptType InterceptType) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
// 根据拦截类型决定响应内容
|
||||
var msgID string
|
||||
var outputTokens int
|
||||
var textDeltas []string
|
||||
|
||||
switch interceptType {
|
||||
case InterceptTypeSuggestionMode:
|
||||
msgID = "msg_mock_suggestion"
|
||||
outputTokens = 1
|
||||
textDeltas = []string{""} // 空内容
|
||||
default: // InterceptTypeWarmup
|
||||
msgID = "msg_mock_warmup"
|
||||
outputTokens = 2
|
||||
textDeltas = []string{"New", " Conversation"}
|
||||
}
|
||||
|
||||
// Build message_start event with proper JSON marshaling
|
||||
messageStart := map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": "msg_mock_warmup",
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
@@ -825,16 +884,46 @@ func sendMockWarmupStream(c *gin.Context, model string) {
|
||||
}
|
||||
messageStartJSON, _ := json.Marshal(messageStart)
|
||||
|
||||
// Build events
|
||||
events := []string{
|
||||
`event: message_start` + "\n" + `data: ` + string(messageStartJSON),
|
||||
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
|
||||
`event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
|
||||
`event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
|
||||
`event: content_block_stop` + "\n" + `data: {"index":0,"type":"content_block_stop"}`,
|
||||
`event: message_delta` + "\n" + `data: {"delta":{"stop_reason":"end_turn","stop_sequence":null},"type":"message_delta","usage":{"input_tokens":10,"output_tokens":2}}`,
|
||||
`event: message_stop` + "\n" + `data: {"type":"message_stop"}`,
|
||||
}
|
||||
|
||||
// Add text deltas
|
||||
for _, text := range textDeltas {
|
||||
delta := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": map[string]string{
|
||||
"type": "text_delta",
|
||||
"text": text,
|
||||
},
|
||||
}
|
||||
deltaJSON, _ := json.Marshal(delta)
|
||||
events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON))
|
||||
}
|
||||
|
||||
// Add final events
|
||||
messageDelta := map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": map[string]int{
|
||||
"input_tokens": 10,
|
||||
"output_tokens": outputTokens,
|
||||
},
|
||||
}
|
||||
messageDeltaJSON, _ := json.Marshal(messageDelta)
|
||||
|
||||
events = append(events,
|
||||
`event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`,
|
||||
`event: message_delta`+"\n"+`data: `+string(messageDeltaJSON),
|
||||
`event: message_stop`+"\n"+`data: {"type":"message_stop"}`,
|
||||
)
|
||||
|
||||
for _, event := range events {
|
||||
_, _ = c.Writer.WriteString(event + "\n\n")
|
||||
c.Writer.Flush()
|
||||
@@ -842,18 +931,32 @@ func sendMockWarmupStream(c *gin.Context, model string) {
|
||||
}
|
||||
}
|
||||
|
||||
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
|
||||
func sendMockWarmupResponse(c *gin.Context, model string) {
|
||||
// sendMockInterceptResponse 发送非流式 mock 响应(用于请求拦截)
|
||||
func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) {
|
||||
var msgID, text string
|
||||
var outputTokens int
|
||||
|
||||
switch interceptType {
|
||||
case InterceptTypeSuggestionMode:
|
||||
msgID = "msg_mock_suggestion"
|
||||
text = ""
|
||||
outputTokens = 1
|
||||
default: // InterceptTypeWarmup
|
||||
msgID = "msg_mock_warmup"
|
||||
text = "New Conversation"
|
||||
outputTokens = 2
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"id": "msg_mock_warmup",
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
|
||||
"content": []gin.H{{"type": "text", "text": text}},
|
||||
"stop_reason": "end_turn",
|
||||
"usage": gin.H{
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 2,
|
||||
"output_tokens": outputTokens,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
122
backend/internal/handler/gemini_cli_session_test.go
Normal file
122
backend/internal/handler/gemini_cli_session_test.go
Normal file
@@ -0,0 +1,122 @@
|
||||
//go:build unit
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExtractGeminiCLISessionHash(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
privilegedUserID string
|
||||
wantEmpty bool
|
||||
wantHash string
|
||||
}{
|
||||
{
|
||||
name: "with privileged-user-id and tmp dir",
|
||||
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
|
||||
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
|
||||
wantEmpty: false,
|
||||
wantHash: func() string {
|
||||
combined := "90785f52-8bbe-4b17-b111-a1ddea1636c3:f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"
|
||||
hash := sha256.Sum256([]byte(combined))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "without privileged-user-id but with tmp dir",
|
||||
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
|
||||
privilegedUserID: "",
|
||||
wantEmpty: false,
|
||||
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
|
||||
},
|
||||
{
|
||||
name: "without tmp dir",
|
||||
body: `{"contents":[{"parts":[{"text":"Hello world"}]}]}`,
|
||||
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
body: "",
|
||||
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
|
||||
wantEmpty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 创建测试上下文
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/test", nil)
|
||||
if tt.privilegedUserID != "" {
|
||||
c.Request.Header.Set("x-gemini-api-privileged-user-id", tt.privilegedUserID)
|
||||
}
|
||||
|
||||
// 调用函数
|
||||
result := extractGeminiCLISessionHash(c, []byte(tt.body))
|
||||
|
||||
// 验证结果
|
||||
if tt.wantEmpty {
|
||||
require.Empty(t, result, "expected empty session hash")
|
||||
} else {
|
||||
require.NotEmpty(t, result, "expected non-empty session hash")
|
||||
require.Equal(t, tt.wantHash, result, "session hash mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiCLITmpDirRegex(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantMatch bool
|
||||
wantHash string
|
||||
}{
|
||||
{
|
||||
name: "valid tmp dir path",
|
||||
input: "/Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
|
||||
wantMatch: true,
|
||||
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
|
||||
},
|
||||
{
|
||||
name: "valid tmp dir path in text",
|
||||
input: "The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740\nOther text",
|
||||
wantMatch: true,
|
||||
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
|
||||
},
|
||||
{
|
||||
name: "invalid hash length",
|
||||
input: "/Users/ianshaw/.gemini/tmp/abc123",
|
||||
wantMatch: false,
|
||||
},
|
||||
{
|
||||
name: "no tmp dir",
|
||||
input: "Hello world",
|
||||
wantMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
match := geminiCLITmpDirRegex.FindStringSubmatch(tt.input)
|
||||
if tt.wantMatch {
|
||||
require.NotNil(t, match, "expected regex to match")
|
||||
require.Len(t, match, 2, "expected 2 capture groups")
|
||||
require.Equal(t, tt.wantHash, match[1], "hash mismatch")
|
||||
} else {
|
||||
require.Nil(t, match, "expected regex not to match")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,15 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -19,6 +23,17 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
|
||||
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
|
||||
var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
|
||||
|
||||
func isGeminiCLIRequest(c *gin.Context, body []byte) bool {
|
||||
if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" {
|
||||
return true
|
||||
}
|
||||
return geminiCLITmpDirRegex.Match(body)
|
||||
}
|
||||
|
||||
// GeminiV1BetaListModels proxies:
|
||||
// GET /v1beta/models
|
||||
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
|
||||
@@ -214,19 +229,33 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 3) select account (sticky session based on request body)
|
||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
// 优先使用 Gemini CLI 的会话标识(privileged-user-id + tmp 目录哈希)
|
||||
sessionHash := extractGeminiCLISessionHash(c, body)
|
||||
if sessionHash == "" {
|
||||
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
|
||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
}
|
||||
sessionKey := sessionHash
|
||||
if sessionHash != "" {
|
||||
sessionKey = "gemini:" + sessionHash
|
||||
}
|
||||
const maxAccountSwitches = 3
|
||||
|
||||
// 查询粘性会话绑定的账号 ID(用于检测账号切换)
|
||||
var sessionBoundAccountID int64
|
||||
if sessionKey != "" {
|
||||
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
|
||||
}
|
||||
isCLI := isGeminiCLIRequest(c, body)
|
||||
cleanedForUnknownBinding := false
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs)
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
@@ -238,6 +267,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
|
||||
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
|
||||
// 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
|
||||
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
|
||||
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
|
||||
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||||
sessionBoundAccountID = account.ID
|
||||
} else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
|
||||
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,CLI 继续携带旧签名。
|
||||
// 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
|
||||
log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively")
|
||||
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||||
cleanedForUnknownBinding = true
|
||||
sessionBoundAccountID = account.ID
|
||||
} else if sessionBoundAccountID == 0 {
|
||||
// 记录本次请求中首次选择到的账号,便于同一请求内 failover 时检测切换。
|
||||
sessionBoundAccountID = account.ID
|
||||
}
|
||||
|
||||
// 4) account concurrency slot
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
@@ -433,3 +480,38 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
|
||||
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
|
||||
//
|
||||
// 会话标识生成策略:
|
||||
// 1. 从请求体中提取 tmp 目录哈希(64位十六进制)
|
||||
// 2. 从 header 中提取 privileged-user-id(UUID)
|
||||
// 3. 组合两者生成 SHA256 哈希作为最终的会话标识
|
||||
//
|
||||
// 如果找不到 tmp 目录哈希,返回空字符串(不使用粘性会话)。
|
||||
//
|
||||
// extractGeminiCLISessionHash extracts session identifier from Gemini CLI requests.
|
||||
// Combines x-gemini-api-privileged-user-id header with tmp directory hash from request body.
|
||||
func extractGeminiCLISessionHash(c *gin.Context, body []byte) string {
|
||||
// 1. 从请求体中提取 tmp 目录哈希
|
||||
match := geminiCLITmpDirRegex.FindSubmatch(body)
|
||||
if len(match) < 2 {
|
||||
return "" // 没有找到 tmp 目录,不使用粘性会话
|
||||
}
|
||||
tmpDirHash := string(match[1])
|
||||
|
||||
// 2. 提取 privileged-user-id
|
||||
privilegedUserID := strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id"))
|
||||
|
||||
// 3. 组合生成最终的 session hash
|
||||
if privilegedUserID != "" {
|
||||
// 组合两个标识符:privileged-user-id + tmp 目录哈希
|
||||
combined := privilegedUserID + ":" + tmpDirHash
|
||||
hash := sha256.Sum256([]byte(combined))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// 如果没有 privileged-user-id,直接使用 tmp 目录哈希
|
||||
return tmpDirHash
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ type Handlers struct {
|
||||
Gateway *GatewayHandler
|
||||
OpenAIGateway *OpenAIGatewayHandler
|
||||
Setting *SettingHandler
|
||||
Totp *TotpHandler
|
||||
}
|
||||
|
||||
// BuildInfo contains build-time information
|
||||
|
||||
@@ -25,6 +25,7 @@ type OpenAIGatewayHandler struct {
|
||||
gatewayService *service.OpenAIGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||
@@ -35,13 +36,18 @@ func NewOpenAIGatewayHandler(
|
||||
cfg *config.Config,
|
||||
) *OpenAIGatewayHandler {
|
||||
pingInterval := time.Duration(0)
|
||||
maxAccountSwitches := 3
|
||||
if cfg != nil {
|
||||
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
|
||||
if cfg.Gateway.MaxAccountSwitches > 0 {
|
||||
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
|
||||
}
|
||||
}
|
||||
return &OpenAIGatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,6 +120,26 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
|
||||
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
|
||||
// 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call,
|
||||
// 或带 id 且与 call_id 匹配的 item_reference。
|
||||
if service.HasFunctionCallOutput(reqBody) {
|
||||
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
||||
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
|
||||
if service.HasFunctionCallOutputMissingCallID(reqBody) {
|
||||
log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel)
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
|
||||
return
|
||||
}
|
||||
callIDs := service.FunctionCallOutputCallIDs(reqBody)
|
||||
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
|
||||
log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel)
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Track if we've started streaming (for error handling)
|
||||
streamStarted := false
|
||||
|
||||
@@ -166,10 +192,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Generate session hash (from header for OpenAI)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c)
|
||||
// Generate session hash (header first; fallback to prompt_cache_key)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody)
|
||||
|
||||
const maxAccountSwitches = 3
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
@@ -544,6 +544,11 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
body := w.buf.Bytes()
|
||||
parsed := parseOpsErrorResponse(body)
|
||||
|
||||
// Skip logging if the error should be filtered based on settings
|
||||
if shouldSkipOpsErrorLog(c.Request.Context(), ops, parsed.Message, string(body), c.Request.URL.Path) {
|
||||
return
|
||||
}
|
||||
|
||||
apiKey, _ := middleware2.GetAPIKeyFromContext(c)
|
||||
|
||||
clientRequestID, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
|
||||
@@ -832,28 +837,30 @@ func normalizeOpsErrorType(errType string, code string) string {
|
||||
|
||||
func classifyOpsPhase(errType, message, code string) string {
|
||||
msg := strings.ToLower(message)
|
||||
// Standardized phases: request|auth|routing|upstream|network|internal
|
||||
// Map billing/concurrency/response => request; scheduling => routing.
|
||||
switch strings.TrimSpace(code) {
|
||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
||||
return "billing"
|
||||
return "request"
|
||||
}
|
||||
|
||||
switch errType {
|
||||
case "authentication_error":
|
||||
return "auth"
|
||||
case "billing_error", "subscription_error":
|
||||
return "billing"
|
||||
return "request"
|
||||
case "rate_limit_error":
|
||||
if strings.Contains(msg, "concurrency") || strings.Contains(msg, "pending") || strings.Contains(msg, "queue") {
|
||||
return "concurrency"
|
||||
return "request"
|
||||
}
|
||||
return "upstream"
|
||||
case "invalid_request_error":
|
||||
return "response"
|
||||
return "request"
|
||||
case "upstream_error", "overloaded_error":
|
||||
return "upstream"
|
||||
case "api_error":
|
||||
if strings.Contains(msg, "no available accounts") {
|
||||
return "scheduling"
|
||||
return "routing"
|
||||
}
|
||||
return "internal"
|
||||
default:
|
||||
@@ -914,34 +921,38 @@ func classifyOpsIsBusinessLimited(errType, phase, code string, status int, messa
|
||||
}
|
||||
|
||||
func classifyOpsErrorOwner(phase string, message string) string {
|
||||
// Standardized owners: client|provider|platform
|
||||
switch phase {
|
||||
case "upstream", "network":
|
||||
return "provider"
|
||||
case "billing", "concurrency", "auth", "response":
|
||||
case "request", "auth":
|
||||
return "client"
|
||||
case "routing", "internal":
|
||||
return "platform"
|
||||
default:
|
||||
if strings.Contains(strings.ToLower(message), "upstream") {
|
||||
return "provider"
|
||||
}
|
||||
return "sub2api"
|
||||
return "platform"
|
||||
}
|
||||
}
|
||||
|
||||
func classifyOpsErrorSource(phase string, message string) string {
|
||||
// Standardized sources: client_request|upstream_http|gateway
|
||||
switch phase {
|
||||
case "upstream":
|
||||
return "upstream_http"
|
||||
case "network":
|
||||
return "upstream_network"
|
||||
case "billing":
|
||||
return "billing"
|
||||
case "concurrency":
|
||||
return "concurrency"
|
||||
return "gateway"
|
||||
case "request", "auth":
|
||||
return "client_request"
|
||||
case "routing", "internal":
|
||||
return "gateway"
|
||||
default:
|
||||
if strings.Contains(strings.ToLower(message), "upstream") {
|
||||
return "upstream_http"
|
||||
}
|
||||
return "internal"
|
||||
return "gateway"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -963,3 +974,42 @@ func truncateString(s string, max int) string {
|
||||
func strconvItoa(v int) string {
|
||||
return strconv.Itoa(v)
|
||||
}
|
||||
|
||||
// shouldSkipOpsErrorLog determines if an error should be skipped from logging based on settings.
|
||||
// Returns true for errors that should be filtered according to OpsAdvancedSettings.
|
||||
func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message, body, requestPath string) bool {
|
||||
if ops == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Get advanced settings to check filter configuration
|
||||
settings, err := ops.GetOpsAdvancedSettings(ctx)
|
||||
if err != nil || settings == nil {
|
||||
// If we can't get settings, don't skip (fail open)
|
||||
return false
|
||||
}
|
||||
|
||||
msgLower := strings.ToLower(message)
|
||||
bodyLower := strings.ToLower(body)
|
||||
|
||||
// Check if count_tokens errors should be ignored
|
||||
if settings.IgnoreCountTokensErrors && strings.Contains(requestPath, "/count_tokens") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if context canceled errors should be ignored (client disconnects)
|
||||
if settings.IgnoreContextCanceled {
|
||||
if strings.Contains(msgLower, "context canceled") || strings.Contains(bodyLower, "context canceled") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if "no available accounts" errors should be ignored
|
||||
if settings.IgnoreNoAvailableAccounts {
|
||||
if strings.Contains(msgLower, "no available accounts") || strings.Contains(bodyLower, "no available accounts") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -32,18 +32,21 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
response.Success(c, dto.PublicSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
Version: h.version,
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
181
backend/internal/handler/totp_handler.go
Normal file
181
backend/internal/handler/totp_handler.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// TotpHandler handles TOTP-related requests
|
||||
type TotpHandler struct {
|
||||
totpService *service.TotpService
|
||||
}
|
||||
|
||||
// NewTotpHandler creates a new TotpHandler
|
||||
func NewTotpHandler(totpService *service.TotpService) *TotpHandler {
|
||||
return &TotpHandler{
|
||||
totpService: totpService,
|
||||
}
|
||||
}
|
||||
|
||||
// TotpStatusResponse represents the TOTP status response
|
||||
type TotpStatusResponse struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
EnabledAt *int64 `json:"enabled_at,omitempty"` // Unix timestamp
|
||||
FeatureEnabled bool `json:"feature_enabled"`
|
||||
}
|
||||
|
||||
// GetStatus returns the TOTP status for the current user
|
||||
// GET /api/v1/user/totp/status
|
||||
func (h *TotpHandler) GetStatus(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
status, err := h.totpService.GetStatus(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
resp := TotpStatusResponse{
|
||||
Enabled: status.Enabled,
|
||||
FeatureEnabled: status.FeatureEnabled,
|
||||
}
|
||||
|
||||
if status.EnabledAt != nil {
|
||||
ts := status.EnabledAt.Unix()
|
||||
resp.EnabledAt = &ts
|
||||
}
|
||||
|
||||
response.Success(c, resp)
|
||||
}
|
||||
|
||||
// TotpSetupRequest represents the request to initiate TOTP setup
|
||||
type TotpSetupRequest struct {
|
||||
EmailCode string `json:"email_code"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// TotpSetupResponse represents the TOTP setup response
|
||||
type TotpSetupResponse struct {
|
||||
Secret string `json:"secret"`
|
||||
QRCodeURL string `json:"qr_code_url"`
|
||||
SetupToken string `json:"setup_token"`
|
||||
Countdown int `json:"countdown"`
|
||||
}
|
||||
|
||||
// InitiateSetup starts the TOTP setup process
|
||||
// POST /api/v1/user/totp/setup
|
||||
func (h *TotpHandler) InitiateSetup(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req TotpSetupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
// Allow empty body (optional params)
|
||||
req = TotpSetupRequest{}
|
||||
}
|
||||
|
||||
result, err := h.totpService.InitiateSetup(c.Request.Context(), subject.UserID, req.EmailCode, req.Password)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, TotpSetupResponse{
|
||||
Secret: result.Secret,
|
||||
QRCodeURL: result.QRCodeURL,
|
||||
SetupToken: result.SetupToken,
|
||||
Countdown: result.Countdown,
|
||||
})
|
||||
}
|
||||
|
||||
// TotpEnableRequest represents the request to enable TOTP
|
||||
type TotpEnableRequest struct {
|
||||
TotpCode string `json:"totp_code" binding:"required,len=6"`
|
||||
SetupToken string `json:"setup_token" binding:"required"`
|
||||
}
|
||||
|
||||
// Enable completes the TOTP setup
|
||||
// POST /api/v1/user/totp/enable
|
||||
func (h *TotpHandler) Enable(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req TotpEnableRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.totpService.CompleteSetup(c.Request.Context(), subject.UserID, req.TotpCode, req.SetupToken); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// TotpDisableRequest represents the request to disable TOTP
|
||||
type TotpDisableRequest struct {
|
||||
EmailCode string `json:"email_code"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// Disable disables TOTP for the current user
|
||||
// POST /api/v1/user/totp/disable
|
||||
func (h *TotpHandler) Disable(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req TotpDisableRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.totpService.Disable(c.Request.Context(), subject.UserID, req.EmailCode, req.Password); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// GetVerificationMethod returns the verification method for TOTP operations
|
||||
// GET /api/v1/user/totp/verification-method
|
||||
func (h *TotpHandler) GetVerificationMethod(c *gin.Context) {
|
||||
method := h.totpService.GetVerificationMethod(c.Request.Context())
|
||||
response.Success(c, method)
|
||||
}
|
||||
|
||||
// SendVerifyCode sends an email verification code for TOTP operations
|
||||
// POST /api/v1/user/totp/send-code
|
||||
func (h *TotpHandler) SendVerifyCode(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
@@ -47,9 +47,6 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 清空notes字段,普通用户不应看到备注
|
||||
userData.Notes = ""
|
||||
|
||||
response.Success(c, dto.UserFromService(userData))
|
||||
}
|
||||
|
||||
@@ -105,8 +102,5 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 清空notes字段,普通用户不应看到备注
|
||||
updatedUser.Notes = ""
|
||||
|
||||
response.Success(c, dto.UserFromService(updatedUser))
|
||||
}
|
||||
|
||||
@@ -70,6 +70,7 @@ func ProvideHandlers(
|
||||
gatewayHandler *GatewayHandler,
|
||||
openaiGatewayHandler *OpenAIGatewayHandler,
|
||||
settingHandler *SettingHandler,
|
||||
totpHandler *TotpHandler,
|
||||
) *Handlers {
|
||||
return &Handlers{
|
||||
Auth: authHandler,
|
||||
@@ -82,6 +83,7 @@ func ProvideHandlers(
|
||||
Gateway: gatewayHandler,
|
||||
OpenAIGateway: openaiGatewayHandler,
|
||||
Setting: settingHandler,
|
||||
Totp: totpHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,6 +98,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewSubscriptionHandler,
|
||||
NewGatewayHandler,
|
||||
NewOpenAIGatewayHandler,
|
||||
NewTotpHandler,
|
||||
ProvideSettingHandler,
|
||||
|
||||
// Admin handlers
|
||||
|
||||
@@ -2,7 +2,10 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -25,15 +28,34 @@ type RateLimitOptions struct {
|
||||
var rateLimitScript = redis.NewScript(`
|
||||
local current = redis.call('INCR', KEYS[1])
|
||||
local ttl = redis.call('PTTL', KEYS[1])
|
||||
if current == 1 or ttl == -1 then
|
||||
local repaired = 0
|
||||
if current == 1 then
|
||||
redis.call('PEXPIRE', KEYS[1], ARGV[1])
|
||||
elseif ttl == -1 then
|
||||
redis.call('PEXPIRE', KEYS[1], ARGV[1])
|
||||
repaired = 1
|
||||
end
|
||||
return current
|
||||
return {current, repaired}
|
||||
`)
|
||||
|
||||
// rateLimitRun 允许测试覆写脚本执行逻辑
|
||||
var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) {
|
||||
return rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Int64()
|
||||
var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) {
|
||||
values, err := rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Slice()
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
if len(values) < 2 {
|
||||
return 0, false, fmt.Errorf("rate limit script returned %d values", len(values))
|
||||
}
|
||||
count, err := parseInt64(values[0])
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
repaired, err := parseInt64(values[1])
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
return count, repaired == 1, nil
|
||||
}
|
||||
|
||||
// RateLimiter Redis 速率限制器
|
||||
@@ -74,8 +96,9 @@ func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Durati
|
||||
windowMillis := windowTTLMillis(window)
|
||||
|
||||
// 使用 Lua 脚本原子操作增加计数并设置过期
|
||||
count, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis)
|
||||
count, repaired, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis)
|
||||
if err != nil {
|
||||
log.Printf("[RateLimit] redis error: key=%s mode=%s err=%v", redisKey, failureModeLabel(failureMode), err)
|
||||
if failureMode == RateLimitFailClose {
|
||||
abortRateLimit(c)
|
||||
return
|
||||
@@ -84,6 +107,9 @@ func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Durati
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
if repaired {
|
||||
log.Printf("[RateLimit] ttl repaired: key=%s window_ms=%d", redisKey, windowMillis)
|
||||
}
|
||||
|
||||
// 超过限制
|
||||
if count > int64(limit) {
|
||||
@@ -109,3 +135,27 @@ func abortRateLimit(c *gin.Context) {
|
||||
"message": "Too many requests, please try again later",
|
||||
})
|
||||
}
|
||||
|
||||
func failureModeLabel(mode RateLimitFailureMode) string {
|
||||
if mode == RateLimitFailClose {
|
||||
return "fail-close"
|
||||
}
|
||||
return "fail-open"
|
||||
}
|
||||
|
||||
func parseInt64(value any) (int64, error) {
|
||||
switch v := value.(type) {
|
||||
case int64:
|
||||
return v, nil
|
||||
case int:
|
||||
return int64(v), nil
|
||||
case string:
|
||||
parsed, err := strconv.ParseInt(v, 10, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return parsed, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unexpected value type %T", value)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,9 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -88,6 +91,7 @@ func performRequest(router *gin.Engine) *httptest.ResponseRecorder {
|
||||
|
||||
func startRedis(t *testing.T, ctx context.Context) *redis.Client {
|
||||
t.Helper()
|
||||
ensureDockerAvailable(t)
|
||||
|
||||
redisContainer, err := tcredis.Run(ctx, redisImageTag)
|
||||
require.NoError(t, err)
|
||||
@@ -112,3 +116,43 @@ func startRedis(t *testing.T, ctx context.Context) *redis.Client {
|
||||
|
||||
return rdb
|
||||
}
|
||||
|
||||
func ensureDockerAvailable(t *testing.T) {
|
||||
t.Helper()
|
||||
if dockerAvailable() {
|
||||
return
|
||||
}
|
||||
t.Skip("Docker 未启用,跳过依赖 testcontainers 的集成测试")
|
||||
}
|
||||
|
||||
func dockerAvailable() bool {
|
||||
if os.Getenv("DOCKER_HOST") != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
socketCandidates := []string{
|
||||
"/var/run/docker.sock",
|
||||
filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"),
|
||||
filepath.Join(userHomeDir(), ".docker", "run", "docker.sock"),
|
||||
filepath.Join(userHomeDir(), ".docker", "desktop", "docker.sock"),
|
||||
filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"),
|
||||
}
|
||||
|
||||
for _, socket := range socketCandidates {
|
||||
if socket == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := os.Stat(socket); err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func userHomeDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return home
|
||||
}
|
||||
|
||||
@@ -66,13 +66,13 @@ func TestRateLimiterSuccessAndLimit(t *testing.T) {
|
||||
originalRun := rateLimitRun
|
||||
counts := []int64{1, 2}
|
||||
callIndex := 0
|
||||
rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) {
|
||||
rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) {
|
||||
if callIndex >= len(counts) {
|
||||
return counts[len(counts)-1], nil
|
||||
return counts[len(counts)-1], false, nil
|
||||
}
|
||||
value := counts[callIndex]
|
||||
callIndex++
|
||||
return value, nil
|
||||
return value, false, nil
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
rateLimitRun = originalRun
|
||||
|
||||
@@ -16,15 +16,6 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// resolveHost 从 URL 解析 host
|
||||
func resolveHost(urlStr string) string {
|
||||
parsed, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return parsed.Host
|
||||
}
|
||||
|
||||
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
|
||||
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
|
||||
// 构建 URL,流式请求添加 ?alt=sse 参数
|
||||
@@ -39,23 +30,11 @@ func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken stri
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 基础 Headers
|
||||
// 基础 Headers(与 Antigravity-Manager 保持一致,只设置这 3 个)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
||||
// Accept Header 根据请求类型设置
|
||||
if isStream {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
} else {
|
||||
req.Header.Set("Accept", "application/json")
|
||||
}
|
||||
|
||||
// 显式设置 Host Header
|
||||
if host := resolveHost(apiURL); host != "" {
|
||||
req.Host = host
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
@@ -195,12 +174,15 @@ func isConnectionError(err error) bool {
|
||||
}
|
||||
|
||||
// shouldFallbackToNextURL 判断是否应切换到下一个 URL
|
||||
// 仅连接错误和 HTTP 429 触发 URL 降级
|
||||
// 与 Antigravity-Manager 保持一致:连接错误、429、408、404、5xx 触发 URL 降级
|
||||
func shouldFallbackToNextURL(err error, statusCode int) bool {
|
||||
if isConnectionError(err) {
|
||||
return true
|
||||
}
|
||||
return statusCode == http.StatusTooManyRequests
|
||||
return statusCode == http.StatusTooManyRequests ||
|
||||
statusCode == http.StatusRequestTimeout ||
|
||||
statusCode == http.StatusNotFound ||
|
||||
statusCode >= 500
|
||||
}
|
||||
|
||||
// ExchangeCode 用 authorization code 交换 token
|
||||
@@ -321,11 +303,8 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
|
||||
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取可用的 URL 列表
|
||||
availableURLs := DefaultURLAvailability.GetAvailableURLs()
|
||||
if len(availableURLs) == 0 {
|
||||
availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有
|
||||
}
|
||||
// 固定顺序:prod -> daily
|
||||
availableURLs := BaseURLs
|
||||
|
||||
var lastErr error
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
@@ -343,7 +322,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("loadCodeAssist 请求失败: %w", err)
|
||||
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("[antigravity] loadCodeAssist URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
@@ -358,7 +336,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
|
||||
|
||||
// 检查是否需要 URL 降级
|
||||
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||||
DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("[antigravity] loadCodeAssist URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
@@ -376,6 +353,8 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
|
||||
var rawResp map[string]any
|
||||
_ = json.Unmarshal(respBodyBytes, &rawResp)
|
||||
|
||||
// 标记成功的 URL,下次优先使用
|
||||
DefaultURLAvailability.MarkSuccess(baseURL)
|
||||
return &loadResp, rawResp, nil
|
||||
}
|
||||
|
||||
@@ -412,11 +391,8 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取可用的 URL 列表
|
||||
availableURLs := DefaultURLAvailability.GetAvailableURLs()
|
||||
if len(availableURLs) == 0 {
|
||||
availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有
|
||||
}
|
||||
// 固定顺序:prod -> daily
|
||||
availableURLs := BaseURLs
|
||||
|
||||
var lastErr error
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
@@ -434,7 +410,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
|
||||
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("[antigravity] fetchAvailableModels URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
@@ -449,7 +424,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
|
||||
// 检查是否需要 URL 降级
|
||||
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||||
DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("[antigravity] fetchAvailableModels URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
@@ -467,6 +441,8 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
var rawResp map[string]any
|
||||
_ = json.Unmarshal(respBodyBytes, &rawResp)
|
||||
|
||||
// 标记成功的 URL,下次优先使用
|
||||
DefaultURLAvailability.MarkSuccess(baseURL)
|
||||
return &modelsResp, rawResp, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -143,9 +143,10 @@ type GeminiResponse struct {
|
||||
|
||||
// GeminiCandidate Gemini 候选响应
|
||||
type GeminiCandidate struct {
|
||||
Content *GeminiContent `json:"content,omitempty"`
|
||||
FinishReason string `json:"finishReason,omitempty"`
|
||||
Index int `json:"index,omitempty"`
|
||||
Content *GeminiContent `json:"content,omitempty"`
|
||||
FinishReason string `json:"finishReason,omitempty"`
|
||||
Index int `json:"index,omitempty"`
|
||||
GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiUsageMetadata Gemini 用量元数据
|
||||
@@ -156,6 +157,23 @@ type GeminiUsageMetadata struct {
|
||||
TotalTokenCount int `json:"totalTokenCount,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiGroundingMetadata Gemini grounding 元数据(Web Search)
|
||||
type GeminiGroundingMetadata struct {
|
||||
WebSearchQueries []string `json:"webSearchQueries,omitempty"`
|
||||
GroundingChunks []GeminiGroundingChunk `json:"groundingChunks,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiGroundingChunk Gemini grounding chunk
|
||||
type GeminiGroundingChunk struct {
|
||||
Web *GeminiGroundingWeb `json:"web,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiGroundingWeb Gemini grounding web 信息
|
||||
type GeminiGroundingWeb struct {
|
||||
Title string `json:"title,omitempty"`
|
||||
URI string `json:"uri,omitempty"`
|
||||
}
|
||||
|
||||
// DefaultSafetySettings 默认安全设置(关闭所有过滤)
|
||||
var DefaultSafetySettings = []GeminiSafetySetting{
|
||||
{Category: "HARM_CATEGORY_HARASSMENT", Threshold: "OFF"},
|
||||
|
||||
@@ -32,8 +32,8 @@ const (
|
||||
"https://www.googleapis.com/auth/cclog " +
|
||||
"https://www.googleapis.com/auth/experimentsandconfigs"
|
||||
|
||||
// User-Agent(模拟官方客户端)
|
||||
UserAgent = "antigravity/1.104.0 darwin/arm64"
|
||||
// User-Agent(与 Antigravity-Manager 保持一致)
|
||||
UserAgent = "antigravity/1.11.9 windows/amd64"
|
||||
|
||||
// Session 过期时间
|
||||
SessionTTL = 30 * time.Minute
|
||||
@@ -42,22 +42,21 @@ const (
|
||||
URLAvailabilityTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
// BaseURLs 定义 Antigravity API 端点,按优先级排序
|
||||
// fallback 顺序: sandbox → daily → prod
|
||||
// BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致)
|
||||
var BaseURLs = []string{
|
||||
"https://daily-cloudcode-pa.sandbox.googleapis.com", // sandbox
|
||||
"https://daily-cloudcode-pa.googleapis.com", // daily
|
||||
"https://cloudcode-pa.googleapis.com", // prod
|
||||
"https://cloudcode-pa.googleapis.com", // prod (优先)
|
||||
"https://daily-cloudcode-pa.sandbox.googleapis.com", // daily sandbox (备用)
|
||||
}
|
||||
|
||||
// BaseURL 默认 URL(保持向后兼容)
|
||||
var BaseURL = BaseURLs[0]
|
||||
|
||||
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复)
|
||||
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级)
|
||||
type URLAvailability struct {
|
||||
mu sync.RWMutex
|
||||
unavailable map[string]time.Time // URL -> 恢复时间
|
||||
ttl time.Duration
|
||||
lastSuccess string // 最近成功请求的 URL,优先使用
|
||||
}
|
||||
|
||||
// DefaultURLAvailability 全局 URL 可用性管理器
|
||||
@@ -78,6 +77,15 @@ func (u *URLAvailability) MarkUnavailable(url string) {
|
||||
u.unavailable[url] = time.Now().Add(u.ttl)
|
||||
}
|
||||
|
||||
// MarkSuccess 标记 URL 请求成功,将其设为优先使用
|
||||
func (u *URLAvailability) MarkSuccess(url string) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
u.lastSuccess = url
|
||||
// 成功后清除该 URL 的不可用标记
|
||||
delete(u.unavailable, url)
|
||||
}
|
||||
|
||||
// IsAvailable 检查 URL 是否可用
|
||||
func (u *URLAvailability) IsAvailable(url string) bool {
|
||||
u.mu.RLock()
|
||||
@@ -89,14 +97,29 @@ func (u *URLAvailability) IsAvailable(url string) bool {
|
||||
return time.Now().After(expiry)
|
||||
}
|
||||
|
||||
// GetAvailableURLs 返回可用的 URL 列表(保持优先级顺序)
|
||||
// GetAvailableURLs 返回可用的 URL 列表
|
||||
// 最近成功的 URL 优先,其他按默认顺序
|
||||
func (u *URLAvailability) GetAvailableURLs() []string {
|
||||
u.mu.RLock()
|
||||
defer u.mu.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
result := make([]string, 0, len(BaseURLs))
|
||||
|
||||
// 如果有最近成功的 URL 且可用,放在最前面
|
||||
if u.lastSuccess != "" {
|
||||
expiry, exists := u.unavailable[u.lastSuccess]
|
||||
if !exists || now.After(expiry) {
|
||||
result = append(result, u.lastSuccess)
|
||||
}
|
||||
}
|
||||
|
||||
// 添加其他可用的 URL(按默认顺序)
|
||||
for _, url := range BaseURLs {
|
||||
// 跳过已添加的 lastSuccess
|
||||
if url == u.lastSuccess {
|
||||
continue
|
||||
}
|
||||
expiry, exists := u.unavailable[url]
|
||||
if !exists || now.After(expiry) {
|
||||
result = append(result, url)
|
||||
@@ -240,24 +263,3 @@ func BuildAuthorizationURL(state, codeChallenge string) string {
|
||||
|
||||
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
|
||||
}
|
||||
|
||||
// GenerateMockProjectID 生成随机 project_id(当 API 不返回时使用)
|
||||
// 格式:{形容词}-{名词}-{5位随机字符}
|
||||
func GenerateMockProjectID() string {
|
||||
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
|
||||
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
|
||||
|
||||
randBytes, _ := GenerateRandomBytes(7)
|
||||
|
||||
adj := adjectives[int(randBytes[0])%len(adjectives)]
|
||||
noun := nouns[int(randBytes[1])%len(nouns)]
|
||||
|
||||
// 生成 5 位随机字符(a-z0-9)
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
suffix := make([]byte, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
suffix[i] = charset[int(randBytes[i+2])%len(charset)]
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s-%s-%s", adj, noun, string(suffix))
|
||||
}
|
||||
|
||||
@@ -7,13 +7,11 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
@@ -54,6 +52,9 @@ func DefaultTransformOptions() TransformOptions {
|
||||
}
|
||||
}
|
||||
|
||||
// webSearchFallbackModel web_search 请求使用的降级模型
|
||||
const webSearchFallbackModel = "gemini-2.5-flash"
|
||||
|
||||
// TransformClaudeToGemini 将 Claude 请求转换为 v1internal Gemini 格式
|
||||
func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel string) ([]byte, error) {
|
||||
return TransformClaudeToGeminiWithOptions(claudeReq, projectID, mappedModel, DefaultTransformOptions())
|
||||
@@ -64,12 +65,23 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
|
||||
// 用于存储 tool_use id -> name 映射
|
||||
toolIDToName := make(map[string]string)
|
||||
|
||||
// 检测是否有 web_search 工具
|
||||
hasWebSearchTool := hasWebSearchTool(claudeReq.Tools)
|
||||
requestType := "agent"
|
||||
targetModel := mappedModel
|
||||
if hasWebSearchTool {
|
||||
requestType = "web_search"
|
||||
if targetModel != webSearchFallbackModel {
|
||||
targetModel = webSearchFallbackModel
|
||||
}
|
||||
}
|
||||
|
||||
// 检测是否启用 thinking
|
||||
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
|
||||
|
||||
// 只有 Gemini 模型支持 dummy thought workaround
|
||||
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
|
||||
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
|
||||
allowDummyThought := strings.HasPrefix(targetModel, "gemini-")
|
||||
|
||||
// 1. 构建 contents
|
||||
contents, strippedThinking, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
|
||||
@@ -78,7 +90,7 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
|
||||
}
|
||||
|
||||
// 2. 构建 systemInstruction
|
||||
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts)
|
||||
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts, claudeReq.Tools)
|
||||
|
||||
// 3. 构建 generationConfig
|
||||
reqForConfig := claudeReq
|
||||
@@ -89,6 +101,11 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
|
||||
reqCopy.Thinking = nil
|
||||
reqForConfig = &reqCopy
|
||||
}
|
||||
if targetModel != "" && targetModel != reqForConfig.Model {
|
||||
reqCopy := *reqForConfig
|
||||
reqCopy.Model = targetModel
|
||||
reqForConfig = &reqCopy
|
||||
}
|
||||
generationConfig := buildGenerationConfig(reqForConfig)
|
||||
|
||||
// 4. 构建 tools
|
||||
@@ -127,8 +144,8 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
|
||||
Project: projectID,
|
||||
RequestID: "agent-" + uuid.New().String(),
|
||||
UserAgent: "antigravity", // 固定值,与官方客户端一致
|
||||
RequestType: "agent",
|
||||
Model: mappedModel,
|
||||
RequestType: requestType,
|
||||
Model: targetModel,
|
||||
Request: innerRequest,
|
||||
}
|
||||
|
||||
@@ -154,8 +171,40 @@ func GetDefaultIdentityPatch() string {
|
||||
return antigravityIdentity
|
||||
}
|
||||
|
||||
// buildSystemInstruction 构建 systemInstruction
|
||||
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions) *GeminiContent {
|
||||
// mcpXMLProtocol MCP XML 工具调用协议(与 Antigravity-Manager 保持一致)
|
||||
const mcpXMLProtocol = `
|
||||
==== MCP XML 工具调用协议 (Workaround) ====
|
||||
当你需要调用名称以 ` + "`mcp__`" + ` 开头的 MCP 工具时:
|
||||
1) 优先尝试 XML 格式调用:输出 ` + "`<mcp__tool_name>{\"arg\":\"value\"}</mcp__tool_name>`" + `。
|
||||
2) 必须直接输出 XML 块,无需 markdown 包装,内容为 JSON 格式的入参。
|
||||
3) 这种方式具有更高的连通性和容错性,适用于大型结果返回场景。
|
||||
===========================================`
|
||||
|
||||
// hasMCPTools 检测是否有 mcp__ 前缀的工具
|
||||
func hasMCPTools(tools []ClaudeTool) bool {
|
||||
for _, tool := range tools {
|
||||
if strings.HasPrefix(tool.Name, "mcp__") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// filterOpenCodePrompt 过滤 OpenCode 默认提示词,只保留用户自定义指令
|
||||
func filterOpenCodePrompt(text string) string {
|
||||
if !strings.Contains(text, "You are an interactive CLI tool") {
|
||||
return text
|
||||
}
|
||||
// 提取 "Instructions from:" 及之后的部分
|
||||
if idx := strings.Index(text, "Instructions from:"); idx >= 0 {
|
||||
return text[idx:]
|
||||
}
|
||||
// 如果没有自定义指令,返回空
|
||||
return ""
|
||||
}
|
||||
|
||||
// buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致)
|
||||
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
|
||||
var parts []GeminiPart
|
||||
|
||||
// 先解析用户的 system prompt,检测是否已包含 Antigravity identity
|
||||
@@ -167,10 +216,14 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
||||
var sysStr string
|
||||
if err := json.Unmarshal(system, &sysStr); err == nil {
|
||||
if strings.TrimSpace(sysStr) != "" {
|
||||
userSystemParts = append(userSystemParts, GeminiPart{Text: sysStr})
|
||||
if strings.Contains(sysStr, "You are Antigravity") {
|
||||
userHasAntigravityIdentity = true
|
||||
}
|
||||
// 过滤 OpenCode 默认提示词
|
||||
filtered := filterOpenCodePrompt(sysStr)
|
||||
if filtered != "" {
|
||||
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 尝试解析为数组
|
||||
@@ -178,10 +231,14 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
||||
if err := json.Unmarshal(system, &sysBlocks); err == nil {
|
||||
for _, block := range sysBlocks {
|
||||
if block.Type == "text" && strings.TrimSpace(block.Text) != "" {
|
||||
userSystemParts = append(userSystemParts, GeminiPart{Text: block.Text})
|
||||
if strings.Contains(block.Text, "You are Antigravity") {
|
||||
userHasAntigravityIdentity = true
|
||||
}
|
||||
// 过滤 OpenCode 默认提示词
|
||||
filtered := filterOpenCodePrompt(block.Text)
|
||||
if filtered != "" {
|
||||
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -200,6 +257,16 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
||||
// 添加用户的 system prompt
|
||||
parts = append(parts, userSystemParts...)
|
||||
|
||||
// 检测是否有 MCP 工具,如有则注入 XML 调用协议
|
||||
if hasMCPTools(tools) {
|
||||
parts = append(parts, GeminiPart{Text: mcpXMLProtocol})
|
||||
}
|
||||
|
||||
// 如果用户没有提供 Antigravity 身份,添加结束标记
|
||||
if !userHasAntigravityIdentity {
|
||||
parts = append(parts, GeminiPart{Text: "\n--- [SYSTEM_PROMPT_END] ---"})
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -300,8 +367,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
|
||||
Text: block.Thinking,
|
||||
Thought: true,
|
||||
}
|
||||
// 保留原有 signature(Claude 模型需要有效的 signature)
|
||||
if block.Signature != "" {
|
||||
// signature 处理:
|
||||
// - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
|
||||
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
|
||||
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
|
||||
part.ThoughtSignature = block.Signature
|
||||
} else if !allowDummyThought {
|
||||
// Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。
|
||||
@@ -340,12 +409,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
|
||||
},
|
||||
}
|
||||
// tool_use 的 signature 处理:
|
||||
// - Gemini 模型:使用 dummy signature(跳过 thought_signature 校验)
|
||||
// - Claude 模型:透传上游返回的真实 signature(Vertex/Google 需要完整签名链路)
|
||||
if allowDummyThought {
|
||||
part.ThoughtSignature = dummyThoughtSignature
|
||||
} else if block.Signature != "" && block.Signature != dummyThoughtSignature {
|
||||
// - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
|
||||
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
|
||||
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
|
||||
part.ThoughtSignature = block.Signature
|
||||
} else if allowDummyThought {
|
||||
part.ThoughtSignature = dummyThoughtSignature
|
||||
}
|
||||
parts = append(parts, part)
|
||||
|
||||
@@ -429,6 +498,11 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
|
||||
StopSequences: DefaultStopSequences,
|
||||
}
|
||||
|
||||
// 如果请求中指定了 MaxTokens,使用请求值
|
||||
if req.MaxTokens > 0 {
|
||||
config.MaxOutputTokens = req.MaxTokens
|
||||
}
|
||||
|
||||
// Thinking 配置
|
||||
if req.Thinking != nil && req.Thinking.Type == "enabled" {
|
||||
config.ThinkingConfig = &GeminiThinkingConfig{
|
||||
@@ -458,37 +532,43 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
|
||||
return config
|
||||
}
|
||||
|
||||
func hasWebSearchTool(tools []ClaudeTool) bool {
|
||||
for _, tool := range tools {
|
||||
if isWebSearchTool(tool) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isWebSearchTool(tool ClaudeTool) bool {
|
||||
if strings.HasPrefix(tool.Type, "web_search") || tool.Type == "google_search" {
|
||||
return true
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(tool.Name)
|
||||
switch name {
|
||||
case "web_search", "google_search", "web_search_20250305":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// buildTools 构建 tools
|
||||
func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
||||
if len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查是否有 web_search 工具
|
||||
hasWebSearch := false
|
||||
for _, tool := range tools {
|
||||
if tool.Name == "web_search" {
|
||||
hasWebSearch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hasWebSearch {
|
||||
// Web Search 工具映射
|
||||
return []GeminiToolDeclaration{{
|
||||
GoogleSearch: &GeminiGoogleSearch{
|
||||
EnhancedContent: &GeminiEnhancedContent{
|
||||
ImageSearch: &GeminiImageSearch{
|
||||
MaxResultCount: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
}}
|
||||
}
|
||||
hasWebSearch := hasWebSearchTool(tools)
|
||||
|
||||
// 普通工具
|
||||
var funcDecls []GeminiFunctionDecl
|
||||
for _, tool := range tools {
|
||||
if isWebSearchTool(tool) {
|
||||
continue
|
||||
}
|
||||
// 跳过无效工具名称
|
||||
if strings.TrimSpace(tool.Name) == "" {
|
||||
log.Printf("Warning: skipping tool with empty name")
|
||||
@@ -514,11 +594,14 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
||||
}
|
||||
|
||||
// 清理 JSON Schema
|
||||
params := cleanJSONSchema(inputSchema)
|
||||
// 1. 深度清理 [undefined] 值
|
||||
DeepCleanUndefined(inputSchema)
|
||||
// 2. 转换为符合 Gemini v1internal 的 schema
|
||||
params := CleanJSONSchema(inputSchema)
|
||||
// 为 nil schema 提供默认值
|
||||
if params == nil {
|
||||
params = map[string]any{
|
||||
"type": "OBJECT",
|
||||
"type": "object", // lowercase type
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
@@ -531,243 +614,23 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
||||
}
|
||||
|
||||
if len(funcDecls) == 0 {
|
||||
return nil
|
||||
if !hasWebSearch {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Web Search 工具映射
|
||||
return []GeminiToolDeclaration{{
|
||||
GoogleSearch: &GeminiGoogleSearch{
|
||||
EnhancedContent: &GeminiEnhancedContent{
|
||||
ImageSearch: &GeminiImageSearch{
|
||||
MaxResultCount: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
}}
|
||||
}
|
||||
|
||||
return []GeminiToolDeclaration{{
|
||||
FunctionDeclarations: funcDecls,
|
||||
}}
|
||||
}
|
||||
|
||||
// cleanJSONSchema 清理 JSON Schema,移除 Antigravity/Gemini 不支持的字段
|
||||
// 参考 proxycast 的实现,确保 schema 符合 JSON Schema draft 2020-12
|
||||
func cleanJSONSchema(schema map[string]any) map[string]any {
|
||||
if schema == nil {
|
||||
return nil
|
||||
}
|
||||
cleaned := cleanSchemaValue(schema, "$")
|
||||
result, ok := cleaned.(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 确保有 type 字段(默认 OBJECT)
|
||||
if _, hasType := result["type"]; !hasType {
|
||||
result["type"] = "OBJECT"
|
||||
}
|
||||
|
||||
// 确保有 properties 字段(默认空对象)
|
||||
if _, hasProps := result["properties"]; !hasProps {
|
||||
result["properties"] = make(map[string]any)
|
||||
}
|
||||
|
||||
// 验证 required 中的字段都存在于 properties 中
|
||||
if required, ok := result["required"].([]any); ok {
|
||||
if props, ok := result["properties"].(map[string]any); ok {
|
||||
validRequired := make([]any, 0, len(required))
|
||||
for _, r := range required {
|
||||
if reqName, ok := r.(string); ok {
|
||||
if _, exists := props[reqName]; exists {
|
||||
validRequired = append(validRequired, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(validRequired) > 0 {
|
||||
result["required"] = validRequired
|
||||
} else {
|
||||
delete(result, "required")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
var schemaValidationKeys = map[string]bool{
|
||||
"minLength": true,
|
||||
"maxLength": true,
|
||||
"pattern": true,
|
||||
"minimum": true,
|
||||
"maximum": true,
|
||||
"exclusiveMinimum": true,
|
||||
"exclusiveMaximum": true,
|
||||
"multipleOf": true,
|
||||
"uniqueItems": true,
|
||||
"minItems": true,
|
||||
"maxItems": true,
|
||||
"minProperties": true,
|
||||
"maxProperties": true,
|
||||
"patternProperties": true,
|
||||
"propertyNames": true,
|
||||
"dependencies": true,
|
||||
"dependentSchemas": true,
|
||||
"dependentRequired": true,
|
||||
}
|
||||
|
||||
var warnedSchemaKeys sync.Map
|
||||
|
||||
func schemaCleaningWarningsEnabled() bool {
|
||||
// 可通过环境变量强制开关,方便排查:SUB2API_SCHEMA_CLEAN_WARN=true/false
|
||||
if v := strings.TrimSpace(os.Getenv("SUB2API_SCHEMA_CLEAN_WARN")); v != "" {
|
||||
switch strings.ToLower(v) {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
case "0", "false", "no", "off":
|
||||
return false
|
||||
}
|
||||
}
|
||||
// 默认:非 release 模式下输出(debug/test)
|
||||
return gin.Mode() != gin.ReleaseMode
|
||||
}
|
||||
|
||||
func warnSchemaKeyRemovedOnce(key, path string) {
|
||||
if !schemaCleaningWarningsEnabled() {
|
||||
return
|
||||
}
|
||||
if !schemaValidationKeys[key] {
|
||||
return
|
||||
}
|
||||
if _, loaded := warnedSchemaKeys.LoadOrStore(key, struct{}{}); loaded {
|
||||
return
|
||||
}
|
||||
log.Printf("[SchemaClean] removed unsupported JSON Schema validation field key=%q path=%q", key, path)
|
||||
}
|
||||
|
||||
// excludedSchemaKeys 不支持的 schema 字段
|
||||
// 基于 Claude API (Vertex AI) 的实际支持情况
|
||||
// 支持: type, description, enum, properties, required, additionalProperties, items
|
||||
// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段
|
||||
var excludedSchemaKeys = map[string]bool{
|
||||
// 元 schema 字段
|
||||
"$schema": true,
|
||||
"$id": true,
|
||||
"$ref": true,
|
||||
|
||||
// 字符串验证(Gemini 不支持)
|
||||
"minLength": true,
|
||||
"maxLength": true,
|
||||
"pattern": true,
|
||||
|
||||
// 数字验证(Claude API 通过 Vertex AI 不支持这些字段)
|
||||
"minimum": true,
|
||||
"maximum": true,
|
||||
"exclusiveMinimum": true,
|
||||
"exclusiveMaximum": true,
|
||||
"multipleOf": true,
|
||||
|
||||
// 数组验证(Claude API 通过 Vertex AI 不支持这些字段)
|
||||
"uniqueItems": true,
|
||||
"minItems": true,
|
||||
"maxItems": true,
|
||||
|
||||
// 组合 schema(Gemini 不支持)
|
||||
"oneOf": true,
|
||||
"anyOf": true,
|
||||
"allOf": true,
|
||||
"not": true,
|
||||
"if": true,
|
||||
"then": true,
|
||||
"else": true,
|
||||
"$defs": true,
|
||||
"definitions": true,
|
||||
|
||||
// 对象验证(仅保留 properties/required/additionalProperties)
|
||||
"minProperties": true,
|
||||
"maxProperties": true,
|
||||
"patternProperties": true,
|
||||
"propertyNames": true,
|
||||
"dependencies": true,
|
||||
"dependentSchemas": true,
|
||||
"dependentRequired": true,
|
||||
|
||||
// 其他不支持的字段
|
||||
"default": true,
|
||||
"const": true,
|
||||
"examples": true,
|
||||
"deprecated": true,
|
||||
"readOnly": true,
|
||||
"writeOnly": true,
|
||||
"contentMediaType": true,
|
||||
"contentEncoding": true,
|
||||
|
||||
// Claude 特有字段
|
||||
"strict": true,
|
||||
}
|
||||
|
||||
// cleanSchemaValue 递归清理 schema 值
|
||||
func cleanSchemaValue(value any, path string) any {
|
||||
switch v := value.(type) {
|
||||
case map[string]any:
|
||||
result := make(map[string]any)
|
||||
for k, val := range v {
|
||||
// 跳过不支持的字段
|
||||
if excludedSchemaKeys[k] {
|
||||
warnSchemaKeyRemovedOnce(k, path)
|
||||
continue
|
||||
}
|
||||
|
||||
// 特殊处理 type 字段
|
||||
if k == "type" {
|
||||
result[k] = cleanTypeValue(val)
|
||||
continue
|
||||
}
|
||||
|
||||
// 特殊处理 format 字段:只保留 Gemini 支持的 format 值
|
||||
if k == "format" {
|
||||
if formatStr, ok := val.(string); ok {
|
||||
// Gemini 只支持 date-time, date, time
|
||||
if formatStr == "date-time" || formatStr == "date" || formatStr == "time" {
|
||||
result[k] = val
|
||||
}
|
||||
// 其他 format 值直接跳过
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 特殊处理 additionalProperties:Claude API 只支持布尔值,不支持 schema 对象
|
||||
if k == "additionalProperties" {
|
||||
if boolVal, ok := val.(bool); ok {
|
||||
result[k] = boolVal
|
||||
} else {
|
||||
// 如果是 schema 对象,转换为 false(更安全的默认值)
|
||||
result[k] = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 递归清理所有值
|
||||
result[k] = cleanSchemaValue(val, path+"."+k)
|
||||
}
|
||||
return result
|
||||
|
||||
case []any:
|
||||
// 递归处理数组中的每个元素
|
||||
cleaned := make([]any, 0, len(v))
|
||||
for i, item := range v {
|
||||
cleaned = append(cleaned, cleanSchemaValue(item, fmt.Sprintf("%s[%d]", path, i)))
|
||||
}
|
||||
return cleaned
|
||||
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// cleanTypeValue 处理 type 字段,转换为大写
|
||||
func cleanTypeValue(value any) any {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return strings.ToUpper(v)
|
||||
case []any:
|
||||
// 联合类型 ["string", "null"] -> 取第一个非 null 类型
|
||||
for _, t := range v {
|
||||
if ts, ok := t.(string); ok && ts != "null" {
|
||||
return strings.ToUpper(ts)
|
||||
}
|
||||
}
|
||||
// 如果只有 null,返回 STRING
|
||||
return "STRING"
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,7 +100,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
|
||||
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"}
|
||||
]`
|
||||
|
||||
t.Run("Gemini uses dummy tool_use signature", func(t *testing.T) {
|
||||
t.Run("Gemini preserves provided tool_use signature", func(t *testing.T) {
|
||||
toolIDToName := make(map[string]string)
|
||||
parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true)
|
||||
if err != nil {
|
||||
@@ -109,6 +109,23 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
|
||||
if len(parts) != 1 || parts[0].FunctionCall == nil {
|
||||
t.Fatalf("expected 1 functionCall part, got %+v", parts)
|
||||
}
|
||||
if parts[0].ThoughtSignature != "sig_tool_abc" {
|
||||
t.Fatalf("expected preserved tool signature %q, got %q", "sig_tool_abc", parts[0].ThoughtSignature)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Gemini falls back to dummy tool_use signature when missing", func(t *testing.T) {
|
||||
contentNoSig := `[
|
||||
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}}
|
||||
]`
|
||||
toolIDToName := make(map[string]string)
|
||||
parts, _, err := buildParts(json.RawMessage(contentNoSig), toolIDToName, true)
|
||||
if err != nil {
|
||||
t.Fatalf("buildParts() error = %v", err)
|
||||
}
|
||||
if len(parts) != 1 || parts[0].FunctionCall == nil {
|
||||
t.Fatalf("expected 1 functionCall part, got %+v", parts)
|
||||
}
|
||||
if parts[0].ThoughtSignature != dummyThoughtSignature {
|
||||
t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package antigravity
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式)
|
||||
@@ -18,6 +20,15 @@ func TransformGeminiToClaude(geminiResp []byte, originalModel string) ([]byte, *
|
||||
v1Resp.Response = directResp
|
||||
v1Resp.ResponseID = directResp.ResponseID
|
||||
v1Resp.ModelVersion = directResp.ModelVersion
|
||||
} else if len(v1Resp.Response.Candidates) == 0 {
|
||||
// 第一次解析成功但 candidates 为空,说明是直接的 GeminiResponse 格式
|
||||
var directResp GeminiResponse
|
||||
if err2 := json.Unmarshal(geminiResp, &directResp); err2 != nil {
|
||||
return nil, nil, fmt.Errorf("parse gemini response as direct: %w", err2)
|
||||
}
|
||||
v1Resp.Response = directResp
|
||||
v1Resp.ResponseID = directResp.ResponseID
|
||||
v1Resp.ModelVersion = directResp.ModelVersion
|
||||
}
|
||||
|
||||
// 使用处理器转换
|
||||
@@ -63,6 +74,12 @@ func (p *NonStreamingProcessor) Process(geminiResp *GeminiResponse, responseID,
|
||||
p.processPart(&part)
|
||||
}
|
||||
|
||||
if len(geminiResp.Candidates) > 0 {
|
||||
if grounding := geminiResp.Candidates[0].GroundingMetadata; grounding != nil {
|
||||
p.processGrounding(grounding)
|
||||
}
|
||||
}
|
||||
|
||||
// 刷新剩余内容
|
||||
p.flushThinking()
|
||||
p.flushText()
|
||||
@@ -166,16 +183,20 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
p.textBuilder += part.Text
|
||||
|
||||
// 非空 text 带签名 - 立即刷新并输出空 thinking 块
|
||||
// 非空 text 带签名 - 特殊处理:先输出 text,再输出空 thinking 块
|
||||
if signature != "" {
|
||||
p.flushText()
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "text",
|
||||
Text: part.Text,
|
||||
})
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "thinking",
|
||||
Thinking: "",
|
||||
Signature: signature,
|
||||
})
|
||||
} else {
|
||||
// 普通 text (无签名) - 累积到 builder
|
||||
p.textBuilder += part.Text
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -190,6 +211,18 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *NonStreamingProcessor) processGrounding(grounding *GeminiGroundingMetadata) {
|
||||
groundingText := buildGroundingText(grounding)
|
||||
if groundingText == "" {
|
||||
return
|
||||
}
|
||||
|
||||
p.flushThinking()
|
||||
p.flushText()
|
||||
p.textBuilder += groundingText
|
||||
p.flushText()
|
||||
}
|
||||
|
||||
// flushText 刷新 text builder
|
||||
func (p *NonStreamingProcessor) flushText() {
|
||||
if p.textBuilder == "" {
|
||||
@@ -223,6 +256,14 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
|
||||
var finishReason string
|
||||
if len(geminiResp.Candidates) > 0 {
|
||||
finishReason = geminiResp.Candidates[0].FinishReason
|
||||
if finishReason == "MALFORMED_FUNCTION_CALL" {
|
||||
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in response for model %s", originalModel)
|
||||
if geminiResp.Candidates[0].Content != nil {
|
||||
if b, err := json.Marshal(geminiResp.Candidates[0].Content); err == nil {
|
||||
log.Printf("[Antigravity] Malformed content: %s", string(b))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stopReason := "end_turn"
|
||||
@@ -262,6 +303,44 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
|
||||
}
|
||||
}
|
||||
|
||||
func buildGroundingText(grounding *GeminiGroundingMetadata) string {
|
||||
if grounding == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var builder strings.Builder
|
||||
|
||||
if len(grounding.WebSearchQueries) > 0 {
|
||||
_, _ = builder.WriteString("\n\n---\nWeb search queries: ")
|
||||
_, _ = builder.WriteString(strings.Join(grounding.WebSearchQueries, ", "))
|
||||
}
|
||||
|
||||
if len(grounding.GroundingChunks) > 0 {
|
||||
var links []string
|
||||
for i, chunk := range grounding.GroundingChunks {
|
||||
if chunk.Web == nil {
|
||||
continue
|
||||
}
|
||||
title := strings.TrimSpace(chunk.Web.Title)
|
||||
if title == "" {
|
||||
title = "Source"
|
||||
}
|
||||
uri := strings.TrimSpace(chunk.Web.URI)
|
||||
if uri == "" {
|
||||
uri = "#"
|
||||
}
|
||||
links = append(links, fmt.Sprintf("[%d] [%s](%s)", i+1, title, uri))
|
||||
}
|
||||
|
||||
if len(links) > 0 {
|
||||
_, _ = builder.WriteString("\n\nSources:\n")
|
||||
_, _ = builder.WriteString(strings.Join(links, "\n"))
|
||||
}
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// generateRandomID 生成随机 ID
|
||||
func generateRandomID() string {
|
||||
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
|
||||
519
backend/internal/pkg/antigravity/schema_cleaner.go
Normal file
519
backend/internal/pkg/antigravity/schema_cleaner.go
Normal file
@@ -0,0 +1,519 @@
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CleanJSONSchema 清理 JSON Schema,移除 Antigravity/Gemini 不支持的字段
|
||||
// 参考 Antigravity-Manager/src-tauri/src/proxy/common/json_schema.rs 实现
|
||||
// 确保 schema 符合 JSON Schema draft 2020-12 且适配 Gemini v1internal
|
||||
func CleanJSONSchema(schema map[string]any) map[string]any {
|
||||
if schema == nil {
|
||||
return nil
|
||||
}
|
||||
// 0. 预处理:展开 $ref (Schema Flattening)
|
||||
// (Go map 是引用的,直接修改 schema)
|
||||
flattenRefs(schema, extractDefs(schema))
|
||||
|
||||
// 递归清理
|
||||
cleaned := cleanJSONSchemaRecursive(schema)
|
||||
result, ok := cleaned.(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// extractDefs 提取并移除定义的 helper
|
||||
func extractDefs(schema map[string]any) map[string]any {
|
||||
defs := make(map[string]any)
|
||||
if d, ok := schema["$defs"].(map[string]any); ok {
|
||||
for k, v := range d {
|
||||
defs[k] = v
|
||||
}
|
||||
delete(schema, "$defs")
|
||||
}
|
||||
if d, ok := schema["definitions"].(map[string]any); ok {
|
||||
for k, v := range d {
|
||||
defs[k] = v
|
||||
}
|
||||
delete(schema, "definitions")
|
||||
}
|
||||
return defs
|
||||
}
|
||||
|
||||
// flattenRefs 递归展开 $ref
|
||||
func flattenRefs(schema map[string]any, defs map[string]any) {
|
||||
if len(defs) == 0 {
|
||||
return // 无需展开
|
||||
}
|
||||
|
||||
// 检查并替换 $ref
|
||||
if ref, ok := schema["$ref"].(string); ok {
|
||||
delete(schema, "$ref")
|
||||
// 解析引用名 (例如 #/$defs/MyType -> MyType)
|
||||
parts := strings.Split(ref, "/")
|
||||
refName := parts[len(parts)-1]
|
||||
|
||||
if defSchema, exists := defs[refName]; exists {
|
||||
if defMap, ok := defSchema.(map[string]any); ok {
|
||||
// 合并定义内容 (不覆盖现有 key)
|
||||
for k, v := range defMap {
|
||||
if _, has := schema[k]; !has {
|
||||
schema[k] = deepCopy(v) // 需深拷贝避免共享引用
|
||||
}
|
||||
}
|
||||
// 递归处理刚刚合并进来的内容
|
||||
flattenRefs(schema, defs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 遍历子节点
|
||||
for _, v := range schema {
|
||||
if subMap, ok := v.(map[string]any); ok {
|
||||
flattenRefs(subMap, defs)
|
||||
} else if subArr, ok := v.([]any); ok {
|
||||
for _, item := range subArr {
|
||||
if itemMap, ok := item.(map[string]any); ok {
|
||||
flattenRefs(itemMap, defs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// deepCopy 深拷贝 (简单实现,仅针对 JSON 类型)
|
||||
func deepCopy(src any) any {
|
||||
if src == nil {
|
||||
return nil
|
||||
}
|
||||
switch v := src.(type) {
|
||||
case map[string]any:
|
||||
dst := make(map[string]any)
|
||||
for k, val := range v {
|
||||
dst[k] = deepCopy(val)
|
||||
}
|
||||
return dst
|
||||
case []any:
|
||||
dst := make([]any, len(v))
|
||||
for i, val := range v {
|
||||
dst[i] = deepCopy(val)
|
||||
}
|
||||
return dst
|
||||
default:
|
||||
return src
|
||||
}
|
||||
}
|
||||
|
||||
// cleanJSONSchemaRecursive 递归核心清理逻辑
|
||||
// 返回处理后的值 (通常是 input map,但可能修改内部结构)
|
||||
func cleanJSONSchemaRecursive(value any) any {
|
||||
schemaMap, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
return value
|
||||
}
|
||||
|
||||
// 0. [NEW] 合并 allOf
|
||||
mergeAllOf(schemaMap)
|
||||
|
||||
// 1. [CRITICAL] 深度递归处理子项
|
||||
if props, ok := schemaMap["properties"].(map[string]any); ok {
|
||||
for _, v := range props {
|
||||
cleanJSONSchemaRecursive(v)
|
||||
}
|
||||
// Go 中不需要像 Rust 那样显式处理 nullable_keys remove required,
|
||||
// 因为我们在子项处理中会正确设置 type 和 description
|
||||
} else if items, ok := schemaMap["items"]; ok {
|
||||
// [FIX] Gemini 期望 "items" 是单个 Schema 对象(列表验证),而不是数组(元组验证)。
|
||||
if itemsArr, ok := items.([]any); ok {
|
||||
// 策略:将元组 [A, B] 视为 A、B 中的最佳匹配项。
|
||||
best := extractBestSchemaFromUnion(itemsArr)
|
||||
if best == nil {
|
||||
// 回退到通用字符串
|
||||
best = map[string]any{"type": "string"}
|
||||
}
|
||||
// 用处理后的对象替换原有数组
|
||||
cleanedBest := cleanJSONSchemaRecursive(best)
|
||||
schemaMap["items"] = cleanedBest
|
||||
} else {
|
||||
cleanJSONSchemaRecursive(items)
|
||||
}
|
||||
} else {
|
||||
// 遍历所有值递归
|
||||
for _, v := range schemaMap {
|
||||
if _, isMap := v.(map[string]any); isMap {
|
||||
cleanJSONSchemaRecursive(v)
|
||||
} else if arr, isArr := v.([]any); isArr {
|
||||
for _, item := range arr {
|
||||
cleanJSONSchemaRecursive(item)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. [FIX] 处理 anyOf/oneOf 联合类型: 合并属性而非直接删除
|
||||
var unionArray []any
|
||||
typeStr, _ := schemaMap["type"].(string)
|
||||
if typeStr == "" || typeStr == "object" {
|
||||
if anyOf, ok := schemaMap["anyOf"].([]any); ok {
|
||||
unionArray = anyOf
|
||||
} else if oneOf, ok := schemaMap["oneOf"].([]any); ok {
|
||||
unionArray = oneOf
|
||||
}
|
||||
}
|
||||
|
||||
if len(unionArray) > 0 {
|
||||
if bestBranch := extractBestSchemaFromUnion(unionArray); bestBranch != nil {
|
||||
if bestMap, ok := bestBranch.(map[string]any); ok {
|
||||
// 合并分支内容
|
||||
for k, v := range bestMap {
|
||||
if k == "properties" {
|
||||
targetProps, _ := schemaMap["properties"].(map[string]any)
|
||||
if targetProps == nil {
|
||||
targetProps = make(map[string]any)
|
||||
schemaMap["properties"] = targetProps
|
||||
}
|
||||
if sourceProps, ok := v.(map[string]any); ok {
|
||||
for pk, pv := range sourceProps {
|
||||
if _, exists := targetProps[pk]; !exists {
|
||||
targetProps[pk] = deepCopy(pv)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if k == "required" {
|
||||
targetReq, _ := schemaMap["required"].([]any)
|
||||
if sourceReq, ok := v.([]any); ok {
|
||||
for _, rv := range sourceReq {
|
||||
// 简单的去重添加
|
||||
exists := false
|
||||
for _, tr := range targetReq {
|
||||
if tr == rv {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
targetReq = append(targetReq, rv)
|
||||
}
|
||||
}
|
||||
schemaMap["required"] = targetReq
|
||||
}
|
||||
} else if _, exists := schemaMap[k]; !exists {
|
||||
schemaMap[k] = deepCopy(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. [SAFETY] 检查当前对象是否为 JSON Schema 节点
|
||||
looksLikeSchema := hasKey(schemaMap, "type") ||
|
||||
hasKey(schemaMap, "properties") ||
|
||||
hasKey(schemaMap, "items") ||
|
||||
hasKey(schemaMap, "enum") ||
|
||||
hasKey(schemaMap, "anyOf") ||
|
||||
hasKey(schemaMap, "oneOf") ||
|
||||
hasKey(schemaMap, "allOf")
|
||||
|
||||
if looksLikeSchema {
|
||||
// 4. [ROBUST] 约束迁移
|
||||
migrateConstraints(schemaMap)
|
||||
|
||||
// 5. [CRITICAL] 白名单过滤
|
||||
allowedFields := map[string]bool{
|
||||
"type": true,
|
||||
"description": true,
|
||||
"properties": true,
|
||||
"required": true,
|
||||
"items": true,
|
||||
"enum": true,
|
||||
"title": true,
|
||||
}
|
||||
for k := range schemaMap {
|
||||
if !allowedFields[k] {
|
||||
delete(schemaMap, k)
|
||||
}
|
||||
}
|
||||
|
||||
// 6. [SAFETY] 处理空 Object
|
||||
if t, _ := schemaMap["type"].(string); t == "object" {
|
||||
hasProps := false
|
||||
if props, ok := schemaMap["properties"].(map[string]any); ok && len(props) > 0 {
|
||||
hasProps = true
|
||||
}
|
||||
if !hasProps {
|
||||
schemaMap["properties"] = map[string]any{
|
||||
"reason": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Reason for calling this tool",
|
||||
},
|
||||
}
|
||||
schemaMap["required"] = []any{"reason"}
|
||||
}
|
||||
}
|
||||
|
||||
// 7. [SAFETY] Required 字段对齐
|
||||
if props, ok := schemaMap["properties"].(map[string]any); ok {
|
||||
if req, ok := schemaMap["required"].([]any); ok {
|
||||
var validReq []any
|
||||
for _, r := range req {
|
||||
if rStr, ok := r.(string); ok {
|
||||
if _, exists := props[rStr]; exists {
|
||||
validReq = append(validReq, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(validReq) > 0 {
|
||||
schemaMap["required"] = validReq
|
||||
} else {
|
||||
delete(schemaMap, "required")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 8. 处理 type 字段 (Lowercase + Nullable 提取)
|
||||
isEffectivelyNullable := false
|
||||
if typeVal, exists := schemaMap["type"]; exists {
|
||||
var selectedType string
|
||||
switch v := typeVal.(type) {
|
||||
case string:
|
||||
lower := strings.ToLower(v)
|
||||
if lower == "null" {
|
||||
isEffectivelyNullable = true
|
||||
selectedType = "string" // fallback
|
||||
} else {
|
||||
selectedType = lower
|
||||
}
|
||||
case []any:
|
||||
// ["string", "null"]
|
||||
for _, t := range v {
|
||||
if ts, ok := t.(string); ok {
|
||||
lower := strings.ToLower(ts)
|
||||
if lower == "null" {
|
||||
isEffectivelyNullable = true
|
||||
} else if selectedType == "" {
|
||||
selectedType = lower
|
||||
}
|
||||
}
|
||||
}
|
||||
if selectedType == "" {
|
||||
selectedType = "string"
|
||||
}
|
||||
}
|
||||
schemaMap["type"] = selectedType
|
||||
} else {
|
||||
// 默认 object 如果有 properties (虽然上面白名单过滤可能删了 type 如果它不在... 但 type 必在 allowlist)
|
||||
// 如果没有 type,但有 properties,补一个
|
||||
if hasKey(schemaMap, "properties") {
|
||||
schemaMap["type"] = "object"
|
||||
} else {
|
||||
// 默认为 string ? or object? Gemini 通常需要明确 type
|
||||
schemaMap["type"] = "object"
|
||||
}
|
||||
}
|
||||
|
||||
if isEffectivelyNullable {
|
||||
desc, _ := schemaMap["description"].(string)
|
||||
if !strings.Contains(desc, "nullable") {
|
||||
if desc != "" {
|
||||
desc += " "
|
||||
}
|
||||
desc += "(nullable)"
|
||||
schemaMap["description"] = desc
|
||||
}
|
||||
}
|
||||
|
||||
// 9. Enum 值强制转字符串
|
||||
if enumVals, ok := schemaMap["enum"].([]any); ok {
|
||||
hasNonString := false
|
||||
for i, val := range enumVals {
|
||||
if _, isStr := val.(string); !isStr {
|
||||
hasNonString = true
|
||||
if val == nil {
|
||||
enumVals[i] = "null"
|
||||
} else {
|
||||
enumVals[i] = fmt.Sprintf("%v", val)
|
||||
}
|
||||
}
|
||||
}
|
||||
// If we mandated string values, we must ensure type is string
|
||||
if hasNonString {
|
||||
schemaMap["type"] = "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return schemaMap
|
||||
}
|
||||
|
||||
func hasKey(m map[string]any, k string) bool {
|
||||
_, ok := m[k]
|
||||
return ok
|
||||
}
|
||||
|
||||
func migrateConstraints(m map[string]any) {
|
||||
constraints := []struct {
|
||||
key string
|
||||
label string
|
||||
}{
|
||||
{"minLength", "minLen"},
|
||||
{"maxLength", "maxLen"},
|
||||
{"pattern", "pattern"},
|
||||
{"minimum", "min"},
|
||||
{"maximum", "max"},
|
||||
{"multipleOf", "multipleOf"},
|
||||
{"exclusiveMinimum", "exclMin"},
|
||||
{"exclusiveMaximum", "exclMax"},
|
||||
{"minItems", "minItems"},
|
||||
{"maxItems", "maxItems"},
|
||||
{"propertyNames", "propertyNames"},
|
||||
{"format", "format"},
|
||||
}
|
||||
|
||||
var hints []string
|
||||
for _, c := range constraints {
|
||||
if val, ok := m[c.key]; ok && val != nil {
|
||||
hints = append(hints, fmt.Sprintf("%s: %v", c.label, val))
|
||||
}
|
||||
}
|
||||
|
||||
if len(hints) > 0 {
|
||||
suffix := fmt.Sprintf(" [Constraint: %s]", strings.Join(hints, ", "))
|
||||
desc, _ := m["description"].(string)
|
||||
if !strings.Contains(desc, suffix) {
|
||||
m["description"] = desc + suffix
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// mergeAllOf 合并 allOf
|
||||
func mergeAllOf(m map[string]any) {
|
||||
allOf, ok := m["allOf"].([]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(m, "allOf")
|
||||
|
||||
mergedProps := make(map[string]any)
|
||||
mergedReq := make(map[string]bool)
|
||||
otherFields := make(map[string]any)
|
||||
|
||||
for _, sub := range allOf {
|
||||
if subMap, ok := sub.(map[string]any); ok {
|
||||
// Props
|
||||
if props, ok := subMap["properties"].(map[string]any); ok {
|
||||
for k, v := range props {
|
||||
mergedProps[k] = v
|
||||
}
|
||||
}
|
||||
// Required
|
||||
if reqs, ok := subMap["required"].([]any); ok {
|
||||
for _, r := range reqs {
|
||||
if s, ok := r.(string); ok {
|
||||
mergedReq[s] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
// Others
|
||||
for k, v := range subMap {
|
||||
if k != "properties" && k != "required" && k != "allOf" {
|
||||
if _, exists := otherFields[k]; !exists {
|
||||
otherFields[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply
|
||||
for k, v := range otherFields {
|
||||
if _, exists := m[k]; !exists {
|
||||
m[k] = v
|
||||
}
|
||||
}
|
||||
if len(mergedProps) > 0 {
|
||||
existProps, _ := m["properties"].(map[string]any)
|
||||
if existProps == nil {
|
||||
existProps = make(map[string]any)
|
||||
m["properties"] = existProps
|
||||
}
|
||||
for k, v := range mergedProps {
|
||||
if _, exists := existProps[k]; !exists {
|
||||
existProps[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(mergedReq) > 0 {
|
||||
existReq, _ := m["required"].([]any)
|
||||
var validReqs []any
|
||||
for _, r := range existReq {
|
||||
if s, ok := r.(string); ok {
|
||||
validReqs = append(validReqs, s)
|
||||
delete(mergedReq, s) // already exists
|
||||
}
|
||||
}
|
||||
// append new
|
||||
for r := range mergedReq {
|
||||
validReqs = append(validReqs, r)
|
||||
}
|
||||
m["required"] = validReqs
|
||||
}
|
||||
}
|
||||
|
||||
// extractBestSchemaFromUnion 从 anyOf/oneOf 中选取最佳分支
|
||||
func extractBestSchemaFromUnion(unionArray []any) any {
|
||||
var bestOption any
|
||||
bestScore := -1
|
||||
|
||||
for _, item := range unionArray {
|
||||
score := scoreSchemaOption(item)
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
bestOption = item
|
||||
}
|
||||
}
|
||||
return bestOption
|
||||
}
|
||||
|
||||
func scoreSchemaOption(val any) int {
|
||||
m, ok := val.(map[string]any)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
typeStr, _ := m["type"].(string)
|
||||
|
||||
if hasKey(m, "properties") || typeStr == "object" {
|
||||
return 3
|
||||
}
|
||||
if hasKey(m, "items") || typeStr == "array" {
|
||||
return 2
|
||||
}
|
||||
if typeStr != "" && typeStr != "null" {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// DeepCleanUndefined 深度清理值为 "[undefined]" 的字段
|
||||
func DeepCleanUndefined(value any) {
|
||||
if value == nil {
|
||||
return
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case map[string]any:
|
||||
for k, val := range v {
|
||||
if s, ok := val.(string); ok && s == "[undefined]" {
|
||||
delete(v, k)
|
||||
continue
|
||||
}
|
||||
DeepCleanUndefined(val)
|
||||
}
|
||||
case []any:
|
||||
for _, val := range v {
|
||||
DeepCleanUndefined(val)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -27,6 +28,8 @@ type StreamingProcessor struct {
|
||||
pendingSignature string
|
||||
trailingSignature string
|
||||
originalModel string
|
||||
webSearchQueries []string
|
||||
groundingChunks []GeminiGroundingChunk
|
||||
|
||||
// 累计 usage
|
||||
inputTokens int
|
||||
@@ -93,9 +96,21 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||
}
|
||||
}
|
||||
|
||||
if len(geminiResp.Candidates) > 0 {
|
||||
p.captureGrounding(geminiResp.Candidates[0].GroundingMetadata)
|
||||
}
|
||||
|
||||
// 检查是否结束
|
||||
if len(geminiResp.Candidates) > 0 {
|
||||
finishReason := geminiResp.Candidates[0].FinishReason
|
||||
if finishReason == "MALFORMED_FUNCTION_CALL" {
|
||||
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in stream for model %s", p.originalModel)
|
||||
if geminiResp.Candidates[0].Content != nil {
|
||||
if b, err := json.Marshal(geminiResp.Candidates[0].Content); err == nil {
|
||||
log.Printf("[Antigravity] Malformed content: %s", string(b))
|
||||
}
|
||||
}
|
||||
}
|
||||
if finishReason != "" {
|
||||
_, _ = result.Write(p.emitFinish(finishReason))
|
||||
}
|
||||
@@ -200,6 +215,20 @@ func (p *StreamingProcessor) processPart(part *GeminiPart) []byte {
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
func (p *StreamingProcessor) captureGrounding(grounding *GeminiGroundingMetadata) {
|
||||
if grounding == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if len(grounding.WebSearchQueries) > 0 && len(p.webSearchQueries) == 0 {
|
||||
p.webSearchQueries = append([]string(nil), grounding.WebSearchQueries...)
|
||||
}
|
||||
|
||||
if len(grounding.GroundingChunks) > 0 && len(p.groundingChunks) == 0 {
|
||||
p.groundingChunks = append([]GeminiGroundingChunk(nil), grounding.GroundingChunks...)
|
||||
}
|
||||
}
|
||||
|
||||
// processThinking 处理 thinking
|
||||
func (p *StreamingProcessor) processThinking(text, signature string) []byte {
|
||||
var result bytes.Buffer
|
||||
@@ -417,6 +446,23 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
if len(p.webSearchQueries) > 0 || len(p.groundingChunks) > 0 {
|
||||
groundingText := buildGroundingText(&GeminiGroundingMetadata{
|
||||
WebSearchQueries: p.webSearchQueries,
|
||||
GroundingChunks: p.groundingChunks,
|
||||
})
|
||||
if groundingText != "" {
|
||||
_, _ = result.Write(p.startBlock(BlockTypeText, map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
}))
|
||||
_, _ = result.Write(p.emitDelta("text_delta", map[string]any{
|
||||
"text": groundingText,
|
||||
}))
|
||||
_, _ = result.Write(p.endBlock())
|
||||
}
|
||||
}
|
||||
|
||||
// 确定 stop_reason
|
||||
stopReason := "end_turn"
|
||||
if p.usedTool {
|
||||
|
||||
@@ -16,14 +16,11 @@ type ModelsListResponse struct {
|
||||
func DefaultModels() []Model {
|
||||
methods := []string{"generateContent", "streamGenerateContent"}
|
||||
return []Model{
|
||||
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-1.5-pro", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-1.5-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-1.5-flash-8b", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,10 +12,10 @@ type Model struct {
|
||||
// DefaultModels is the curated Gemini model list used by the admin UI "test account" flow.
|
||||
var DefaultModels = []Model{
|
||||
{ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
|
||||
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
|
||||
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
|
||||
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
|
||||
}
|
||||
|
||||
// DefaultTestModel is the default model to preselect in test flows.
|
||||
|
||||
@@ -13,20 +13,26 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Claude OAuth Constants (from CRS project)
|
||||
// Claude OAuth Constants
|
||||
const (
|
||||
// OAuth Client ID for Claude
|
||||
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
|
||||
// OAuth endpoints
|
||||
AuthorizeURL = "https://claude.ai/oauth/authorize"
|
||||
TokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||
RedirectURI = "https://console.anthropic.com/oauth/code/callback"
|
||||
TokenURL = "https://platform.claude.com/v1/oauth/token"
|
||||
RedirectURI = "https://platform.claude.com/oauth/code/callback"
|
||||
|
||||
// Scopes
|
||||
ScopeProfile = "user:profile"
|
||||
// Scopes - Browser URL (includes org:create_api_key for user authorization)
|
||||
ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers"
|
||||
// Scopes - Internal API call (org:create_api_key not supported in API)
|
||||
ScopeAPI = "user:profile user:inference user:sessions:claude_code user:mcp_servers"
|
||||
// Scopes - Setup token (inference only)
|
||||
ScopeInference = "user:inference"
|
||||
|
||||
// Code Verifier character set (RFC 7636 compliant)
|
||||
codeVerifierCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
|
||||
|
||||
// Session TTL
|
||||
SessionTTL = 30 * time.Minute
|
||||
)
|
||||
@@ -53,7 +59,6 @@ func NewSessionStore() *SessionStore {
|
||||
sessions: make(map[string]*OAuthSession),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
// Start cleanup goroutine
|
||||
go store.cleanup()
|
||||
return store
|
||||
}
|
||||
@@ -78,7 +83,6 @@ func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
// Check if expired
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
return nil, false
|
||||
}
|
||||
@@ -122,13 +126,13 @@ func GenerateRandomBytes(n int) ([]byte, error) {
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// GenerateState generates a random state string for OAuth
|
||||
// GenerateState generates a random state string for OAuth (base64url encoded)
|
||||
func GenerateState() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
return base64URLEncode(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateSessionID generates a unique session ID
|
||||
@@ -140,13 +144,30 @@ func GenerateSessionID() (string, error) {
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeVerifier generates a PKCE code verifier (32 bytes -> base64url)
|
||||
// GenerateCodeVerifier generates a PKCE code verifier using character set method
|
||||
func GenerateCodeVerifier() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
const targetLen = 32
|
||||
charsetLen := len(codeVerifierCharset)
|
||||
limit := 256 - (256 % charsetLen)
|
||||
|
||||
result := make([]byte, 0, targetLen)
|
||||
randBuf := make([]byte, targetLen*2)
|
||||
|
||||
for len(result) < targetLen {
|
||||
if _, err := rand.Read(randBuf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, b := range randBuf {
|
||||
if int(b) < limit {
|
||||
result = append(result, codeVerifierCharset[int(b)%charsetLen])
|
||||
if len(result) >= targetLen {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return base64URLEncode(bytes), nil
|
||||
|
||||
return base64URLEncode(result), nil
|
||||
}
|
||||
|
||||
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
|
||||
@@ -158,42 +179,31 @@ func GenerateCodeChallenge(verifier string) string {
|
||||
// base64URLEncode encodes bytes to base64url without padding
|
||||
func base64URLEncode(data []byte) string {
|
||||
encoded := base64.URLEncoding.EncodeToString(data)
|
||||
// Remove padding
|
||||
return strings.TrimRight(encoded, "=")
|
||||
}
|
||||
|
||||
// BuildAuthorizationURL builds the OAuth authorization URL
|
||||
// BuildAuthorizationURL builds the OAuth authorization URL with correct parameter order
|
||||
func BuildAuthorizationURL(state, codeChallenge, scope string) string {
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("redirect_uri", RedirectURI)
|
||||
params.Set("scope", scope)
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
encodedRedirectURI := url.QueryEscape(RedirectURI)
|
||||
encodedScope := strings.ReplaceAll(url.QueryEscape(scope), "%20", "+")
|
||||
|
||||
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
|
||||
}
|
||||
|
||||
// TokenRequest represents the token exchange request body
|
||||
type TokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
ClientID string `json:"client_id"`
|
||||
Code string `json:"code"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
State string `json:"state"`
|
||||
return fmt.Sprintf("%s?code=true&client_id=%s&response_type=code&redirect_uri=%s&scope=%s&code_challenge=%s&code_challenge_method=S256&state=%s",
|
||||
AuthorizeURL,
|
||||
ClientID,
|
||||
encodedRedirectURI,
|
||||
encodedScope,
|
||||
codeChallenge,
|
||||
state,
|
||||
)
|
||||
}
|
||||
|
||||
// TokenResponse represents the token response from OAuth provider
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
// Organization and Account info from OAuth response
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
Organization *OrgInfo `json:"organization,omitempty"`
|
||||
Account *AccountInfo `json:"account,omitempty"`
|
||||
}
|
||||
@@ -205,33 +215,6 @@ type OrgInfo struct {
|
||||
|
||||
// AccountInfo represents account info from OAuth response
|
||||
type AccountInfo struct {
|
||||
UUID string `json:"uuid"`
|
||||
}
|
||||
|
||||
// RefreshTokenRequest represents the refresh token request
|
||||
type RefreshTokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ClientID string `json:"client_id"`
|
||||
}
|
||||
|
||||
// BuildTokenRequest creates a token exchange request
|
||||
func BuildTokenRequest(code, codeVerifier, state string) *TokenRequest {
|
||||
return &TokenRequest{
|
||||
GrantType: "authorization_code",
|
||||
ClientID: ClientID,
|
||||
Code: code,
|
||||
RedirectURI: RedirectURI,
|
||||
CodeVerifier: codeVerifier,
|
||||
State: state,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildRefreshTokenRequest creates a refresh token request
|
||||
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
|
||||
return &RefreshTokenRequest{
|
||||
GrantType: "refresh_token",
|
||||
RefreshToken: refreshToken,
|
||||
ClientID: ClientID,
|
||||
}
|
||||
UUID string `json:"uuid"`
|
||||
EmailAddress string `json:"email_address"`
|
||||
}
|
||||
|
||||
@@ -162,11 +162,11 @@ func ParsePagination(c *gin.Context) (page, pageSize int) {
|
||||
|
||||
// 支持 page_size 和 limit 两种参数名
|
||||
if ps := c.Query("page_size"); ps != "" {
|
||||
if val, err := parseInt(ps); err == nil && val > 0 && val <= 100 {
|
||||
if val, err := parseInt(ps); err == nil && val > 0 && val <= 1000 {
|
||||
pageSize = val
|
||||
}
|
||||
} else if l := c.Query("limit"); l != "" {
|
||||
if val, err := parseInt(l); err == nil && val > 0 && val <= 100 {
|
||||
if val, err := parseInt(l); err == nil && val > 0 && val <= 1000 {
|
||||
pageSize = val
|
||||
}
|
||||
}
|
||||
|
||||
568
backend/internal/pkg/tlsfingerprint/dialer.go
Normal file
568
backend/internal/pkg/tlsfingerprint/dialer.go
Normal file
@@ -0,0 +1,568 @@
|
||||
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
|
||||
// It uses the utls library to create TLS connections that mimic Node.js/Claude Code clients.
|
||||
package tlsfingerprint
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
utls "github.com/refraction-networking/utls"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// Profile contains TLS fingerprint configuration.
|
||||
type Profile struct {
|
||||
Name string // Profile name for identification
|
||||
CipherSuites []uint16
|
||||
Curves []uint16
|
||||
PointFormats []uint8
|
||||
EnableGREASE bool
|
||||
}
|
||||
|
||||
// Dialer creates TLS connections with custom fingerprints.
|
||||
type Dialer struct {
|
||||
profile *Profile
|
||||
baseDialer func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// HTTPProxyDialer creates TLS connections through HTTP/HTTPS proxies with custom fingerprints.
|
||||
// It handles the CONNECT tunnel establishment before performing TLS handshake.
|
||||
type HTTPProxyDialer struct {
|
||||
profile *Profile
|
||||
proxyURL *url.URL
|
||||
}
|
||||
|
||||
// SOCKS5ProxyDialer creates TLS connections through SOCKS5 proxies with custom fingerprints.
|
||||
// It uses golang.org/x/net/proxy to establish the SOCKS5 tunnel.
|
||||
type SOCKS5ProxyDialer struct {
|
||||
profile *Profile
|
||||
proxyURL *url.URL
|
||||
}
|
||||
|
||||
// Default TLS fingerprint values captured from Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)
|
||||
// Captured using: tshark -i lo -f "tcp port 8443" -Y "tls.handshake.type == 1" -V
|
||||
// JA3 Hash: 1a28e69016765d92e3b381168d68922c
|
||||
//
|
||||
// Note: JA3/JA4 may have slight variations due to:
|
||||
// - Session ticket presence/absence
|
||||
// - Extension negotiation state
|
||||
var (
|
||||
// defaultCipherSuites contains all 59 cipher suites from Claude CLI
|
||||
// Order is critical for JA3 fingerprint matching
|
||||
defaultCipherSuites = []uint16{
|
||||
// TLS 1.3 cipher suites (MUST be first)
|
||||
0x1302, // TLS_AES_256_GCM_SHA384
|
||||
0x1303, // TLS_CHACHA20_POLY1305_SHA256
|
||||
0x1301, // TLS_AES_128_GCM_SHA256
|
||||
|
||||
// ECDHE + AES-GCM
|
||||
0xc02f, // TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
|
||||
0xc02b, // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
|
||||
0xc030, // TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
|
||||
0xc02c, // TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
|
||||
|
||||
// DHE + AES-GCM
|
||||
0x009e, // TLS_DHE_RSA_WITH_AES_128_GCM_SHA256
|
||||
|
||||
// ECDHE/DHE + AES-CBC-SHA256/384
|
||||
0xc027, // TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256
|
||||
0x0067, // TLS_DHE_RSA_WITH_AES_128_CBC_SHA256
|
||||
0xc028, // TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384
|
||||
0x006b, // TLS_DHE_RSA_WITH_AES_256_CBC_SHA256
|
||||
|
||||
// DHE-DSS/RSA + AES-GCM
|
||||
0x00a3, // TLS_DHE_DSS_WITH_AES_256_GCM_SHA384
|
||||
0x009f, // TLS_DHE_RSA_WITH_AES_256_GCM_SHA384
|
||||
|
||||
// ChaCha20-Poly1305
|
||||
0xcca9, // TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
|
||||
0xcca8, // TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
|
||||
0xccaa, // TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256
|
||||
|
||||
// AES-CCM (256-bit)
|
||||
0xc0af, // TLS_ECDHE_ECDSA_WITH_AES_256_CCM_8
|
||||
0xc0ad, // TLS_ECDHE_ECDSA_WITH_AES_256_CCM
|
||||
0xc0a3, // TLS_DHE_RSA_WITH_AES_256_CCM_8
|
||||
0xc09f, // TLS_DHE_RSA_WITH_AES_256_CCM
|
||||
|
||||
// ARIA (256-bit)
|
||||
0xc05d, // TLS_ECDHE_ECDSA_WITH_ARIA_256_GCM_SHA384
|
||||
0xc061, // TLS_ECDHE_RSA_WITH_ARIA_256_GCM_SHA384
|
||||
0xc057, // TLS_DHE_DSS_WITH_ARIA_256_GCM_SHA384
|
||||
0xc053, // TLS_DHE_RSA_WITH_ARIA_256_GCM_SHA384
|
||||
|
||||
// DHE-DSS + AES-GCM (128-bit)
|
||||
0x00a2, // TLS_DHE_DSS_WITH_AES_128_GCM_SHA256
|
||||
|
||||
// AES-CCM (128-bit)
|
||||
0xc0ae, // TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8
|
||||
0xc0ac, // TLS_ECDHE_ECDSA_WITH_AES_128_CCM
|
||||
0xc0a2, // TLS_DHE_RSA_WITH_AES_128_CCM_8
|
||||
0xc09e, // TLS_DHE_RSA_WITH_AES_128_CCM
|
||||
|
||||
// ARIA (128-bit)
|
||||
0xc05c, // TLS_ECDHE_ECDSA_WITH_ARIA_128_GCM_SHA256
|
||||
0xc060, // TLS_ECDHE_RSA_WITH_ARIA_128_GCM_SHA256
|
||||
0xc056, // TLS_DHE_DSS_WITH_ARIA_128_GCM_SHA256
|
||||
0xc052, // TLS_DHE_RSA_WITH_ARIA_128_GCM_SHA256
|
||||
|
||||
// ECDHE/DHE + AES-CBC-SHA384/256 (more)
|
||||
0xc024, // TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384
|
||||
0x006a, // TLS_DHE_DSS_WITH_AES_256_CBC_SHA256
|
||||
0xc023, // TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256
|
||||
0x0040, // TLS_DHE_DSS_WITH_AES_128_CBC_SHA256
|
||||
|
||||
// ECDHE/DHE + AES-CBC-SHA (legacy)
|
||||
0xc00a, // TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA
|
||||
0xc014, // TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA
|
||||
0x0039, // TLS_DHE_RSA_WITH_AES_256_CBC_SHA
|
||||
0x0038, // TLS_DHE_DSS_WITH_AES_256_CBC_SHA
|
||||
0xc009, // TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA
|
||||
0xc013, // TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA
|
||||
0x0033, // TLS_DHE_RSA_WITH_AES_128_CBC_SHA
|
||||
0x0032, // TLS_DHE_DSS_WITH_AES_128_CBC_SHA
|
||||
|
||||
// RSA + AES-GCM/CCM/ARIA (non-PFS, 256-bit)
|
||||
0x009d, // TLS_RSA_WITH_AES_256_GCM_SHA384
|
||||
0xc0a1, // TLS_RSA_WITH_AES_256_CCM_8
|
||||
0xc09d, // TLS_RSA_WITH_AES_256_CCM
|
||||
0xc051, // TLS_RSA_WITH_ARIA_256_GCM_SHA384
|
||||
|
||||
// RSA + AES-GCM/CCM/ARIA (non-PFS, 128-bit)
|
||||
0x009c, // TLS_RSA_WITH_AES_128_GCM_SHA256
|
||||
0xc0a0, // TLS_RSA_WITH_AES_128_CCM_8
|
||||
0xc09c, // TLS_RSA_WITH_AES_128_CCM
|
||||
0xc050, // TLS_RSA_WITH_ARIA_128_GCM_SHA256
|
||||
|
||||
// RSA + AES-CBC (non-PFS, legacy)
|
||||
0x003d, // TLS_RSA_WITH_AES_256_CBC_SHA256
|
||||
0x003c, // TLS_RSA_WITH_AES_128_CBC_SHA256
|
||||
0x0035, // TLS_RSA_WITH_AES_256_CBC_SHA
|
||||
0x002f, // TLS_RSA_WITH_AES_128_CBC_SHA
|
||||
|
||||
// Renegotiation indication
|
||||
0x00ff, // TLS_EMPTY_RENEGOTIATION_INFO_SCSV
|
||||
}
|
||||
|
||||
// defaultCurves contains the 10 supported groups from Claude CLI (including FFDHE)
|
||||
defaultCurves = []utls.CurveID{
|
||||
utls.X25519, // 0x001d
|
||||
utls.CurveP256, // 0x0017 (secp256r1)
|
||||
utls.CurveID(0x001e), // x448
|
||||
utls.CurveP521, // 0x0019 (secp521r1)
|
||||
utls.CurveP384, // 0x0018 (secp384r1)
|
||||
utls.CurveID(0x0100), // ffdhe2048
|
||||
utls.CurveID(0x0101), // ffdhe3072
|
||||
utls.CurveID(0x0102), // ffdhe4096
|
||||
utls.CurveID(0x0103), // ffdhe6144
|
||||
utls.CurveID(0x0104), // ffdhe8192
|
||||
}
|
||||
|
||||
// defaultPointFormats contains all 3 point formats from Claude CLI
|
||||
defaultPointFormats = []uint8{
|
||||
0, // uncompressed
|
||||
1, // ansiX962_compressed_prime
|
||||
2, // ansiX962_compressed_char2
|
||||
}
|
||||
|
||||
// defaultSignatureAlgorithms contains the 20 signature algorithms from Claude CLI
|
||||
defaultSignatureAlgorithms = []utls.SignatureScheme{
|
||||
0x0403, // ecdsa_secp256r1_sha256
|
||||
0x0503, // ecdsa_secp384r1_sha384
|
||||
0x0603, // ecdsa_secp521r1_sha512
|
||||
0x0807, // ed25519
|
||||
0x0808, // ed448
|
||||
0x0809, // rsa_pss_pss_sha256
|
||||
0x080a, // rsa_pss_pss_sha384
|
||||
0x080b, // rsa_pss_pss_sha512
|
||||
0x0804, // rsa_pss_rsae_sha256
|
||||
0x0805, // rsa_pss_rsae_sha384
|
||||
0x0806, // rsa_pss_rsae_sha512
|
||||
0x0401, // rsa_pkcs1_sha256
|
||||
0x0501, // rsa_pkcs1_sha384
|
||||
0x0601, // rsa_pkcs1_sha512
|
||||
0x0303, // ecdsa_sha224
|
||||
0x0301, // rsa_pkcs1_sha224
|
||||
0x0302, // dsa_sha224
|
||||
0x0402, // dsa_sha256
|
||||
0x0502, // dsa_sha384
|
||||
0x0602, // dsa_sha512
|
||||
}
|
||||
)
|
||||
|
||||
// NewDialer creates a new TLS fingerprint dialer.
|
||||
// baseDialer is used for TCP connection establishment (supports proxy scenarios).
|
||||
// If baseDialer is nil, direct TCP dial is used.
|
||||
func NewDialer(profile *Profile, baseDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *Dialer {
|
||||
if baseDialer == nil {
|
||||
baseDialer = (&net.Dialer{}).DialContext
|
||||
}
|
||||
return &Dialer{profile: profile, baseDialer: baseDialer}
|
||||
}
|
||||
|
||||
// NewHTTPProxyDialer creates a new TLS fingerprint dialer that works through HTTP/HTTPS proxies.
|
||||
// It establishes a CONNECT tunnel before performing TLS handshake with custom fingerprint.
|
||||
func NewHTTPProxyDialer(profile *Profile, proxyURL *url.URL) *HTTPProxyDialer {
|
||||
return &HTTPProxyDialer{profile: profile, proxyURL: proxyURL}
|
||||
}
|
||||
|
||||
// NewSOCKS5ProxyDialer creates a new TLS fingerprint dialer that works through SOCKS5 proxies.
|
||||
// It establishes a SOCKS5 tunnel before performing TLS handshake with custom fingerprint.
|
||||
func NewSOCKS5ProxyDialer(profile *Profile, proxyURL *url.URL) *SOCKS5ProxyDialer {
|
||||
return &SOCKS5ProxyDialer{profile: profile, proxyURL: proxyURL}
|
||||
}
|
||||
|
||||
// DialTLSContext establishes a TLS connection through SOCKS5 proxy with the configured fingerprint.
|
||||
// Flow: SOCKS5 CONNECT to target -> TLS handshake with utls on the tunnel
|
||||
func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
slog.Debug("tls_fingerprint_socks5_connecting", "proxy", d.proxyURL.Host, "target", addr)
|
||||
|
||||
// Step 1: Create SOCKS5 dialer
|
||||
var auth *proxy.Auth
|
||||
if d.proxyURL.User != nil {
|
||||
username := d.proxyURL.User.Username()
|
||||
password, _ := d.proxyURL.User.Password()
|
||||
auth = &proxy.Auth{
|
||||
User: username,
|
||||
Password: password,
|
||||
}
|
||||
}
|
||||
|
||||
// Determine proxy address
|
||||
proxyAddr := d.proxyURL.Host
|
||||
if d.proxyURL.Port() == "" {
|
||||
proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "1080") // Default SOCKS5 port
|
||||
}
|
||||
|
||||
socksDialer, err := proxy.SOCKS5("tcp", proxyAddr, auth, proxy.Direct)
|
||||
if err != nil {
|
||||
slog.Debug("tls_fingerprint_socks5_dialer_failed", "error", err)
|
||||
return nil, fmt.Errorf("create SOCKS5 dialer: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Establish SOCKS5 tunnel to target
|
||||
slog.Debug("tls_fingerprint_socks5_establishing_tunnel", "target", addr)
|
||||
conn, err := socksDialer.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
slog.Debug("tls_fingerprint_socks5_connect_failed", "error", err)
|
||||
return nil, fmt.Errorf("SOCKS5 connect: %w", err)
|
||||
}
|
||||
slog.Debug("tls_fingerprint_socks5_tunnel_established")
|
||||
|
||||
// Step 3: Perform TLS handshake on the tunnel with utls fingerprint
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr
|
||||
}
|
||||
slog.Debug("tls_fingerprint_socks5_starting_handshake", "host", host)
|
||||
|
||||
// Build ClientHello specification from profile (Node.js/Claude CLI fingerprint)
|
||||
spec := buildClientHelloSpecFromProfile(d.profile)
|
||||
slog.Debug("tls_fingerprint_socks5_clienthello_spec",
|
||||
"cipher_suites", len(spec.CipherSuites),
|
||||
"extensions", len(spec.Extensions),
|
||||
"compression_methods", spec.CompressionMethods,
|
||||
"tls_vers_max", fmt.Sprintf("0x%04x", spec.TLSVersMax),
|
||||
"tls_vers_min", fmt.Sprintf("0x%04x", spec.TLSVersMin))
|
||||
|
||||
if d.profile != nil {
|
||||
slog.Debug("tls_fingerprint_socks5_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
|
||||
}
|
||||
|
||||
// Create uTLS connection on the tunnel
|
||||
tlsConn := utls.UClient(conn, &utls.Config{
|
||||
ServerName: host,
|
||||
}, utls.HelloCustom)
|
||||
|
||||
if err := tlsConn.ApplyPreset(spec); err != nil {
|
||||
slog.Debug("tls_fingerprint_socks5_apply_preset_failed", "error", err)
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("apply TLS preset: %w", err)
|
||||
}
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
slog.Debug("tls_fingerprint_socks5_handshake_failed", "error", err)
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("TLS handshake failed: %w", err)
|
||||
}
|
||||
|
||||
state := tlsConn.ConnectionState()
|
||||
slog.Debug("tls_fingerprint_socks5_handshake_success",
|
||||
"version", fmt.Sprintf("0x%04x", state.Version),
|
||||
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
|
||||
"alpn", state.NegotiatedProtocol)
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
// DialTLSContext establishes a TLS connection through HTTP proxy with the configured fingerprint.
|
||||
// Flow: TCP connect to proxy -> CONNECT tunnel -> TLS handshake with utls
|
||||
func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
slog.Debug("tls_fingerprint_http_proxy_connecting", "proxy", d.proxyURL.Host, "target", addr)
|
||||
|
||||
// Step 1: TCP connect to proxy server
|
||||
var proxyAddr string
|
||||
if d.proxyURL.Port() != "" {
|
||||
proxyAddr = d.proxyURL.Host
|
||||
} else {
|
||||
// Default ports
|
||||
if d.proxyURL.Scheme == "https" {
|
||||
proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "443")
|
||||
} else {
|
||||
proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "80")
|
||||
}
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{}
|
||||
conn, err := dialer.DialContext(ctx, "tcp", proxyAddr)
|
||||
if err != nil {
|
||||
slog.Debug("tls_fingerprint_http_proxy_connect_failed", "error", err)
|
||||
return nil, fmt.Errorf("connect to proxy: %w", err)
|
||||
}
|
||||
slog.Debug("tls_fingerprint_http_proxy_connected", "proxy_addr", proxyAddr)
|
||||
|
||||
// Step 2: Send CONNECT request to establish tunnel
|
||||
req := &http.Request{
|
||||
Method: "CONNECT",
|
||||
URL: &url.URL{Opaque: addr},
|
||||
Host: addr,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
|
||||
// Add proxy authentication if present
|
||||
if d.proxyURL.User != nil {
|
||||
username := d.proxyURL.User.Username()
|
||||
password, _ := d.proxyURL.User.Password()
|
||||
auth := base64.StdEncoding.EncodeToString([]byte(username + ":" + password))
|
||||
req.Header.Set("Proxy-Authorization", "Basic "+auth)
|
||||
}
|
||||
|
||||
slog.Debug("tls_fingerprint_http_proxy_sending_connect", "target", addr)
|
||||
if err := req.Write(conn); err != nil {
|
||||
_ = conn.Close()
|
||||
slog.Debug("tls_fingerprint_http_proxy_write_failed", "error", err)
|
||||
return nil, fmt.Errorf("write CONNECT request: %w", err)
|
||||
}
|
||||
|
||||
// Step 3: Read CONNECT response
|
||||
br := bufio.NewReader(conn)
|
||||
resp, err := http.ReadResponse(br, req)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
slog.Debug("tls_fingerprint_http_proxy_read_response_failed", "error", err)
|
||||
return nil, fmt.Errorf("read CONNECT response: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
_ = conn.Close()
|
||||
slog.Debug("tls_fingerprint_http_proxy_connect_failed_status", "status_code", resp.StatusCode, "status", resp.Status)
|
||||
return nil, fmt.Errorf("proxy CONNECT failed: %s", resp.Status)
|
||||
}
|
||||
slog.Debug("tls_fingerprint_http_proxy_tunnel_established")
|
||||
|
||||
// Step 4: Perform TLS handshake on the tunnel with utls fingerprint
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr
|
||||
}
|
||||
slog.Debug("tls_fingerprint_http_proxy_starting_handshake", "host", host)
|
||||
|
||||
// Build ClientHello specification (reuse the shared method)
|
||||
spec := buildClientHelloSpecFromProfile(d.profile)
|
||||
slog.Debug("tls_fingerprint_http_proxy_clienthello_spec",
|
||||
"cipher_suites", len(spec.CipherSuites),
|
||||
"extensions", len(spec.Extensions))
|
||||
|
||||
if d.profile != nil {
|
||||
slog.Debug("tls_fingerprint_http_proxy_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
|
||||
}
|
||||
|
||||
// Create uTLS connection on the tunnel
|
||||
// Note: TLS 1.3 cipher suites are handled automatically by utls when TLS 1.3 is in SupportedVersions
|
||||
tlsConn := utls.UClient(conn, &utls.Config{
|
||||
ServerName: host,
|
||||
}, utls.HelloCustom)
|
||||
|
||||
if err := tlsConn.ApplyPreset(spec); err != nil {
|
||||
slog.Debug("tls_fingerprint_http_proxy_apply_preset_failed", "error", err)
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("apply TLS preset: %w", err)
|
||||
}
|
||||
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
slog.Debug("tls_fingerprint_http_proxy_handshake_failed", "error", err)
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("TLS handshake failed: %w", err)
|
||||
}
|
||||
|
||||
state := tlsConn.ConnectionState()
|
||||
slog.Debug("tls_fingerprint_http_proxy_handshake_success",
|
||||
"version", fmt.Sprintf("0x%04x", state.Version),
|
||||
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
|
||||
"alpn", state.NegotiatedProtocol)
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
// DialTLSContext establishes a TLS connection with the configured fingerprint.
|
||||
// This method is designed to be used as http.Transport.DialTLSContext.
|
||||
func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
// Establish TCP connection using base dialer (supports proxy)
|
||||
slog.Debug("tls_fingerprint_dialing_tcp", "addr", addr)
|
||||
conn, err := d.baseDialer(ctx, network, addr)
|
||||
if err != nil {
|
||||
slog.Debug("tls_fingerprint_tcp_dial_failed", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
slog.Debug("tls_fingerprint_tcp_connected", "addr", addr)
|
||||
|
||||
// Extract hostname for SNI
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr
|
||||
}
|
||||
slog.Debug("tls_fingerprint_sni_hostname", "host", host)
|
||||
|
||||
// Build ClientHello specification
|
||||
spec := d.buildClientHelloSpec()
|
||||
slog.Debug("tls_fingerprint_clienthello_spec",
|
||||
"cipher_suites", len(spec.CipherSuites),
|
||||
"extensions", len(spec.Extensions))
|
||||
|
||||
// Log profile info
|
||||
if d.profile != nil {
|
||||
slog.Debug("tls_fingerprint_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
|
||||
} else {
|
||||
slog.Debug("tls_fingerprint_using_default_profile")
|
||||
}
|
||||
|
||||
// Create uTLS connection
|
||||
// Note: TLS 1.3 cipher suites are handled automatically by utls when TLS 1.3 is in SupportedVersions
|
||||
tlsConn := utls.UClient(conn, &utls.Config{
|
||||
ServerName: host,
|
||||
}, utls.HelloCustom)
|
||||
|
||||
// Apply fingerprint
|
||||
if err := tlsConn.ApplyPreset(spec); err != nil {
|
||||
slog.Debug("tls_fingerprint_apply_preset_failed", "error", err)
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
slog.Debug("tls_fingerprint_preset_applied")
|
||||
|
||||
// Perform TLS handshake
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
slog.Debug("tls_fingerprint_handshake_failed",
|
||||
"error", err,
|
||||
"local_addr", conn.LocalAddr(),
|
||||
"remote_addr", conn.RemoteAddr())
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("TLS handshake failed: %w", err)
|
||||
}
|
||||
|
||||
// Log successful handshake details
|
||||
state := tlsConn.ConnectionState()
|
||||
slog.Debug("tls_fingerprint_handshake_success",
|
||||
"version", fmt.Sprintf("0x%04x", state.Version),
|
||||
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
|
||||
"alpn", state.NegotiatedProtocol)
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
// buildClientHelloSpec constructs the ClientHello specification based on the profile.
|
||||
func (d *Dialer) buildClientHelloSpec() *utls.ClientHelloSpec {
|
||||
return buildClientHelloSpecFromProfile(d.profile)
|
||||
}
|
||||
|
||||
// toUTLSCurves converts uint16 slice to utls.CurveID slice.
|
||||
func toUTLSCurves(curves []uint16) []utls.CurveID {
|
||||
result := make([]utls.CurveID, len(curves))
|
||||
for i, c := range curves {
|
||||
result[i] = utls.CurveID(c)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// buildClientHelloSpecFromProfile constructs ClientHelloSpec from a Profile.
|
||||
// This is a standalone function that can be used by both Dialer and HTTPProxyDialer.
|
||||
func buildClientHelloSpecFromProfile(profile *Profile) *utls.ClientHelloSpec {
|
||||
// Get cipher suites
|
||||
var cipherSuites []uint16
|
||||
if profile != nil && len(profile.CipherSuites) > 0 {
|
||||
cipherSuites = profile.CipherSuites
|
||||
} else {
|
||||
cipherSuites = defaultCipherSuites
|
||||
}
|
||||
|
||||
// Get curves
|
||||
var curves []utls.CurveID
|
||||
if profile != nil && len(profile.Curves) > 0 {
|
||||
curves = toUTLSCurves(profile.Curves)
|
||||
} else {
|
||||
curves = defaultCurves
|
||||
}
|
||||
|
||||
// Get point formats
|
||||
var pointFormats []uint8
|
||||
if profile != nil && len(profile.PointFormats) > 0 {
|
||||
pointFormats = profile.PointFormats
|
||||
} else {
|
||||
pointFormats = defaultPointFormats
|
||||
}
|
||||
|
||||
// Check if GREASE is enabled
|
||||
enableGREASE := profile != nil && profile.EnableGREASE
|
||||
|
||||
extensions := make([]utls.TLSExtension, 0, 16)
|
||||
|
||||
if enableGREASE {
|
||||
extensions = append(extensions, &utls.UtlsGREASEExtension{})
|
||||
}
|
||||
|
||||
// SNI extension - MUST be explicitly added for HelloCustom mode
|
||||
// utls will populate the server name from Config.ServerName
|
||||
extensions = append(extensions, &utls.SNIExtension{})
|
||||
|
||||
// Claude CLI extension order (captured from tshark):
|
||||
// server_name(0), ec_point_formats(11), supported_groups(10), session_ticket(35),
|
||||
// alpn(16), encrypt_then_mac(22), extended_master_secret(23),
|
||||
// signature_algorithms(13), supported_versions(43),
|
||||
// psk_key_exchange_modes(45), key_share(51)
|
||||
extensions = append(extensions,
|
||||
&utls.SupportedPointsExtension{SupportedPoints: pointFormats},
|
||||
&utls.SupportedCurvesExtension{Curves: curves},
|
||||
&utls.SessionTicketExtension{},
|
||||
&utls.ALPNExtension{AlpnProtocols: []string{"http/1.1"}},
|
||||
&utls.GenericExtension{Id: 22},
|
||||
&utls.ExtendedMasterSecretExtension{},
|
||||
&utls.SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: defaultSignatureAlgorithms},
|
||||
&utls.SupportedVersionsExtension{Versions: []uint16{
|
||||
utls.VersionTLS13,
|
||||
utls.VersionTLS12,
|
||||
}},
|
||||
&utls.PSKKeyExchangeModesExtension{Modes: []uint8{utls.PskModeDHE}},
|
||||
&utls.KeyShareExtension{KeyShares: []utls.KeyShare{
|
||||
{Group: utls.X25519},
|
||||
}},
|
||||
)
|
||||
|
||||
if enableGREASE {
|
||||
extensions = append(extensions, &utls.UtlsGREASEExtension{})
|
||||
}
|
||||
|
||||
return &utls.ClientHelloSpec{
|
||||
CipherSuites: cipherSuites,
|
||||
CompressionMethods: []uint8{0}, // null compression only (standard)
|
||||
Extensions: extensions,
|
||||
TLSVersMax: utls.VersionTLS13,
|
||||
TLSVersMin: utls.VersionTLS10,
|
||||
}
|
||||
}
|
||||
278
backend/internal/pkg/tlsfingerprint/dialer_integration_test.go
Normal file
278
backend/internal/pkg/tlsfingerprint/dialer_integration_test.go
Normal file
@@ -0,0 +1,278 @@
|
||||
//go:build integration
|
||||
|
||||
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
|
||||
//
|
||||
// Integration tests for verifying TLS fingerprint correctness.
|
||||
// These tests make actual network requests to external services and should be run manually.
|
||||
//
|
||||
// Run with: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
|
||||
package tlsfingerprint
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// skipIfExternalServiceUnavailable checks if the external service is available.
|
||||
// If not, it skips the test instead of failing.
|
||||
func skipIfExternalServiceUnavailable(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
// Check for common network/TLS errors that indicate external service issues
|
||||
errStr := err.Error()
|
||||
if strings.Contains(errStr, "certificate has expired") ||
|
||||
strings.Contains(errStr, "certificate is not yet valid") ||
|
||||
strings.Contains(errStr, "connection refused") ||
|
||||
strings.Contains(errStr, "no such host") ||
|
||||
strings.Contains(errStr, "network is unreachable") ||
|
||||
strings.Contains(errStr, "timeout") {
|
||||
t.Skipf("skipping test: external service unavailable: %v", err)
|
||||
}
|
||||
t.Fatalf("failed to get fingerprint: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
|
||||
// This test uses tls.peet.ws to verify the fingerprint.
|
||||
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
|
||||
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
|
||||
func TestJA3Fingerprint(t *testing.T) {
|
||||
// Skip if network is unavailable or if running in short mode
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
profile := &Profile{
|
||||
Name: "Claude CLI Test",
|
||||
EnableGREASE: false,
|
||||
}
|
||||
dialer := NewDialer(profile, nil)
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialTLSContext: dialer.DialTLSContext,
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// Use tls.peet.ws fingerprint detection API
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
skipIfExternalServiceUnavailable(t, err)
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read response: %v", err)
|
||||
}
|
||||
|
||||
var fpResp FingerprintResponse
|
||||
if err := json.Unmarshal(body, &fpResp); err != nil {
|
||||
t.Logf("Response body: %s", string(body))
|
||||
t.Fatalf("failed to parse fingerprint response: %v", err)
|
||||
}
|
||||
|
||||
// Log all fingerprint information
|
||||
t.Logf("JA3: %s", fpResp.TLS.JA3)
|
||||
t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
|
||||
t.Logf("JA4: %s", fpResp.TLS.JA4)
|
||||
t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
|
||||
t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
|
||||
|
||||
// Verify JA3 hash matches expected value
|
||||
expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
|
||||
if fpResp.TLS.JA3Hash == expectedJA3Hash {
|
||||
t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
|
||||
} else {
|
||||
t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
|
||||
}
|
||||
|
||||
// Verify JA4 fingerprint
|
||||
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
|
||||
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
|
||||
// The suffix _a33745022dd6_1f22a2ca17c4 should match
|
||||
expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
|
||||
if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
|
||||
t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
|
||||
} else {
|
||||
t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
|
||||
}
|
||||
|
||||
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
|
||||
// d = domain (SNI present), i = IP (no SNI)
|
||||
// Since we connect to tls.peet.ws (domain), we expect 'd'
|
||||
expectedJA4Prefix := "t13d5911h1"
|
||||
if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
|
||||
t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
|
||||
} else {
|
||||
// Also accept 'i' variant for IP connections
|
||||
altPrefix := "t13i5911h1"
|
||||
if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
|
||||
t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
|
||||
} else {
|
||||
t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
|
||||
if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
|
||||
t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
|
||||
} else {
|
||||
t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
|
||||
}
|
||||
|
||||
// Verify extension list (should be 11 extensions including SNI)
|
||||
// Expected: 0-11-10-35-16-22-23-13-43-45-51
|
||||
expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
|
||||
if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
|
||||
t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
|
||||
} else {
|
||||
t.Logf("Warning: JA3 extension list may differ")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfileExpectation defines expected fingerprint values for a profile.
|
||||
type TestProfileExpectation struct {
|
||||
Profile *Profile
|
||||
ExpectedJA3 string // Expected JA3 hash (empty = don't check)
|
||||
ExpectedJA4 string // Expected full JA4 (empty = don't check)
|
||||
JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
|
||||
}
|
||||
|
||||
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
|
||||
// Run with: go test -v -tags=integration -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
|
||||
func TestAllProfiles(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// Define all profiles to test with their expected fingerprints
|
||||
// These profiles are from config.yaml gateway.tls_fingerprint.profiles
|
||||
profiles := []TestProfileExpectation{
|
||||
{
|
||||
// Linux x64 Node.js v22.17.1
|
||||
// Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
|
||||
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
|
||||
Profile: &Profile{
|
||||
Name: "linux_x64_node_v22171",
|
||||
EnableGREASE: false,
|
||||
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
|
||||
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
|
||||
PointFormats: []uint8{0, 1, 2},
|
||||
},
|
||||
JA4CipherHash: "a33745022dd6", // stable part
|
||||
},
|
||||
{
|
||||
// MacOS arm64 Node.js v22.18.0
|
||||
// Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
|
||||
// Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
|
||||
Profile: &Profile{
|
||||
Name: "macos_arm64_node_v22180",
|
||||
EnableGREASE: false,
|
||||
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
|
||||
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
|
||||
PointFormats: []uint8{0, 1, 2},
|
||||
},
|
||||
JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range profiles {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.Profile.Name, func(t *testing.T) {
|
||||
fp := fetchFingerprint(t, tc.Profile)
|
||||
if fp == nil {
|
||||
return // fetchFingerprint already called t.Fatal
|
||||
}
|
||||
|
||||
t.Logf("Profile: %s", tc.Profile.Name)
|
||||
t.Logf(" JA3: %s", fp.JA3)
|
||||
t.Logf(" JA3 Hash: %s", fp.JA3Hash)
|
||||
t.Logf(" JA4: %s", fp.JA4)
|
||||
t.Logf(" PeetPrint: %s", fp.PeetPrint)
|
||||
t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
|
||||
|
||||
// Verify expectations
|
||||
if tc.ExpectedJA3 != "" {
|
||||
if fp.JA3Hash == tc.ExpectedJA3 {
|
||||
t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
|
||||
} else {
|
||||
t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.ExpectedJA4 != "" {
|
||||
if fp.JA4 == tc.ExpectedJA4 {
|
||||
t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
|
||||
} else {
|
||||
t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
|
||||
}
|
||||
}
|
||||
|
||||
// Check JA4 cipher hash (stable middle part)
|
||||
// JA4 format: prefix_cipherHash_extHash
|
||||
if tc.JA4CipherHash != "" {
|
||||
if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
|
||||
t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
|
||||
} else {
|
||||
t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
|
||||
func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
|
||||
t.Helper()
|
||||
|
||||
dialer := NewDialer(profile, nil)
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialTLSContext: dialer.DialTLSContext,
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
return nil
|
||||
}
|
||||
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
skipIfExternalServiceUnavailable(t, err)
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read response: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
var fpResp FingerprintResponse
|
||||
if err := json.Unmarshal(body, &fpResp); err != nil {
|
||||
t.Logf("Response body: %s", string(body))
|
||||
t.Fatalf("failed to parse fingerprint response: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &fpResp.TLS
|
||||
}
|
||||
160
backend/internal/pkg/tlsfingerprint/dialer_test.go
Normal file
160
backend/internal/pkg/tlsfingerprint/dialer_test.go
Normal file
@@ -0,0 +1,160 @@
|
||||
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
|
||||
//
|
||||
// Unit tests for TLS fingerprint dialer.
|
||||
// Integration tests that require external network are in dialer_integration_test.go
|
||||
// and require the 'integration' build tag.
|
||||
//
|
||||
// Run unit tests: go test -v ./internal/pkg/tlsfingerprint/...
|
||||
// Run integration tests: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
|
||||
package tlsfingerprint
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// FingerprintResponse represents the response from tls.peet.ws/api/all.
|
||||
type FingerprintResponse struct {
|
||||
IP string `json:"ip"`
|
||||
TLS TLSInfo `json:"tls"`
|
||||
HTTP2 any `json:"http2"`
|
||||
}
|
||||
|
||||
// TLSInfo contains TLS fingerprint details.
|
||||
type TLSInfo struct {
|
||||
JA3 string `json:"ja3"`
|
||||
JA3Hash string `json:"ja3_hash"`
|
||||
JA4 string `json:"ja4"`
|
||||
PeetPrint string `json:"peetprint"`
|
||||
PeetPrintHash string `json:"peetprint_hash"`
|
||||
ClientRandom string `json:"client_random"`
|
||||
SessionID string `json:"session_id"`
|
||||
}
|
||||
|
||||
// TestDialerWithProfile tests that different profiles produce different fingerprints.
|
||||
func TestDialerWithProfile(t *testing.T) {
|
||||
// Create two dialers with different profiles
|
||||
profile1 := &Profile{
|
||||
Name: "Profile 1 - No GREASE",
|
||||
EnableGREASE: false,
|
||||
}
|
||||
profile2 := &Profile{
|
||||
Name: "Profile 2 - With GREASE",
|
||||
EnableGREASE: true,
|
||||
}
|
||||
|
||||
dialer1 := NewDialer(profile1, nil)
|
||||
dialer2 := NewDialer(profile2, nil)
|
||||
|
||||
// Build specs and compare
|
||||
// Note: We can't directly compare JA3 without making network requests
|
||||
// but we can verify the specs are different
|
||||
spec1 := dialer1.buildClientHelloSpec()
|
||||
spec2 := dialer2.buildClientHelloSpec()
|
||||
|
||||
// Profile with GREASE should have more extensions
|
||||
if len(spec2.Extensions) <= len(spec1.Extensions) {
|
||||
t.Error("expected GREASE profile to have more extensions")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHTTPProxyDialerBasic tests HTTP proxy dialer creation.
|
||||
// Note: This is a unit test - actual proxy testing requires a proxy server.
|
||||
func TestHTTPProxyDialerBasic(t *testing.T) {
|
||||
profile := &Profile{
|
||||
Name: "Test Profile",
|
||||
EnableGREASE: false,
|
||||
}
|
||||
|
||||
// Test that dialer is created without panic
|
||||
proxyURL := mustParseURL("http://proxy.example.com:8080")
|
||||
dialer := NewHTTPProxyDialer(profile, proxyURL)
|
||||
|
||||
if dialer == nil {
|
||||
t.Fatal("expected dialer to be created")
|
||||
}
|
||||
if dialer.profile != profile {
|
||||
t.Error("expected profile to be set")
|
||||
}
|
||||
if dialer.proxyURL != proxyURL {
|
||||
t.Error("expected proxyURL to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSOCKS5ProxyDialerBasic tests SOCKS5 proxy dialer creation.
|
||||
// Note: This is a unit test - actual proxy testing requires a proxy server.
|
||||
func TestSOCKS5ProxyDialerBasic(t *testing.T) {
|
||||
profile := &Profile{
|
||||
Name: "Test Profile",
|
||||
EnableGREASE: false,
|
||||
}
|
||||
|
||||
// Test that dialer is created without panic
|
||||
proxyURL := mustParseURL("socks5://proxy.example.com:1080")
|
||||
dialer := NewSOCKS5ProxyDialer(profile, proxyURL)
|
||||
|
||||
if dialer == nil {
|
||||
t.Fatal("expected dialer to be created")
|
||||
}
|
||||
if dialer.profile != profile {
|
||||
t.Error("expected profile to be set")
|
||||
}
|
||||
if dialer.proxyURL != proxyURL {
|
||||
t.Error("expected proxyURL to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildClientHelloSpec tests ClientHello spec construction.
|
||||
func TestBuildClientHelloSpec(t *testing.T) {
|
||||
// Test with nil profile (should use defaults)
|
||||
spec := buildClientHelloSpecFromProfile(nil)
|
||||
|
||||
if len(spec.CipherSuites) == 0 {
|
||||
t.Error("expected cipher suites to be set")
|
||||
}
|
||||
if len(spec.Extensions) == 0 {
|
||||
t.Error("expected extensions to be set")
|
||||
}
|
||||
|
||||
// Verify default cipher suites are used
|
||||
if len(spec.CipherSuites) != len(defaultCipherSuites) {
|
||||
t.Errorf("expected %d cipher suites, got %d", len(defaultCipherSuites), len(spec.CipherSuites))
|
||||
}
|
||||
|
||||
// Test with custom profile
|
||||
customProfile := &Profile{
|
||||
Name: "Custom",
|
||||
EnableGREASE: false,
|
||||
CipherSuites: []uint16{0x1301, 0x1302},
|
||||
}
|
||||
spec = buildClientHelloSpecFromProfile(customProfile)
|
||||
|
||||
if len(spec.CipherSuites) != 2 {
|
||||
t.Errorf("expected 2 cipher suites, got %d", len(spec.CipherSuites))
|
||||
}
|
||||
}
|
||||
|
||||
// TestToUTLSCurves tests curve ID conversion.
|
||||
func TestToUTLSCurves(t *testing.T) {
|
||||
input := []uint16{0x001d, 0x0017, 0x0018}
|
||||
result := toUTLSCurves(input)
|
||||
|
||||
if len(result) != len(input) {
|
||||
t.Errorf("expected %d curves, got %d", len(input), len(result))
|
||||
}
|
||||
|
||||
for i, curve := range result {
|
||||
if uint16(curve) != input[i] {
|
||||
t.Errorf("curve %d: expected 0x%04x, got 0x%04x", i, input[i], uint16(curve))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to parse URL without error handling.
|
||||
func mustParseURL(rawURL string) *url.URL {
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return u
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user