Compare commits

..

296 Commits

Author SHA1 Message Date
shaw
426d691c95 fix(urlvalidator): 移除 ValidateURLFormat 返回值的末尾斜杠
修复 API Key 账号 base_url 末尾带斜杠时导致的双斜杠问题
2026-01-26 10:21:41 +08:00
shaw
e9a4c8ab97 docs: 修改演示站点域名 2026-01-26 10:04:44 +08:00
Wesley Liddick
34cc02f8c7 Merge pull request #393 from IanShaw027/fix/gemini-thought-signature-preserve
fix(gemini): 修复 thoughtSignature 跨账号验证错误
2026-01-26 09:23:46 +08:00
Wesley Liddick
624d9fddb7 Merge pull request #391 from geminiwen/main
fix(subscription): 修复订阅调整逻辑,已过期订阅从当前时间计算
2026-01-26 09:23:29 +08:00
Wesley Liddick
47fbe43324 Merge pull request #385 from DDZS987/fix/oauth-token-refresh-missing-project-id-retry
fix(oauth): 修复 OAuth 令牌刷新时 missing_project_id 误报问题
2026-01-26 09:22:48 +08:00
shaw
1245f07a2d feat(auth): 实现 TOTP 双因素认证功能
新增功能:
- 支持 Google Authenticator 等应用进行 TOTP 二次验证
- 用户可在个人设置中启用/禁用 2FA
- 登录时支持 TOTP 验证流程
- 管理后台可全局开关 TOTP 功能

安全增强:
- TOTP 密钥使用 AES-256-GCM 加密存储
- 添加 TOTP_ENCRYPTION_KEY 配置项,必须手动配置才能启用功能
- 防止服务重启导致加密密钥变更使用户无法登录
- 验证失败次数限制,防止暴力破解

配置说明:
- Docker 部署:在 .env 中设置 TOTP_ENCRYPTION_KEY
- 非 Docker 部署:在 config.yaml 中设置 totp.encryption_key
- 生成密钥命令:openssl rand -hex 32
2026-01-26 09:19:53 +08:00
ianshaw
839975b0cf feat(gemini): 支持 Gemini CLI 粘性会话与跨账号 thoughtSignature 清理
## 问题背景

1. Gemini CLI 没有明确的会话标识(如 Claude Code 的 metadata.user_id)
2. thoughtSignature 与具体上游账号强绑定,跨账号使用会导致 400 错误
3. 粘性会话切换账号或 cache 丢失时,旧签名会导致请求失败

## 解决方案

### 1. Gemini CLI 会话标识提取

- 从 `x-gemini-api-privileged-user-id` header 和请求体中的 tmp 目录哈希生成会话标识
- 组合策略:SHA256(privileged-user-id + ":" + tmp_dir_hash)
- 正则提取:`/\.gemini/tmp/([A-Fa-f0-9]{64})`

### 2. 跨账号 thoughtSignature 清理

实现三种场景的智能清理:

1. **Cache 命中 + 账号切换**
   - 粘性会话绑定的账号与当前选择的账号不同时清理

2. **同一请求内 failover 切换**
   - 通过 sessionBoundAccountID 跟踪,检测重试时的账号切换

3. **Gemini CLI + Cache 未命中 + 含签名**
   - 预防性清理,避免 cache 丢失后首次转发就 400
   - 仅对 Gemini CLI 请求且请求体包含 thoughtSignature 时触发

## 修改内容

### backend/internal/handler/gemini_v1beta_handler.go
- 添加 `extractGeminiCLISessionHash` 函数提取 Gemini CLI 会话标识
- 添加 `isGeminiCLIRequest` 函数识别 Gemini CLI 请求
- 实现账号切换检测与 thoughtSignature 清理逻辑
- 添加 `geminiCLITmpDirRegex` 正则表达式

### backend/internal/service/gateway_service.go
- 添加 `GetCachedSessionAccountID` 方法查询粘性会话绑定的账号 ID

### backend/internal/service/gemini_native_signature_cleaner.go (新增)
- 实现 `CleanGeminiNativeThoughtSignatures` 函数
- 递归清理 JSON 中的所有 thoughtSignature 字段
- 支持任意 JSON 顶层类型(object/array)

### backend/internal/handler/gemini_cli_session_test.go (新增)
- 测试 Gemini CLI 会话哈希提取逻辑
- 测试 tmp 目录正则匹配
- 覆盖有/无 privileged-user-id 的场景

## 影响范围

- 修复 Gemini CLI 多轮对话时账号切换导致的 400 错误
- 提高粘性会话的稳定性和容错能力
- 不影响其他客户端(Claude Code 等)的会话标识生成

## 测试

- 单元测试:go test -tags=unit ./internal/handler -run TestExtractGeminiCLISessionHash
- 单元测试:go test -tags=unit ./internal/handler -run TestGeminiCLITmpDirRegex
- 编译验证:go build ./cmd/server
2026-01-26 04:40:38 +08:00
ianshaw
8c1233393f fix(antigravity): 修复 Gemini 模型 thoughtSignature 被错误覆盖的问题
## 问题描述

在使用 Gemini 模型(gemini-3-flash-preview)时,出现 400 错误:
"Unable to submit request because Thought signature is not valid"

## 根本原因

在 `request_transformer.go` 的 `buildParts()` 函数中:
- 对于 `tool_use` 和 `thinking` 块,当 `allowDummyThought=true`(Gemini 模型)时
- 代码会无条件将客户端传入的真实 `thoughtSignature` 覆盖成 dummy 值
- 导致 Gemini API 验证签名失败(签名与上下文不匹配)

## 修复方案

修改 signature 处理逻辑:
1. **优先透传真实 signature**:如果客户端提供了有效的 signature,保留它
2. **缺失时才使用 dummy**:只有在 signature 缺失且是 Gemini 模型时,才使用 dummy signature
3. **Claude 模型特殊处理**:将 dummy signature 视为缺失,避免透传到需要真实签名的链路

## 修改内容

### request_transformer.go
- `thinking` 块(第 367-386 行):优先透传真实 signature
- `tool_use` 块(第 411-418 行):优先透传真实 signature

### request_transformer_test.go
- 修改测试用例名称,反映新的行为
- 新增测试用例验证"缺失时才使用 dummy"的逻辑

## 影响范围

- 修复 Gemini 模型在多轮对话中使用 tool_use 时的签名验证错误
- 不影响 Claude 模型的现有行为
- 提高跨账号切换时的稳定性

相关问题:#[issue_number]
2026-01-26 03:06:45 +08:00
Gemini Wen
9cdb0568cc fix(subscription): 修复订阅调整逻辑,已过期订阅从当前时间计算
- 已过期订阅延长时,从当前时间开始增加天数
- 已过期订阅不允许负向调整(缩短)
- 未过期订阅保持原逻辑,从原过期时间增减

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-25 18:12:15 +08:00
shaw
74e05b83ea fix(ratelimit): 修复 OpenAI 账号限流倒计时计算错误
- 解析 x-codex-* 响应头获取正确的重置时间
- 7d 限制用尽时使用 codex_7d_reset_after_seconds
- 提取 Normalize() 方法统一窗口规范化逻辑
2026-01-25 13:32:08 +08:00
Ubuntu
4ded9e7d49 fix(oauth): 为初始 OAuth 授权添加 LoadCodeAssist 重试机制
问题:
- 初始授权时 LoadCodeAssist 没有重试机制,失败后直接跳过
- 导致账号创建时就可能缺失 project_id
- 之后每次刷新都因为 missing_project_id 报错

修复:
- 统一使用 loadProjectIDWithRetry 方法(最多4次尝试)
- 初始授权和token刷新使用相同的重试策略
- 保留原注释说明部分账户可能没有 project_id

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-24 23:41:36 +08:00
Ubuntu
716272a1e2 fix(oauth): 彻底修复 project_id 丢失问题
根本原因:
- BuildAccountCredentials 只在 project_id 非空时才添加该字段
- LoadCodeAssist 失败时返回空字符串 → 新 credentials 不包含 project_id 键
- 普通合并逻辑只保留新 credentials 中不存在的键,无法覆盖空值

解决方案:
1. 在合并后特殊处理 project_id:如果新值为空但旧值非空,保留旧值
2. LoadCodeAssist 失败不再返回错误,只记录警告
3. Token 刷新成功(access_token 已更新)就不应标记账户为 error

改进效果:
- 即使 LoadCodeAssist 连续失败,已有的 project_id 也不会丢失
- 避免因临时网络问题将账户误标记为不可用
- 允许在下次刷新时自动重试获取 project_id

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-24 23:04:48 +08:00
shaw
9cc8352593 feat(auth): 密码重置邮件队列化与限流优化
- 邮件发送改为异步队列处理,避免并发导致发送失败
- 新增 Email 维度限流(30秒冷却期),防止邮件轰炸
- Token 验证使用常量时间比较,防止时序攻击
- 重构代码消除冗余,提取公共验证逻辑
2026-01-24 22:55:28 +08:00
shaw
43a1031e38 fix(test): 修复订阅相关测试失败问题
1. 使用未来日期(2099年)作为测试订阅的过期时间,避免
   normalizeSubscriptionStatus 将测试数据标记为过期
2. 修复 List 方法调用参数不足的问题(新增 sortBy/sortOrder 参数)
2026-01-24 21:10:02 +08:00
Wesley Liddick
a5547b2f30 Merge pull request #380 from DDZS987/fix/oauth-token-refresh-missing-project-id-retry
fix(oauth): 修复 OAuth 令牌刷新时 missing_project_id 误报问题
2026-01-24 20:29:43 +08:00
shaw
b0aa23540b feat(subscription): 订阅过期状态自动更新与服务端排序
- 新增 SubscriptionExpiryService 定时任务,每分钟更新过期订阅状态
- 订阅列表支持服务端排序(按过期时间、状态、创建时间)
- 实时显示正确的过期状态,无需等待定时任务
- 允许对已过期订阅进行续期操作
- DataTable 组件支持 serverSideSort 模式
2026-01-24 20:26:01 +08:00
Ubuntu
ffaa6c4a17 fix(oauth): 修复 OAuth 令牌刷新时 missing_project_id 误报问题
问题描述:
在网络波动环境下,LoadCodeAssist 临时失败会被错误地标记为
"missing_project_id" 不可重试错误,导致账户被禁用。但实际上
账户配置正确,手动刷新后即可恢复。

解决方案:
1. 为 LoadCodeAssist 增加重试机制(最多4次,指数退避)
2. 区分真正的配置缺失和临时网络故障:
   - 如果之前有 project_id,本次获取失败只保留旧值,不报错
   - 只有从未获取过 project_id 且本次也失败时,才标记为缺失
3. 优化错误判断逻辑,避免误报

改进效果:
- 提高在复杂网络环境(如家宽代理)下的鲁棒性
- 减少因临时网络波动导致的服务中断
- 保持真正配置错误的检测能力

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-24 17:44:56 +08:00
Wesley Liddick
fbf72f0ec4 Merge pull request #377 from lynoot/fix/non-streaming-chunk-aggregation
fix(gateway): aggregate all text chunks in non-streaming Gemini responses
2026-01-24 08:49:10 +08:00
lynoot
909b8a8f9c fix(gateway): aggregate all text chunks in non-streaming Gemini responses
Previously, collectGeminiSSE() only returned the last chunk received
from the upstream streaming response when converting to non-streaming.
This caused incomplete responses where only the final text fragment
was returned to clients.

For example, a request asking to "count from 1 to 10" would only
return "\n" (the last chunk) instead of "1\n2\n3\n...\n10\n".

This was especially problematic for JSON structured output where
the opening brace "{" from the first chunk was lost, resulting
in invalid JSON like: colors": ["red", "blue"]}

The fix:
- Collect all text parts from each SSE chunk into a slice
- Merge all collected text parts into the final response
- Reuse the same pattern as handleGeminiStreamToNonStreaming
  in antigravity_gateway_service.go

Fixes: non-streaming responses returning incomplete text
Fixes: structured output (JSON schema) returning invalid JSON
2026-01-23 13:54:09 +00:00
shaw
4a0fe3b143 feat(gateway): 增加 SUGGESTION MODE 请求拦截
扩展现有的预热请求拦截功能,新增对 SUGGESTION MODE 请求的拦截:
- 检测 messages 最后一条 user 消息是否以 [SUGGESTION MODE: 开头
- 拦截后返回空内容响应,节省 token 消耗
- 重构检测逻辑,合并为单一函数,只解析一次 JSON
2026-01-23 16:57:25 +08:00
shaw
a1292fac81 feat(oauth): 支持Anthropic的Team账号使用sk授权 2026-01-23 16:30:12 +08:00
shaw
7f98be4f91 fix(oauth): 更新 Anthropic 账号 OAuth 参数,同步最新客户端 2026-01-23 16:00:42 +08:00
shaw
fd73b8875d feat(frontend): 优化账号限流状态显示,直接展示倒计时 2026-01-23 15:48:25 +08:00
shaw
f9ab1daa3c feat: 保存并显示OAuth账号邮箱地址 2026-01-23 15:17:47 +08:00
shaw
d27b847442 fix(test): 修复测试中引用不存在的函数名
测试文件引用了 IsTokenVersionStale 函数,但实际函数名为 CheckTokenVersion,导致 CI 构建失败
2026-01-23 10:58:30 +08:00
shaw
dac6bc2228 fix(token-cache): 版本过时时使用最新token而非旧token
上次修复(2665230)只阻止了写入缓存,但仍返回旧token导致403。
现在版本过时时直接使用DB中的最新token返回。

- 重构 IsTokenVersionStale 为 CheckTokenVersion,返回最新account
- 消除重复DB查询,复用版本检查时已获取的account
2026-01-23 10:29:52 +08:00
Wesley Liddick
4bd3dbf2ce Merge pull request #354 from DuckyProject/fix/frontend-table
feat(frontend): 账号表格默认排序/持久化 + 自动刷新 + 更多菜单外部关闭
2026-01-22 21:17:48 +08:00
Wesley Liddick
226df1c23a Merge pull request #358 from 0xff26b9a8/main
fix(antigravity): 修复非流式 Claude To Antigravity 响应内容为空的问题
2026-01-22 21:16:54 +08:00
shaw
2665230a09 fix(token-cache): 修复异步刷新与请求线程的缓存竞态条件
- 新增 _token_version 版本号机制,防止过期 token 污染缓存
- TokenRefreshService 刷新成功后写入版本号并清除缓存
- TokenProvider 写入缓存前检查版本,过时则跳过
- ClearError 时同步清除 token 缓存
2026-01-22 21:09:28 +08:00
0xff26b9a8
4f0c2b794c style: gofmt antigravity_gateway_service.go 2026-01-22 14:38:55 +08:00
0xff26b9a8
e756064c19 fix(antigravity): 修复非流式 Claude To Antigravity 响应内容为空的问题
- 修复 TransformGeminiToClaude 的 JSON 解析逻辑,当 V1InternalResponse
  解析成功但 candidates 为空时,尝试直接解析为 GeminiResponse 格式
- 修复 handleClaudeStreamToNonStreaming 收集流式响应的逻辑,累积所有
  chunks 的内容而不是只保留最后一个(最后一个 chunk 通常 text 为空)
- 新增 mergeCollectedPartsToResponse 函数,合并所有类型的 parts
  (text、thinking、functionCall、inlineData),保持原始顺序
- 连续的普通 text parts 合并为一个,thinking/functionCall/inlineData 保持原样
2026-01-22 14:17:59 +08:00
Wesley Liddick
17dfb0af01 Merge pull request #346 from 0xff26b9a8/main
refactor(antigravity): 提取并同步 Schema 清理逻辑至 schema_cleaner.go
2026-01-22 08:46:11 +08:00
ducky
ff74f517df feat(frontend): 账号表格默认排序/持久化 + 自动刷新 + 更多菜单外部关闭 2026-01-21 22:43:25 +08:00
0xff26b9a8
477a9a180f fix: 修复 schema 清理逻辑 2026-01-21 10:58:39 +08:00
0xff26b9a8
da48df06d2 refactor(antigravity): 提取并同步 Schema 清理逻辑至 schema_cleaner.go
主要变更:
1. 重构代码结构:
   - 将 CleanJSONSchema 及其相关辅助函数从 request_transformer.go 提取到独立的 schema_cleaner.go 文件中,实现逻辑解耦。

2. 逻辑优化与修正:
   - 参考 Antigravity-Manager (json_schema.rs) 的实现逻辑,修正了 Schema 清洗策略。
2026-01-20 23:41:53 +08:00
Wesley Liddick
39fad63ccf Merge pull request #345 from whoismonay/main
mod(frontend): 管理员订阅/兑换码分组选择展示备注
2026-01-20 16:22:53 +08:00
Wesley Liddick
5602d02b1b Merge pull request #343 from mt21625457/main
fix(调度): 完善粘性会话清理与账号调度刷新 和 启用 OpenAI OAuth HTTP/2 并修复清理任务 lint
2026-01-20 16:05:53 +08:00
shaw
81989eed1c test: add promo_code_enabled to API contract test 2026-01-20 16:02:49 +08:00
shaw
192efb84a0 feat(promo-code): complete promo code feature implementation
- Add promo_code_enabled field to SystemSettings and PublicSettings DTOs
- Add promo code validation in registration flow
- Add admin settings UI for promo code configuration
- Add i18n translations for promo code feature
2026-01-20 15:56:26 +08:00
shaw
8672347f93 fix(settings): add missing promo_code_enabled field in public settings API
The field was defined in DTO but not mapped in handler response.
2026-01-20 15:49:57 +08:00
yangjianbo
5e5d4a513b feat: 移动镜像脚本位置 2026-01-20 15:11:27 +08:00
墨颜
88b6358472 build(frontend): vite 加载开发环境变量
- 使用 loadEnv 读取 VITE_DEV_PROXY_TARGET/VITE_DEV_PORT
- 注入 public settings 与 dev proxy 使用同源后端地址
2026-01-20 15:04:18 +08:00
墨颜
dd8d5e2c42 mod(frontend): 订阅分组下拉显示备注
- 订阅管理/分配订阅:下拉项展示分组备注
- 兑换码/订阅类型:下拉项展示分组备注
- 复用 GroupOptionItem/GroupBadge 保持一致体验
2026-01-20 15:02:48 +08:00
shaw
d91e2328fb test(tlsfingerprint): add multi-profile fingerprint verification test
Add TestAllProfiles to verify TLS fingerprint configurations from
config.yaml against tls.peet.ws. Tests check JA4 cipher hash (stable
part) to validate fingerprint spoofing works correctly.
2026-01-20 14:53:15 +08:00
yangjianbo
2a16735495 fix(测试): 修复 SelectAccountWithLoadAwareness 调用缺少参数
为 gateway_multiplatform_test.go 中的 SelectAccountWithLoadAwareness
调用添加缺少的第6个参数 metadataUserID,修复 CI 测试编译错误。

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 14:16:46 +08:00
yangjianbo
292f25f9ca Merge branch 'main' of https://github.com/mt21625457/aicodex2api 2026-01-20 14:02:08 +08:00
yangjianbo
c92e37775a Merge branch 'dev' 2026-01-20 13:57:08 +08:00
yangjianbo
f6ed3d1456 Merge branch 'test' into dev 2026-01-20 11:59:13 +08:00
yangjianbo
84686753e8 Merge branch 'test' of https://github.com/mt21625457/aicodex2api into test 2026-01-20 11:51:44 +08:00
yangjianbo
91f01309da fix(调度): 完善粘性会话清理与账号调度刷新
- Update/BulkUpdate 按不可调度字段触发缓存刷新
- GatewayCache 支持多前缀会话键清理
- 模型路由与混合调度优化粘性会话处理
- 补充调度与缓存相关测试覆盖
2026-01-20 11:40:55 +08:00
yangjianbo
57a1fc9d33 style(仓储): 格式化账号仓储
- gofmt 修正 lint 格式提示
2026-01-20 11:30:36 +08:00
shaw
c95a864975 docs: add TLS fingerprint tool link 2026-01-20 11:30:10 +08:00
yangjianbo
7a83db6180 fix(调度): 完善粘性会话清理与账号调度刷新
- Update/BulkUpdate 按不可调度字段触发缓存刷新
- GatewayCache 支持多前缀会话键清理
- 模型路由与混合调度优化粘性会话处理
- 补充调度与缓存相关测试覆盖
2026-01-20 11:19:32 +08:00
Wesley Liddick
a8513da7ff Merge pull request #335 from geminiwen/main
feat(subscription): 支持调整订阅时长(延长/缩短)
2026-01-20 08:52:19 +08:00
Gemini Wen
53534d3956 style(admin): 统一列设置按钮位置到刷新按钮右侧
将订阅管理和账号管理页面的列设置按钮移动到刷新按钮右侧,
与用户管理页面保持一致的布局。

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 00:18:51 +08:00
Gemini Wen
cc07a0e295 feat(subscription): 支持调整订阅时长(延长/缩短)
- 将"延长订阅"功能改为"调整订阅",支持正数延长、负数缩短
- 后端验证:调整天数范围 -36500 到 36500,缩短后剩余天数必须 > 0
- 前端同步更新界面文案和验证逻辑

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 00:11:30 +08:00
Wesley Liddick
e7bc62500b Merge pull request #333 from whoismonay/main
fix: 普通用户接口移除管理员敏感字段透传
2026-01-19 21:35:51 +08:00
墨颜
c8fb9ef3a5 style(dto): 修复 gofmt 格式问题
- 修复 mappers.go 中 Notes 字段的对齐格式
- 修复 types.go 中 BulkAssignResult 结构体字段的 JSON tag 对齐

修复 golangci-lint 检查中的 gofmt 格式错误
2026-01-19 21:26:30 +08:00
Wesley Liddick
eb5e6214bc Merge pull request #332 from geminiwen/main
fix: 修复手动刷新令牌后缓存未清除导致403错误的问题
2026-01-19 20:53:06 +08:00
shaw
568d6ee10e fix: 修复测试缺少新增设置字段 2026-01-19 20:52:05 +08:00
墨颜
6aef1af76e fix(redeem): 用户兑换历史不返回备注
- 用户侧 RedeemCode DTO 移除 notes 字段,避免泄露内部备注\n- 新增 AdminRedeemCode,并调整管理员兑换码接口继续返回 notes\n- 增加 /api/v1/redeem/history 契约测试,确保用户侧响应不包含 notes
2026-01-19 20:09:35 +08:00
Gemini Wen
a54852e129 fix: 补充API契约测试中缺失的hide_ccs_import_button字段
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 20:06:36 +08:00
Gemini Wen
668118def1 fix: 修复遗漏的测试文件更新和lint错误
- 更新api_contract_test.go以匹配NewAccountHandler新增的tokenCacheInvalidator参数
- 修复errcheck lint错误,显式忽略c.Error()返回值

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 19:58:09 +08:00
yangjianbo
73e6b160f8 feat(认证): 启用 OpenAI OAuth HTTP/2 并修复清理任务 lint
为共享 req 客户端增加 HTTP/2 选项与缓存隔离
OpenAI OAuth 超时提升到 120s,并按协议控制强制
新增客户端池与 OAuth 客户端单测覆盖
修复 usage cleanup 相关 errcheck/ineffassign/staticcheck 并统一格式

测试: make test
2026-01-19 19:50:57 +08:00
Gemini Wen
6fec141de6 fix: 修复手动刷新令牌后缓存未清除导致403错误的问题
手动刷新令牌后,新token保存到数据库但Redis缓存未清除,
导致下游请求仍然使用旧的失效token,上游API返回403错误。

修复方案:在AccountHandler中注入TokenCacheInvalidator,
刷新令牌成功后调用InvalidateToken清除缓存。

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 19:40:43 +08:00
墨颜
31cde6c555 fix(subscriptions): 用户订阅不返回分配信息
- 用户侧 UserSubscription DTO 移除 assigned_by/assigned_at/notes/assigned_by_user 等管理员字段\n- 新增 AdminUserSubscription,并调整管理员订阅接口与批量分配结果使用\n- 增加 /api/v1/subscriptions 契约测试,确保用户侧响应不包含上述字段
2026-01-19 19:35:13 +08:00
shaw
b1a980f344 feat: 添加隐藏CCS导入按钮的设置选项
在管理后台设置页面新增开关,允许管理员隐藏API Keys页面的"导入CCS"按钮
2026-01-19 19:25:16 +08:00
墨颜
00d9fbd220 fix(user): 普通用户接口不返回备注
- 用户侧 dto.User 移除 notes 字段,避免泄露管理员备注\n- 新增 dto.AdminUser 并调整 /admin/users 系列接口使用\n- 前端拆分 User/AdminUser,管理端用户页面使用 AdminUser\n- 更新契约测试:/api/v1/auth/me 响应不包含 notes
2026-01-19 19:23:51 +08:00
墨颜
4f4c9679bf fix(groups): 用户分组不下发内部路由信息
- 普通用户 Group DTO 移除 model_routing/account_count/account_groups,避免泄露内部路由与账号信息\n- 新增 AdminGroup DTO,并仅在管理员分组接口返回完整字段\n- 前端拆分 Group/AdminGroup,管理端页面与 API 使用 AdminGroup\n- 增加 /api/v1/groups/available 契约测试,防止回归
2026-01-19 18:58:42 +08:00
shaw
3dab71729d feat: usage接口支持TLS指纹和缓存User-Agent 2026-01-19 17:06:16 +08:00
墨颜
2f6f758670 fix(usage): 用户使用记录不下发账号计费倍率
- 将 usage log DTO 拆分为用户/管理员两类
- 用户接口不返回 account_rate_multiplier/ip_address/account
- 管理员接口保留管理员字段
- 补充契约测试防止回归
2026-01-19 17:05:42 +08:00
shaw
090c8981dd fix: 更新Claude OAuth授权配置以匹配最新规范
- 更新TokenURL和RedirectURI为platform.claude.com
- 更新scope定义,区分浏览器URL和内部API调用
- 修正state/code_verifier生成算法使用base64url编码
- 修正授权URL参数顺序并添加code=true
- 更新token交换请求头匹配官方实现
- 清理未使用的类型和函数
2026-01-19 16:40:06 +08:00
shaw
fbb572948d fix: 修复会话数量查询使用错误的超时配置 2026-01-19 11:45:04 +08:00
shaw
a652b513d3 fix: handle 400 error for disabled organization 2026-01-19 10:54:40 +08:00
shaw
ccfeaeb22d feat: 新增会话ID伪装功能,优化日志系统
- 新增 session_id_masking_enabled 配置,启用后将在15分钟内固定
  metadata.user_id 中的 session ID
- TLS fingerprint 模块日志从自定义 debugLog 迁移到 slog
- main.go 添加 slog 初始化,根据 gin mode 设置日志级别
- 前端创建/编辑账号模态框添加会话ID伪装开关
- 多语言支持(中英文)
2026-01-19 10:22:13 +08:00
shaw
4c12799a95 fix: 补充测试桩缺失的接口方法 2026-01-19 09:28:11 +08:00
Wesley Liddick
0f8d42c577 Merge pull request #327 from mt21625457/main
feat(usage): 添加清理任务与统计过滤
2026-01-19 09:18:00 +08:00
Wesley Liddick
03c7578713 Merge pull request #325 from slovx2/main
fix(antigravity): 修复Antigravity 频繁429的问题,以及一系列优化,配置增强
2026-01-19 09:17:15 +08:00
shaw
de6797c560 fix: 修复5小时窗口费用不重置的问题
- 新增 GetCurrentWindowStartTime() 方法,当窗口过期时自动使用新的预测窗口开始时间
- UpdateSessionWindow 更新窗口时间后触发 outbox 事件同步调度器缓存
- 统一所有窗口费用查询入口使用新方法
2026-01-19 09:13:15 +08:00
shaw
46ae08ecb7 fix: 补充测试桩缺失的接口方法 2026-01-18 22:23:03 +08:00
shaw
2028cc29b7 fix: 修复多个管理后台问题
- 分页接口 page_size 最大限制从 100 改为 1000
- 通过 Redis Pub/Sub 实现跨实例认证缓存失效
- 允许订阅类型分组编辑计费倍率
- 账号计费倍率支持 3 位小数
2026-01-18 22:13:47 +08:00
shaw
f6360e0bf3 fix: 移除未使用的 extractSessionUUID 函数
修复 golangci-lint unused 检查报错
2026-01-18 20:15:02 +08:00
shaw
9abda1bc59 feat(tls): 新增 TLS 指纹模拟功能 2026-01-18 20:08:40 +08:00
yangjianbo
2a94cc76a6 fix(软删除): 增强错误处理,确保软删除操作中的错误类型正确 2026-01-18 16:51:26 +08:00
yangjianbo
150b315a7b fix(软删除): 修复软删除钩子的客户端调用逻辑,确保正确处理变更 2026-01-18 16:46:22 +08:00
shaw
a07174c191 fix: 修复会话限制功能并在创建账号时支持配额控制 2026-01-18 16:41:15 +08:00
yangjianbo
fb839ae6ca fix(软删除): 修复删除钩子调用链并跳过无Docker测试
软删除钩子改用 next.Mutate 处理更新,避免 mutation 类型不匹配
集成测试检测 Docker 可用性,无 Docker 自动跳过
2026-01-18 16:10:54 +08:00
yangjianbo
bdc426a774 Merge branch 'main' into dev 2026-01-18 15:55:58 +08:00
Wesley Liddick
32fff3798c Merge pull request #326 from geminiwen/main
feat(admin): 添加账号管理和订阅管理的列设置功能
2026-01-18 14:41:02 +08:00
Wesley Liddick
2b02c6635d Merge pull request #323 from IanShaw027/fix/ops-error-classification-consistency
fix(ops): 统一 request-errors 接口与 SLA 计算的错误分类逻辑
2026-01-18 14:32:04 +08:00
yangjianbo
771baa66ee feat(界面): 优化分页跳转与页大小显示
分页组件支持隐藏每页条数选择器并新增跳转页输入
清理任务列表启用跳转页并固定每页 5 条
补充中英文分页文案
2026-01-18 14:31:22 +08:00
Wesley Liddick
a82029b0cf Merge pull request #318 from IanShaw027/main
fix(openai): OpenCode 兼容性增强 - 工具过滤和粘性会话修复
2026-01-18 14:30:53 +08:00
Wesley Liddick
0c2a901af4 Merge pull request #317 from IanShaw027/fix/gemini-issue
fix(gemini,group): 更新 Gemini 模型配置并补齐 SIMPLE 默认分组
2026-01-18 14:30:42 +08:00
yangjianbo
bd18f4b8ef feat(清理任务): 引入Ent存储并补充日志与测试
新增 usage_cleanup_task Ent schema 与仓储实现,支持清理任务排序分页
补充清理任务全链路日志、仪表盘重算触发及 UI 过滤调整
完善 repository/service 单测并引入 sqlite 测试依赖
2026-01-18 14:18:28 +08:00
yangjianbo
bf7b79f2f0 fix(数据库): 优化任务状态更新查询,使用别名提高可读性 2026-01-18 11:58:53 +08:00
Gemini Wen
45e8598d32 feat(admin): 添加账号管理和订阅管理的列设置功能
- 账号管理新增代理列显示和列设置下拉菜单
- 订阅管理新增列设置,支持用户列在邮箱/用户名间切换
- 列设置持久化到 localStorage
- 统一列设置图标样式

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 11:57:07 +08:00
yangjianbo
8391d480c9 fix(数据库): 补充分组列回填前置保护 2026-01-18 11:52:59 +08:00
yangjianbo
d17f853a5f fix(数据库): 补充允许分组列兼容迁移 2026-01-18 11:45:43 +08:00
yangjianbo
ef5a41057f feat(usage): 添加清理任务与统计过滤 2026-01-18 10:52:18 +08:00
song
c115c9e048 fix: address lint errors 2026-01-18 01:22:40 +08:00
song
6941315432 feat: add antigravity web search support 2026-01-18 01:09:40 +08:00
song
8b071cc665 fix(antigravity): restore signature retry and base order 2026-01-17 22:50:50 +08:00
song
959f6c538a fix(antigravity): remove thinking sanitation 2026-01-17 22:21:48 +08:00
song
217b3b59c0 fix(antigravity): drop MarkUnavailable 2026-01-17 21:59:32 +08:00
song
ec916a3197 fix(antigravity): remove signature retry 2026-01-17 21:56:57 +08:00
song
22eb72e0f9 fix(antigravity): restore url fallback behavior 2026-01-17 21:50:09 +08:00
song
07ba64c666 fix(antigravity): handle url-level 429 without failover 2026-01-17 21:37:32 +08:00
song
f22bc59fe3 fix(antigravity): route signature retry through url fallback 2026-01-17 21:15:33 +08:00
song
0ce8666cc0 Revert "Revert "fix(antigravity): Claude 模型透传 tool_use 的 signature""
This reverts commit 5427a9e422.
2026-01-17 21:09:59 +08:00
song
5427a9e422 Revert "fix(antigravity): Claude 模型透传 tool_use 的 signature"
This reverts commit 81b865b89d.
2026-01-17 20:41:06 +08:00
song
5e9f5efbe3 chore: log antigravity signature retry 429 2026-01-17 18:22:53 +08:00
song
a7a0017aa8 chore: gofmt antigravity gateway service 2026-01-17 18:22:43 +08:00
song
9078b17a41 test: add antigravity rate limit coverage 2026-01-17 18:15:45 +08:00
song
14a3694a9a chore: set antigravity fallback cooldown default to 1 2026-01-17 18:03:45 +08:00
song
b9b4db3df5 Merge upstream/main 2026-01-17 18:00:07 +08:00
ianshaw
bc1d7edc58 fix(ops): 统一 request-errors 和 SLA 的错误分类逻辑
修复 request-errors 接口与 Dashboard Overview SLA 计算不一致的问题:
- errors 视图现在只排除业务限制错误(余额不足、并发限制等)
- 上游 429/529 错误现在包含在 errors 视图中,与 SLA 计算保持一致
- excluded 视图现在只显示业务限制错误

这确保了 request-errors 接口和 Dashboard 的 error_count_sla 使用相同的过滤逻辑。
2026-01-17 17:57:40 +08:00
song
5a6f60a954 fix(antigravity): 区分 URL 级别和账户配额级别的 429 限流
- "Resource has been exhausted" → URL 级别限流,立即切换 URL
- "exhausted your capacity on this model" → 账户配额限流,重试 3 次(指数退避)后标记限流
2026-01-17 11:11:18 +08:00
IanShaw027
a61cc2cb24 fix(openai): 增强 Codex 工具过滤和参数标准化
- codex_transform: 过滤无效工具,支持 Responses-style 和 ChatCompletions-style 格式
- tool_corrector: 添加 fetch 工具映射,修正 bash/edit 参数命名规范
2026-01-17 11:00:07 +08:00
song
31933c8a60 fix: 删除未使用的字段修复 lint 错误 2026-01-17 10:40:28 +08:00
song
78bccd032d refactor(antigravity): 提取公共重试循环函数减少重复代码
- 新增 antigravityRetryLoop 函数统一处理 Forward 和 ForwardGemini 的重试逻辑
- 429 日志增加 base_url 字段便于调试
- 删除重复的 shouldRetryUpstreamError 方法
2026-01-17 10:28:31 +08:00
IanShaw027
ae21db77ec fix(openai): 使用 prompt_cache_key 兜底粘性会话
opencode 请求不带 session_id/conversation_id,导致粘性会话失效。现在按 header 优先、prompt_cache_key 兜底生成 session hash,并补充单测验证优先级。
2026-01-17 02:31:16 +08:00
song
ac7503d95f fix(antigravity): 429 时也切换 URL 重试
- 429 优先切换到下一个 URL 重试
- 只有所有 URL 都返回 429 时才限流账户并返回错误
- 与 client.go 中的逻辑保持一致
2026-01-17 02:14:57 +08:00
song
69c4b17a9b feat(antigravity): 动态 URL 排序,最近成功的优先使用
- URLAvailability 新增 lastSuccess 字段追踪最近成功的 URL
- GetAvailableURLs 返回列表时优先放置 lastSuccess
- 所有 Antigravity API 调用成功后调用 MarkSuccess 更新优先级
2026-01-17 01:54:14 +08:00
IanShaw027
a7165b0f73 fix(group): SIMPLE 模式启动补齐默认分组 2026-01-17 01:53:51 +08:00
song
cc0fca35ec feat(antigravity): 同步 Antigravity-Manager 的请求逻辑
- System Prompt: 改为简短版,添加 OpenCode 过滤、MCP XML 协议注入、SYSTEM_PROMPT_END 标记
- HTTP Headers: 只保留 Content-Type/Authorization/User-Agent,移除 Accept 和 Host
- User-Agent: 改为 antigravity/1.11.9 windows/amd64
- requestType: 动态判断 (agent/web_search/image_gen)
- BaseURLs: 添加 daily sandbox 备用 URL
- Fallback: 扩展触发条件 (429/408/404/5xx)
2026-01-17 01:49:42 +08:00
Wesley Liddick
dae0d5321f Merge pull request #315 from mt21625457/main
perf(前端): 优化页面加载性能和用户体验 和 修复静态 import 导致入口文件膨胀问题
2026-01-16 23:55:45 +08:00
shaw
34415db7ed fix: 修复 api_contract_test 缺少 SessionLimitCache 参数的问题 2026-01-16 23:53:54 +08:00
IanShaw027
28e46e0e7c fix(gemini): 更新 Gemini 模型列表配置
- 移除已弃用的 1.5 系列模型
- 调整模型优先级顺序(2.0 Flash > 2.5 Flash > 2.5 Pro > 3.0 Preview)
- 同步前后端模型配置
- 更新相关测试用例和默认模型选择逻辑
2026-01-16 23:47:42 +08:00
shaw
7379423325 feat: 添加5h窗口费用控制和会话数量限制
- 支持Anthropic OAuth/SetupToken账号的5h窗口费用阈值控制
- 支持账号级别的并发会话数量限制
- 使用Redis缓存窗口费用(30秒TTL)减少数据库压力
- 费用计算基于标准费用(不含账号倍率)
2026-01-16 23:36:52 +08:00
yangjianbo
74a3c74514 perf(Caddy): 添加静态资源长期缓存配置
- 为 /assets/* 设置 1 年缓存 + immutable 标记
- 包含 logo.png 和 favicon.ico
- 移除可能干扰缓存的 Pragma/Expires 头

效果:
- 浏览器缓存命中后不再发送请求
- Cloudflare CDN 可正确缓存静态资源
- 重复访问页面秒开

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 22:49:05 +08:00
程序猿MT
3d6d131889 Merge branch 'Wei-Shaw:main' into main 2026-01-16 22:31:14 +08:00
yangjianbo
b0569d873a perf(路由预加载): 修复静态 import 导致入口文件膨胀问题
问题:
- 原实现使用静态 import() 映射表
- Rollup 静态分析时将所有 37 个视图组件引用打包进 index.js
- 导致首次加载时需要解析大量未使用的 import 语句

修复:
- 移除静态 import() 映射,改用纯路径字符串邻接表
- 通过 router.getRoutes() 动态获取组件的 import 函数
- 延迟初始化 routePrefetch,首次导航时才创建实例
- 更新测试文件使用 mock router

效果:
- index.js 中动态 import 引用从 37 个减少到 1 个
- 首次加载不再包含未使用的视图组件引用
- 41 个测试全部通过

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 22:07:39 +08:00
yangjianbo
d9433699db Merge branch 'dev' of https://github.com/mt21625457/aicodex2api into dev 2026-01-16 21:49:28 +08:00
yangjianbo
92234857f7 perf(前端): 优化页面加载性能和用户体验
- 添加路由预加载功能,使用 requestIdleCallback 在浏览器空闲时预加载
- 配置 Vite manualChunks 分离 vendor 库(vue/ui/chart/i18n/misc)
- 新增 NavigationProgress 导航进度条组件,支持防闪烁和无障碍
- 集成 Vitest 测试框架,添加 40 个单元测试和集成测试
- 支持 prefers-reduced-motion 和暗色模式

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 21:48:57 +08:00
yangjianbo
8efa361728 perf(前端): 优化页面加载性能和用户体验
- 添加路由预加载功能,使用 requestIdleCallback 在浏览器空闲时预加载
- 配置 Vite manualChunks 分离 vendor 库(vue/ui/chart/i18n/misc)
- 新增 NavigationProgress 导航进度条组件,支持防闪烁和无障碍
- 集成 Vitest 测试框架,添加 40 个单元测试和集成测试
- 支持 prefers-reduced-motion 和暗色模式

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 21:43:39 +08:00
song
1be3eacad5 feat(scheduling): 兜底层账户选择策略可配置
- gateway.scheduling.fallback_selection_mode: "last_used"(默认) 或 "random"
- last_used: 按最后使用时间排序(轮询效果)
- random: 同优先级内随机选择
2026-01-16 20:47:07 +08:00
song
34d6b0a601 feat(gateway): 账户切换次数和 Antigravity 限流时间可配置
- gateway.max_account_switches: 账户切换最大次数,默认 10
- gateway.max_account_switches_gemini: Gemini 账户切换次数,默认 3
- gateway.antigravity_fallback_cooldown_minutes: Antigravity 429 fallback 限流时间,默认 5 分钟
- Antigravity 429 不再重试,直接标记账户限流
2026-01-16 20:18:30 +08:00
yangjianbo
eb432a49ed perf(前端): 移除 Google Fonts 改用系统字体栈
- 删除 Google Fonts @import,解决国内访问阻塞问题
- 使用 system-ui 优先的系统字体栈
- 添加中文字体支持(苹方、冬青黑、微软雅黑)
- 移除 Inter 字体专用的 font-feature-settings

此改动可显著提升国内用户的页面加载速度,避免因 Google Fonts
被墙导致的渲染阻塞问题。

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 19:49:33 +08:00
Wesley Liddick
04811c00cb Merge pull request #313 from mt21625457/main
fix(ci): 修复各类bug
2026-01-16 19:30:07 +08:00
Wesley Liddick
06093d4f79 Merge pull request #311 from longgexx/main
添加分组级别模型路由配置功能(Anthropic平台)
2026-01-16 19:22:12 +08:00
song
2055a60bcb fix(antigravity): 429 重试3次后限流账户
- 收到429后重试最多3次(指数退避)
- 3次都失败后调用 handleUpstreamError 限流账户
- 移除无效的 URL fallback 逻辑(当前只有一个URL)
2026-01-16 18:51:07 +08:00
song
cc892744bc fix(antigravity): 429 fallback 改为 5 分钟并限流整个账户
- fallback 时间从 1 分钟改为 5 分钟
- fallback 时直接限流整个账户而非仅限制 quota scope
2026-01-16 18:09:34 +08:00
longgexx
577ee16108 Merge branch 'main' of github.com:longgexx/sub2api 2026-01-16 17:35:44 +08:00
longgexx
392a8ac7ea 修复格式问题。 2026-01-16 17:35:17 +08:00
longgexx
226920064b Merge branch 'Wei-Shaw:main' into main 2026-01-16 17:26:54 +08:00
longgexx
19865b865f feat(group): 添加分组级别模型路由配置功能
支持为分组配置模型路由规则,可以指定特定模型模式优先使用的账号列表。

  - 新增 model_routing 字段存储路由配置(JSONB格式,支持通配符匹配)

  - 新增 model_routing_enabled 字段控制是否启用路由

  - 更新后端 handler/service/repository 支持路由配置的增删改查

  - 更新前端 GroupsView 添加路由配置界面

  - 添加数据库迁移脚本 040/041
2026-01-16 17:26:05 +08:00
yangjianbo
e3f812c2fe fix(安全): CSP 策略自动增强,无需配置文件修改即可生效
- 添加 enhanceCSPPolicy() 自动增强任何 CSP 策略
- 自动添加 nonce 占位符(如果策略中没有)
- 自动添加 Cloudflare Insights 域名
- 即使配置文件使用旧策略也能正常工作
- 添加 enhanceCSPPolicy 和 addToDirective 单元测试

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 17:20:39 +08:00
yangjianbo
c9f79dee66 feat(安全): 实现 CSP nonce 支持解决内联脚本安全问题
- 添加 GenerateNonce() 生成加密安全的随机 nonce
- SecurityHeaders 中间件为每个请求生成唯一 nonce
- CSP 策略支持 __CSP_NONCE__ 占位符动态替换
- embed_on.go 注入的内联脚本添加 nonce 属性
- 添加 Cloudflare Insights 域名到 CSP 允许列表
- 添加完整单元测试,覆盖率达到 89.8%

解决的问题:
- 内联脚本违反 CSP script-src 指令
- Cloudflare Insights beacon.min.js 加载被阻止

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 17:05:49 +08:00
yangjianbo
c659788022 fix(前端路由): 添加 chunk 加载错误自动恢复机制
- 检测动态导入模块加载失败错误
- 自动刷新页面获取最新资源
- 使用 sessionStorage 防止无限刷新循环(10秒冷却)
- 解决前端重新部署后用户缓存导致的加载失败问题

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 16:20:10 +08:00
yangjianbo
aeb987ceb1 Merge branch 'dev' 2026-01-16 15:35:38 +08:00
yangjianbo
b478982484 style(界面): 修复 AccountsView ESLint 报错
移除多余分号并整理 onMounted 代码块,确保 lint:check 通过。
2026-01-16 15:25:47 +08:00
yangjianbo
fe71ee57b3 fix(定时轮): 初始化失败返回错误并补充单测
- NewTimingWheelService 改为返回 error,避免 panic

- ProvideTimingWheelService 透传 error 并更新 wire 生成代码

- 补充定时任务调度/取消/周期任务相关单元测试
2026-01-16 15:25:33 +08:00
song
fba3d21a35 fix: 使用 Contains 匹配 missing_project_id 并修复测试 mock 2026-01-16 14:18:12 +08:00
song
455576300c fix(antigravity): 使用 Contains 匹配 missing_project_id 错误信息 2026-01-16 14:03:25 +08:00
song
821968903c feat(antigravity): 手动刷新令牌时自动恢复 missing_project_id 错误账户状态
- 当手动刷新成功获取到 project_id,且之前错误为 missing_project_id 时,自动清除错误状态
- 后台自动刷新时同样支持状态恢复
2026-01-16 13:18:00 +08:00
wfunc
452fa53c0d feat: Claude Sonnet 429 仅限模型限流 2026-01-16 13:03:04 +08:00
song
95fe1e818f fix: Antigravity 刷新 token 时检测 project_id 缺失
- 刷新 token 后调用 LoadCodeAssist 获取 project_id
- 如果获取失败,保留原有 project_id,标记账户为 error
- token 仍会正常更新,不影响凭证刷新
- 错误信息:账户缺少project id,可能无法使用Antigravity
2026-01-16 12:13:54 +08:00
song
a61042bca0 fix: Antigravity project_id 获取优化
- API URL 改为只使用 prod 端点
- 刷新 token 时每次调用 LoadCodeAssist 更新 project_id
- 移除随机生成 project_id 的兜底逻辑
2026-01-16 11:57:14 +08:00
song
b4abfae4de fix: Antigravity 测试连接使用最小 token 消耗
- buildGeminiTestRequest: 输入 "." + maxOutputTokens: 1
- buildClaudeTestRequest: 输入 "." + MaxTokens: 1
- buildGenerationConfig: 支持透传 MaxTokens 参数
2026-01-16 10:31:55 +08:00
Wesley Liddick
c02c8646a6 Merge pull request #304 from IanShaw027/feature/codex-tool-correction
feat(openai): 添加Codex工具调用自动修正功能
2026-01-16 08:50:11 +08:00
Wesley Liddick
3ff2ca8d41 Merge pull request #303 from IanShaw027/feature/ops-account-health-score
feat(ops): 运维监控功能增强与优化
2026-01-16 08:49:52 +08:00
IanShaw027
415840088e fix(lint): 修复剩余的errcheck错误
修复了测试文件中剩余的6处类型断言未检查错误:
- 第115-118行:choices.message.tool_calls 的类型断言链
- 第140和145行:multiple tool calls 测试的类型断言
- 第343和345行:ComplexSSEData 测试的类型断言

**修复模式:**
所有类型断言都改为使用 ok 检查:
```go
// 修复前
choices := payload["choices"].([]any)

// 修复后
choices, ok := payload["choices"].([]any)
if !ok || len(choices) == 0 {
    t.Fatal("No choices found in result")
}
```

**测试验证:**
-  TestCorrectToolCallsInSSEData - 所有子测试通过
-  TestComplexSSEData - 通过
-  TestCorrectToolParameters - 通过
-  所有类型断言都有 ok 检查
-  添加了数组长度验证

现在所有 errcheck 错误都已修复。
2026-01-16 00:14:19 +08:00
IanShaw027
c4f6c89b65 fix(lint): 修复golangci-lint检查发现的问题
修复了4个lint问题:
1. errcheck (3处): 在测试中添加类型断言的ok检查
2. govet copylocks (1处): 将mutex从ToolCorrectionStats移到CodexToolCorrector

**详细修改:**

1. **openai_tool_corrector_test.go**
   - 添加了类型断言的ok检查,避免panic
   - 在解析JSON后检查payload结构的有效性
   - 改进错误处理和测试可靠性

2. **openai_tool_corrector.go**
   - 将sync.RWMutex从ToolCorrectionStats移到CodexToolCorrector
   - 避免在GetStats()返回时复制mutex
   - 保持线程安全的同时符合Go最佳实践

**测试验证:**
- 所有单元测试通过 
- go vet 检查通过 
- 代码编译正常 
2026-01-16 00:02:22 +08:00
IanShaw027
539b41f421 feat(openai): 添加Codex工具调用自动修正功能
实现了完整的Codex工具调用拦截和自动修正系统,解决OpenCode使用Codex模型时的工具调用兼容性问题。

**核心功能:**

1. **工具名称自动映射**
   - apply_patch/applyPatch → edit
   - update_plan/updatePlan → todowrite
   - read_plan/readPlan → todoread
   - search_files/searchFiles → grep
   - list_files/listFiles → glob
   - read_file/readFile → read
   - write_file/writeFile → write
   - execute_bash/executeBash/exec_bash/execBash → bash

2. **工具参数自动修正**
   - bash: 自动移除不支持的 workdir/work_dir 参数
   - edit: 自动将 path 参数重命名为 file_path
   - 支持 JSON 字符串和对象两种参数格式

3. **流式响应集成**
   - 在 SSE 数据流中实时修正工具调用
   - 支持多种 JSON 结构(tool_calls, function_call, delta, choices等)
   - 不影响响应性能和用户体验

4. **统计和监控**
   - 记录每次工具修正的详细信息
   - 提供修正统计数据查询
   - 便于问题排查和性能优化

**实现文件:**
- `openai_tool_corrector.go`: 工具修正核心逻辑(250行)
- `openai_tool_corrector_test.go`: 完整的单元测试(380+行)
- `openai_gateway_service.go`: 流式响应集成
- `openai_gateway_service_tool_correction_test.go`: 集成测试

**测试覆盖:**
- 工具名称映射测试(18个映射规则)
- 参数修正测试(bash workdir、edit path等)
- SSE数据修正测试(多种JSON结构)
- 统计功能测试
- 所有测试通过 

**解决的问题:**
修复了 OpenCode 使用 sub2api 中转 Codex 时,因工具名称和参数不兼容导致的工具调用失败问题。
Codex 模型有时会忽略指令文件中的工具映射说明,导致调用不存在的工具(如 apply_patch)。
现在通过流式响应拦截,自动将错误的工具调用修正为 OpenCode 兼容的格式。

**参考文档:**
- OpenCode 工具规范: https://opencode.ai/docs/
- Codex Bridge 指令: backend/internal/service/prompts/codex_opencode_bridge.txt
2026-01-15 23:52:50 +08:00
IanShaw027
b2ff326ced fix(ops): 调整健康分数权重以修复CI测试
- 将业务健康和基础设施健康的权重从80/20调整为70/30
- 使基础设施故障(DB/Redis)对总分影响更明显
- 修复三个失败的测试用例:
  * DB故障: 92→88 (期望70-90)
  * Redis故障: 96→94 (期望85-95)
  * 业务降级: 82→84.5 (期望84-85)
2026-01-15 23:02:15 +08:00
IanShaw027
8b95d16220 refactor(ops): 简化自动刷新定时器逻辑
- 合并双定时器为单一倒计时定时器
- 倒计时归零时触发数据刷新
- 添加自定义时间范围的安全回退
2026-01-15 22:07:23 +08:00
IanShaw027
a478822b8e refactor(ops): 优化文案显示
- TTFT 定义统一改为"首 Token"/"First Token"(而非"首字节"/"first byte")
- 请求时长卡片标题去掉"(毫秒)"/"(ms)"后缀
2026-01-15 21:37:35 +08:00
IanShaw027
23aa69f56f refactor(ops): 优化任务心跳和组件刷新机制
后端改动:
- 添加 ops_job_heartbeats.last_result 字段记录任务执行结果
- 优化告警评估器统计信息(规则数/事件数/邮件数)
- 统一各定时任务的心跳记录格式

前端改动:
- 重构 OpsConcurrencyCard 使用父组件统一控制刷新节奏
- 移除独立的 5 秒刷新定时器,改用 refreshToken 机制
- 修复 TypeScript 类型错误
2026-01-15 21:31:55 +08:00
Wesley Liddick
b36f3db9de Merge pull request #300 from mt21625457/main
feat(网关): 引入 OpenAI/Claude OAuth token 缓存
2026-01-15 20:10:16 +08:00
IanShaw027
e93f086485 fix(ops): 请求时长详情显示所有请求
- 移除请求时长卡片详情按钮的 min_duration_ms 参数限制
- 现在点击详情会显示所有请求,按时长倒序排列
- 不再只显示 P99 以上的请求
2026-01-15 19:57:19 +08:00
IanShaw027
930e9ee55c feat(ops): 添加自定义时间范围选择功能
功能特性:
- 在时间段选择器中增加"自定义"选项
- 点击后弹出对话框,支持选择任意时间范围
- 使用 HTML5 datetime-local 输入框,体验友好
- 自定义时显示格式化的时间范围标签(MM-DD HH:mm ~ MM-DD HH:mm)
- 默认初始化为最近1小时

技术实现:
- 扩展 TimeRange 类型支持 'custom'
- 添加 customStartTime 和 customEndTime 状态管理
- 创建 buildApiParams 辅助函数统一处理 API 参数
- 当选择自定义时,使用 start_time 和 end_time 参数替代 time_range
- 更新所有相关 API 调用支持自定义时间范围

国际化:
- 添加"自定义"、"开始时间"、"结束时间"翻译
2026-01-15 19:50:47 +08:00
IanShaw027
38961ba10e refactor(ops): 优化阈值检查系统和布局
阈值检查系统优化:
- 引入三级阈值系统(normal/warning/critical)
- 统一阈值判断逻辑,支持警告和严重两个级别
- 移除硬编码的 TTFT 颜色判断,改用阈值配置
- 新增 getThresholdColorClass 统一颜色映射

布局优化:
- 优化详细指标在卡片内的响应式布局
- 改进宽屏下的卡片布局显示
- 优化指标数值的对齐和间距
2026-01-15 19:50:31 +08:00
IanShaw027
93b5b7474b refactor(ops): 调整健康得分权重
- 业务健康权重从 70% 提升到 80%
- 基础设施健康权重从 30% 降低到 20%
- 更加关注业务指标(SLA、错误率等)对整体健康的影响
2026-01-15 19:50:02 +08:00
yangjianbo
f862ddc9ff style: 修复 gofmt 格式化问题
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 19:42:18 +08:00
程序猿MT
b59032304c Merge branch 'Wei-Shaw:main' into main 2026-01-15 19:34:11 +08:00
yangjianbo
3ba4d535e3 Merge branch 'dev' 2026-01-15 19:33:32 +08:00
yangjianbo
5b37e9aea4 fix(OAuth缓存): 修复缓存键冲突、401强制刷新及Redis降级处理
- Gemini 缓存键统一增加 gemini: 前缀,避免与其他平台命名空间冲突
- OAuth 账号 401 错误时设置 expires_at=now 并持久化,强制下次请求刷新 token
- Redis 锁获取失败时降级为无锁刷新,仅在 token 接近过期时执行,并检查 ctx 取消状态

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 19:08:07 +08:00
yangjianbo
1820389a05 feat(网关): 引入 OpenAI/Claude OAuth token 缓存
新增 OpenAI/Claude TokenProvider 与缓存键生成
扩展 OAuth 缓存失效覆盖更多平台
统一 OAuth 缓存前缀与依赖注入
2026-01-15 18:27:06 +08:00
yangjianbo
35e3a89385 fix(忽略文件): 添加 .serena 目录到 .gitignore 2026-01-15 16:34:42 +08:00
shaw
5f890e85e7 Merge PR #296: OAuth token cache invalidation on 401 and refresh 2026-01-15 16:31:37 +08:00
Wesley Liddick
10bc7f7042 Merge pull request #297 from LLLLLLiulei/feat/ip-management-enhancements
feat: add proxy geo location
2026-01-15 16:06:36 +08:00
yangjianbo
a65fd9dee8 Merge branch 'test' 2026-01-15 15:58:47 +08:00
yangjianbo
1bb4c76deb fix(账号管理): 移除调度切换后的冗余列表刷新
切换账号调度状态后,updateSchedulableInList 已完成局部更新,
无需再调用 load() 刷新整个列表。此修改减少不必要的 API 请求,
避免 UI 闪烁。

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 15:57:47 +08:00
LLLLLLiulei
aab44f9fc8 feat: add proxy geo location 2026-01-15 15:15:20 +08:00
yangjianbo
0a848e7578 Merge branch 'test' 2026-01-15 15:15:07 +08:00
yangjianbo
90bce60b85 feat: merge dev 2026-01-15 15:14:44 +08:00
程序猿MT
c22d51ee41 Merge branch 'Wei-Shaw:main' into main 2026-01-15 15:12:16 +08:00
yangjianbo
a458e684bc fix(认证): OAuth 401 直接标记错误状态
- OAuth 401 清理缓存并设置错误状态

- 移除 oauth_401_cooldown_minutes 配置及示例

- 更新 401 相关单测

破坏性变更: OAuth 401 不再临时不可调度,需手动恢复
2026-01-15 15:06:34 +08:00
LLLLLLiulei
87b4662993 Revert "feat: add proxy geo location"
This reverts commit 09c4f82927ddce1c9528c146a26457f53d02b034.
2026-01-15 15:01:50 +08:00
LLLLLLiulei
3a100339b9 feat: add proxy geo location 2026-01-15 15:01:50 +08:00
Wesley Liddick
47eb3c8888 Merge pull request #284 from longgexx/main
fix(admin): 修复使用记录页面趋势图筛选联动和日期选择问题
2026-01-15 13:43:00 +08:00
longgexx
4672a6fac3 merge: 合并上游 upstream/main 分支
解决冲突:
- usage_log_repo.go: 保留 groupID 和 stream 参数,同时合并上游的
  account_rate_multiplier 逻辑

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 13:09:26 +08:00
longgexx
82743704e4 fix(test): 添加测试辅助函数 truncateToDayUTC 修复编译错误
在 usage_log_repo_integration_test.go 中添加本地的 truncateToDayUTC
辅助函数,修复因主代码重命名该函数导致的测试编译错误。

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 11:42:35 +08:00
Wesley Liddick
cc2d064ab4 Merge pull request #289 from IanShaw027/fix/mobile-ui
fix(mobile): 优化移动端表格、操作栏和弹窗显示
2026-01-15 11:31:16 +08:00
Wesley Liddick
27214f8657 Merge pull request #285 from IanShaw027/fix/ops-bug
feat(ops): 增强错误日志管理、告警静默和前端 UI 优化
2026-01-15 11:26:16 +08:00
Wesley Liddick
28de614dfb Merge pull request #282 from LLLLLLiulei/feat/ip-management-enhancements
feat: enhance proxy management
2026-01-15 11:23:17 +08:00
longgexx
850183c269 fix(dashboard): 修复预聚合表使用UTC时区导致今日统计不准确的问题
将 dashboard_aggregation_repo.go 和 usage_log_repo.go 中的时区处理
从 UTC 改为使用服务器配置时区(默认 Asia/Shanghai),确保"今日"
统计数据与用户预期一致。

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 11:22:13 +08:00
shaw
2a5ef6d3f5 Merge PR #279: feat(计费): 账号计费倍率快照与账号口径费用统计 2026-01-15 11:11:55 +08:00
IanShaw027
1d231c6cc3 fix(mobile): 修复 UsersView 更多菜单定位并统一逻辑
**问题描述**:
- UsersView 的"更多"菜单仍然出现在页面左上角错误位置
- UsersView 使用 actionButtonRefs Map 获取按钮元素,导致定位失败
- UsersView 和 AccountsView 的菜单定位逻辑不一致,难以维护

**解决方案**:
- 修改 openActionMenu 函数签名,添加 MouseEvent 参数
- 使用 e.currentTarget 直接从事件对象获取触发元素
- 移除不必要的 actionButtonRefs Map 和 setActionButtonRef 函数
- 统一菜单宽度为 200px(与 AccountsView 一致)
- 完全复制 AccountsView 的定位逻辑,确保两者行为一致

**技术要点**:
- 移动端:菜单居中对齐按钮,优先显示在按钮下方
- 桌面端:使用鼠标位置定位,添加边界检测
- 简化代码,移除不必要的防御性检查
- 两个组件的菜单定位逻辑完全一致,便于维护
2026-01-15 10:21:35 +08:00
IanShaw027
20c71acb3b fix(mobile): 优化移动端表格、操作栏和弹窗显示
**问题描述**:
- 表格在移动端显示列过多,需要横向滚动,内容被截断
- 顶部操作栏按钮拥挤,占用过多空间
- 弹窗表单在小屏幕上布局不合理
- "更多"操作菜单定位错误,位置过高或超出屏幕
- 滚动页面时菜单不会自动关闭,与卡片分离

**解决方案**:

1. **DataTable 组件 - 移动端卡片视图**
   - 在 < 768px 时自动切换到卡片布局
   - 每个表格行渲染为独立卡片,所有字段清晰可见
   - 操作按钮在卡片底部,触摸目标足够大
   - 支持深色模式,包含加载和空状态
   - 自动应用于所有使用 DataTable 的管理页面

2. **UsersView 顶部操作栏优化**
   - 移动端:搜索框全宽 + 次要按钮显示为图标 + 创建按钮突出
   - 桌面端:保持原有布局(图标 + 文字)
   - 使用响应式 Tailwind classes

3. **UserCreateModal 弹窗优化**
   - 余额/并发数字段:移动端单列,桌面端双列
   - 弹窗边距:移动端 8px,桌面端 16px

4. **操作菜单定位修复**
   - UsersView: 移动端菜单居中对齐按钮,智能定位
   - AccountsView: 移动端菜单优先显示在按钮下方
   - 所有情况下确保菜单不超出屏幕边界
   - 添加滚动监听,滚动时自动关闭菜单

**影响范围**:
- 所有使用 DataTable 的管理页面(8 个页面)自动获得移动端卡片视图
- 用户管理和账号管理页面的操作菜单定位优化
- 创建用户弹窗的响应式布局优化

**技术要点**:
- 使用 Tailwind 响应式断点(md:, sm:)
- 触摸目标 ≥ 44px
- 完整支持深色模式
- 向后兼容,桌面端保持原有布局
2026-01-15 10:08:14 +08:00
yangjianbo
52ad7c6e9c Merge branch 'main' into dev 2026-01-15 09:30:29 +08:00
longgexx
5aaaffe4d1 fix(dashboard): 修复仪表盘今日统计使用UTC时区的问题
将仪表盘统计中的"今日"时间范围从UTC时区改为服务器配置时区,
使其与使用记录页面保持一致。

修改内容:
- GetDashboardStats: 使用 timezone.Now() 和 timezone.Today()
- GetDashboardStatsWithRange: 同上

影响的统计项:
- 今日请求 (TodayRequests)
- 今日 Token (TodayTokens)
- 今日费用 (TodayCost/TodayActualCost)
- 今日新用户 (TodayNewUsers)
- 今日活跃用户 (ActiveUsers)
2026-01-15 09:12:16 +08:00
IanShaw027
5354ba3662 fix(ops): 修复错误列表用户显示并区分上游错误和请求错误
- 修复错误列表中用户列显示 \n 的问题
- 上游错误显示账号(account),请求错误显示用户(user)
- 错误详情模态框同步调整显示逻辑
- 添加 accountId 国际化翻译
2026-01-15 00:11:44 +08:00
IanShaw027
2daf13c4c8 style(backend): 修复 ops_service.go 代码格式 2026-01-15 00:04:37 +08:00
IanShaw027
16a90f3d3a fix(types): 在 OpsErrorLog 类型中添加 user_email 字段
- 修复 TypeScript 编译错误
- 添加 user_email 字段到 OpsErrorLog 接口
2026-01-15 00:03:34 +08:00
IanShaw027
8a0ff15242 feat(i18n): 添加用户相关的国际化翻译
- 在中英文 i18n 文件中添加 errorLog.user 和 errorLog.userId
- 在中英文 i18n 文件中添加 errorDetail.user
- 支持错误日志和详情中的用户信息显示
2026-01-15 00:02:32 +08:00
IanShaw027
8c993dfd35 refactor(frontend): 将账号显示替换为用户显示
- 在错误日志表格中将账号列替换为用户列
- 在错误详情模态框中将账号信息替换为用户信息
- 显示用户邮箱而不是账号名称
- 上游错误的账号信息保留在上游错误上下文中
2026-01-14 23:59:26 +08:00
IanShaw027
2a6fb1e456 feat(ops): 添加用户信息显示和搜索功能
- 在错误日志列表和详情中显示用户邮箱
- 在 GetErrorLogByID 中关联 users 表获取用户邮箱
- 在 OpsErrorLogFilter 中添加 UserQuery 字段
- 在 buildOpsErrorLogsWhere 中添加用户邮箱搜索条件
- 在 GetErrorLogs handler 中支持 user_query 参数
2026-01-14 23:56:45 +08:00
IanShaw027
9e6cd36af4 feat(ops): 添加上游响应体字段到错误事件
- 在 OpsUpstreamErrorEvent 中添加 UpstreamResponseBody 字段
- 用于存储上游服务返回的响应内容
- 区分客户端响应和上游响应
2026-01-14 23:52:36 +08:00
IanShaw027
f25f992a30 fix(ops): 错误详情中显示账号和分组名称
- 在 GetErrorLogByID 查询中添加 LEFT JOIN 关联查询
- 关联 accounts 和 groups 表获取名称
- 填充 AccountName 和 GroupName 字段
2026-01-14 23:52:00 +08:00
IanShaw027
841d7ef2f2 fix(lint): 修复 golangci-lint 检查问题
- 格式化代码(gofmt)
- 修复空指针检查(staticcheck)
- 删除未使用的函数(unused)
2026-01-14 23:49:27 +08:00
IanShaw027
a7a49be850 refactor(ops): 使用TTFT替代Duration作为健康分数指标
- 业务健康分数:错误率 50% + TTFT 50%
- TTFT 阈值:1s → 100分,3s → 0分
- TTFT 对 AI 服务的用户体验更有意义
- 更新所有相关测试用例期望值
2026-01-14 23:47:43 +08:00
IanShaw027
d5eab7da3b refactor(ops): 优化健康分数计算逻辑和阈值
- 移除 SLA 组件(与错误率重复)
- 恢复延迟组件,阈值调整为 1s-2s
- 错误率阈值调整为 1%-10%(更宽松)
- 业务健康分数:错误率 50% + 延迟 50%
- 更新所有相关测试用例期望值
2026-01-14 23:43:12 +08:00
IanShaw027
9b10241561 test(ops): 修复健康分数测试用例期望值
- 更新 TestComputeBusinessHealth 中 SLA 95% 边界测试的期望值
- 更新 TestComputeDashboardHealthScore 中中等健康度测试的期望值
- 适配移除延迟组件后的新健康分数计算逻辑
2026-01-14 23:39:09 +08:00
IanShaw027
76448ab555 refactor(frontend): 优化ops看板骨架屏组件
- 添加 fullscreen 属性支持,适配全屏模式
- 优化骨架屏布局,更好地匹配实际看板结构
- 改进加载动画效果,提升用户体验
2026-01-14 23:26:34 +08:00
IanShaw027
9584af5cb4 fix(ops): 优化错误日志查询和详情展示
- 新增 GetErrorLogByID 接口用于获取单个错误日志详情
- 优化 GetErrorLogs 过滤逻辑,简化参数处理
- 简化前端错误详情模态框代码,提升可维护性
- 更新相关 API 接口和 i18n 翻译
2026-01-14 23:16:01 +08:00
longgexx
6fabddcb0b fix(test): 更新集成测试以匹配新的筛选参数签名
更新 usage_log_repo_integration_test.go 中的测试用例,
   使其与 GetUsageTrendWithFilters 和 GetModelStatsWithFilters
   方法的新签名保持一致。
2026-01-14 22:29:14 +08:00
longgexx
5efeabb0c6 fix(admin): 修复使用记录页面趋势图筛选联动和日期选择问题
修复两个问题:
   1. Token使用趋势图和模型分布图未响应筛选条件
   2. 上午时段选择今天刷新后日期回退到前一天

   前端修改:
   - 更新 dashboard API 类型定义,添加 model、account_id、group_id、stream 参数支持
   - 修改 UsageView 趋势图加载逻辑,传递所有筛选参数到后端
   - 修复日期格式化函数,使用本地时区避免 UTC 转换导致的日期偏移

   后端修改:
   - Handler 层:接收并解析所有筛选参数(model、account_id、group_id、stream)
   - Service 层:传递完整的筛选参数到 Repository 层
   - Repository 层:SQL 查询动态添加所有过滤条件
   - 更新接口定义和测试 mock 以保持一致性

   影响范围:
   - /admin/dashboard/trend 端点现支持完整筛选
   - /admin/dashboard/models 端点现支持完整筛选
   - 用户在后台使用记录页面选择任意筛选条件时,趋势图和模型分布图会实时响应
   - 日期选择器在任何时区下都能正确保持今天的选择
2026-01-14 22:02:56 +08:00
longgexx
806f402bba fix(admin): 修复使用记录页面趋势图筛选联动和日期选择问题
修复两个问题:
   1. Token使用趋势图和模型分布图未响应筛选条件
   2. 上午时段选择今天刷新后日期回退到前一天

   前端修改:
   - 更新 dashboard API 类型定义,添加 model、account_id、group_id、stream 参数支持
   - 修改 UsageView 趋势图加载逻辑,传递所有筛选参数到后端
   - 修复日期格式化函数,使用本地时区避免 UTC 转换导致的日期偏移

   后端修改:
   - Handler 层:接收并解析所有筛选参数(model、account_id、group_id、stream)
   - Service 层:传递完整的筛选参数到 Repository 层
   - Repository 层:SQL 查询动态添加所有过滤条件
   - 更新接口定义和所有调用点以保持一致性

   影响范围:
   - /admin/dashboard/trend 端点现支持完整筛选
   - /admin/dashboard/models 端点现支持完整筛选
   - 用户在后台使用记录页面选择任意筛选条件时,趋势图和模型分布图会实时响应
   - 日期选择器在任何时区下都能正确保持今天的选择
2026-01-14 21:52:50 +08:00
IanShaw027
5432087d96 refactor(frontend): 优化ops错误详情模态框代码格式和功能
- 重构OpsErrorDetailModal.vue代码格式,提升可读性
- 添加上游错误tab显示功能
- 完善i18n翻译(upstream_http)
- 优化其他ops组件代码格式
2026-01-14 20:49:18 +08:00
LLLLLLiulei
02cb14c7b8 test: fix proxy repo stub 2026-01-14 19:56:19 +08:00
LLLLLLiulei
9bdb45be7c feat: enhance proxy management 2026-01-14 19:45:29 +08:00
IanShaw027
514c0562e0 refactor(frontend): 清理OpsDashboardHeader中的i18n翻译
将技术术语的i18n翻译键替换为硬编码文本:
- ms (P99) - 毫秒和百分位数标识
- TTFT - Time To First Token缩写

这些是通用技术术语,不需要国际化。
2026-01-14 19:02:02 +08:00
IanShaw027
371275ec34 refactor(frontend): 清理ops组件中未使用的i18n翻译
- 移除i18n文件中未使用的翻译键(cpu, redis, qps, ttft等)
- 将技术术语改为硬编码(QPS, CPU, TPS等不需要翻译)
- 简化OpsDashboardHeader、OpsErrorDetailModal等组件的i18n调用
2026-01-14 17:04:30 +08:00
墨颜
ec24a3c361 Merge branch 'Wei-Shaw:main' into main 2026-01-14 16:42:30 +08:00
墨颜
d89e797bfc Merge branch 'main' of https://github.com/whoismonay/sub2api 2026-01-14 16:30:32 +08:00
IanShaw027
55e469c7fe fix(ops): 优化错误日志过滤和查询逻辑
后端改动:
- 添加 resolved 参数默认值处理(向后兼容,默认显示未解决错误)
- 新增 status_codes_other 查询参数支持
- 移除 service 层的高级设置过滤逻辑,简化错误日志查询流程

前端改动:
- 完善错误日志相关组件的国际化支持
- 优化 Ops 监控面板和设置对话框的用户体验
2026-01-14 16:26:33 +08:00
墨颜
fb99ceacc7 feat(计费): 支持账号计费倍率快照与统计展示
- 新增 accounts.rate_multiplier(默认 1.0,允许 0)
- 使用 usage_logs.account_rate_multiplier 记录倍率快照,避免历史回算
- 统计/导出/管理端展示账号口径费用(total_cost * account_rate_multiplier)
2026-01-14 16:12:08 +08:00
yangjianbo
daf10907e4 fix(认证): 修复 OAuth token 缓存失效与 401 处理
新增 token 缓存失效接口并在刷新后清理
401 限流支持自定义规则与可配置冷却时间
补齐缓存失效与 401 处理测试
测试: make test
2026-01-14 15:55:44 +08:00
Wesley Liddick
b3b2868f55 Merge pull request #278 from IanShaw027/fix/openai-account-schedulability-check
fix(网关): 修复账号选择中的调度器快照延迟问题
2026-01-14 15:52:35 +08:00
ianshaw
25b00abca1 fix(网关): 修复账号选择中的调度器快照延迟问题
## 问题描述
调度器快照更新存在0.5-1秒的延迟(Outbox轮询间隔),导致在账号被限流或过载后的短时间窗口内,
可能仍会被选中,造成请求失败。

## 根本原因
账号选择逻辑依赖调度器快照(listSchedulableAccounts),但快照更新有延迟:
- Outbox轮询: 每1秒检查一次变更事件
- 全量重建: 每300秒重建一次
- 时间窗口: 账号状态变更后0.5-1秒内,快照可能未更新

## 解决方案
在账号选择循环中添加IsSchedulable()实时检查,作为第二道防线:
1. 第一道防线: 调度器快照过滤(可能有延迟)
2. 第二道防线: IsSchedulable()实时检查(本次修复)

IsSchedulable()会检查:
- RateLimitResetAt: 限流重置时间
- OverloadUntil: 过载持续时间
- TempUnschedulableUntil: 临时不可调度时间
- Status: 账号状态
- Schedulable: 可调度标志

## 修改范围
### OpenAI Gateway Service
- SelectAccountForModelWithExclusions: 添加IsSchedulable()检查
- SelectAccountWithLoadAwareness: 添加IsSchedulable()检查

### Gateway Service (Claude/Gemini/Antigravity)
- 负载感知选择候选账号筛选: 添加IsSchedulable()检查
- selectAccountForModelWithPlatform: 添加IsSchedulable()检查
- selectAccountWithMixedScheduling: 添加IsSchedulable()检查

### 测试用例
- OpenAI: 添加2个测试用例验证限流账号过滤
- Gateway: 添加2个测试用例验证限流和过载账号过滤

### 其他修复
- ops_repo_preagg.go: 修复platform为NULL时的聚合问题

## 测试结果
所有单元测试通过 
2026-01-13 22:49:26 -08:00
IanShaw027
8d0767352b fix(ops): 修复ops handler逻辑 2026-01-14 14:30:41 +08:00
IanShaw027
918a253851 feat(frontend): 完善ops监控面板和组件功能 2026-01-14 14:30:18 +08:00
IanShaw027
63711067e6 refactor(ops): 完善gateway服务ops集成 2026-01-14 14:30:00 +08:00
IanShaw027
7158b38897 refactor(ops): 优化ops repository数据访问层 2026-01-14 14:29:39 +08:00
IanShaw027
7f317b9093 feat(ops): 增强ops核心服务功能和重试机制 2026-01-14 14:29:19 +08:00
IanShaw027
7c4309ea24 feat(ops): 添加ops handler和路由配置 2026-01-14 14:29:01 +08:00
IanShaw027
5013290486 feat(frontend): 优化ops监控UI组件 2026-01-14 12:41:24 +08:00
IanShaw027
8cf3e9a620 feat(frontend): 更新ops API接口和国际化文案 2026-01-14 12:41:05 +08:00
IanShaw027
060699c3b8 refactor(ops): 更新gateway服务集成ops功能 2026-01-14 12:40:49 +08:00
IanShaw027
2ca6c631ac refactor(ops): 重构ops handler和repository层 2026-01-14 12:40:34 +08:00
IanShaw027
967e25878f refactor(ops): 重构ops核心服务层代码 2026-01-14 12:40:12 +08:00
IanShaw027
182683814b refactor(ops): 移除duration相关告警指标,简化监控配置
主要改动:
- 移除 p95_latency_ms 和 p99_latency_ms 告警指标类型
- 移除配置中的 latency_p99_ms_max 阈值设置
- 简化健康分数计算(移除latency权重,重新归一化SLA和错误率)
- 移除duration相关的诊断规则和阈值检查
- 统一术语:延迟 → 请求时长
- 保留duration数据展示,但不再用于告警判断
- 聚焦TTFT作为主要的响应速度告警指标

影响范围:
- Backend: handler, service, models, tests
- Frontend: API types, i18n, components
2026-01-14 10:52:56 +08:00
shaw
99cbfa1567 fix(admin): 修复退款金额精度问题
- 显示完整余额精度,避免四舍五入导致的退款失败
- 添加"全部"按钮,一键填入完整余额
- 移除最小金额限制,支持任意正数金额
2026-01-14 10:22:31 +08:00
Wesley Liddick
3f8c8d70ad Merge pull request #274 from mt21625457/main
fix(openai): OAuth 请求强制 store=false
2026-01-14 09:53:09 +08:00
yangjianbo
9c567fad92 fix(网关): 优化 OAuth 请求中 store 参数的处理逻辑 2026-01-14 09:46:10 +08:00
IanShaw027
33f58d583d fix(ops): 修复告警状态验证和错误处理逻辑
- 增强告警事件状态验证,添加合法状态值检查
- 移除重试逻辑中的遗留字段赋值
- 修正仓库不可用时的错误类型
- 格式化测试文件代码
2026-01-14 09:39:18 +08:00
yangjianbo
0abb3a6843 Merge branch 'main' of https://github.com/mt21625457/aicodex2api 2026-01-14 09:22:37 +08:00
yangjianbo
3663951d11 fix(网关): OAuth 请求强制 store=false
避免上游 Store 必须为 false 的错误

仅在缺失或 true 时写回 store

测试: go test ./internal/service -run TestApplyCodexOAuthTransform

测试: make test-backend(golangci-lint 已单独执行)
2026-01-14 09:17:58 +08:00
IanShaw027
1e169685f4 feat(i18n): 添加ops新功能的国际化文案
- 新增告警静默相关的中英文翻译
- 补充错误分类和重试状态的文案
- 完善ops管理界面的提示信息
2026-01-14 09:04:09 +08:00
IanShaw027
f38a3e7585 feat(ui): 优化ops监控面板和组件功能
- 增强告警事件卡片的交互和静默功能
- 完善错误详情弹窗的展示和操作
- 优化错误日志表格的筛选和排序
- 新增重试和解决状态的UI支持
2026-01-14 09:03:59 +08:00
IanShaw027
b8da5d45ce feat(api): 扩展前端ops API接口
- 新增告警静默相关API调用
- 增强错误日志查询和过滤接口
- 添加重试和解决状态管理接口
2026-01-14 09:03:45 +08:00
IanShaw027
659df6e220 feat(handler): 新增ops管理接口和路由
- 添加告警静默管理接口
- 扩展错误日志查询和操作接口
- 新增重试和解决状态相关端点
- 完善错误日志记录功能
2026-01-14 09:03:35 +08:00
IanShaw027
d601768016 feat(service): 增强ops业务逻辑和告警功能
- 实现告警静默功能的业务逻辑
- 优化错误分类和重试机制
- 扩展告警评估和通知功能
- 完善错误解决和重试结果处理
2026-01-14 09:03:16 +08:00
IanShaw027
16ddc6a83b feat(repository): 扩展ops数据访问层功能
- 新增告警静默相关数据库操作
- 增强错误日志查询和统计功能
- 优化重试结果和解决状态的存储
2026-01-14 09:03:01 +08:00
IanShaw027
340dc9cadb feat(db): 添加ops告警静默和错误分类优化迁移
- 添加ops告警静默功能的数据库结构
- 优化错误分类和重试结果字段标准化
2026-01-14 09:02:45 +08:00
Wesley Liddick
55fced3942 Merge pull request #269 from mt21625457/main
fix: 修复opencode 适配openai 套餐的错误,通过sub2api完美转发 opencode
2026-01-13 17:33:07 +08:00
yangjianbo
7bbf49fd65 为类型断言补充 ok 校验并添加中文说明,避免 errcheck 报错(backend/internal/service/
openai_codex_transform_test.go:36、backend/internal/service/
    openai_codex_transform_test.go:89、backend/internal/service/
    openai_codex_transform_test.go:104)。
2026-01-13 17:22:57 +08:00
yangjianbo
eea6c2d02c fix(网关): 补齐Codex指令回退与输入过滤 2026-01-13 17:02:31 +08:00
yangjianbo
70eaa450db fix(网关): 修复工具续链校验与存储策略
完善 function_call_output 续链校验与引用匹配
续链场景强制 store=true,过滤 input 时避免副作用
补充续链判断与过滤相关单元测试

测试: go test ./...
2026-01-13 16:47:35 +08:00
Wesley Liddick
55796a118d Merge pull request #264 from IanShaw027/fix/openai-opencode-compatibility
fix(openai): 增强 OpenCode 兼容性和模型规范化
2026-01-13 16:01:37 +08:00
song
9a22d1a690 refactor: 提取 getOrCreateGeminiParts 减少重复代码
将两个 merge 函数中重复的 Gemini 响应结构访问逻辑提取为公共函数。
2026-01-13 13:25:55 +08:00
song
c9d21d53e6 fix: 修复 Antigravity 非流式响应文本丢失问题
Gemini 流式响应是增量的,需要累积所有 chunk 的文本内容。
原代码只保留最后一个有 parts 的 chunk,导致实际文本被空
text + thoughtSignature 的最终 chunk 覆盖。

添加 collectedTextParts 收集所有文本片段,返回前合并。
2026-01-13 13:04:03 +08:00
song
e1015c2759 fix: 修复 Antigravity 图片生成响应丢失问题
流式转非流式时,图片数据在中间 chunk 返回,最后一个 chunk 只有
finishReason,导致只保留最后 chunk 时图片丢失。

添加 collectedImageParts 收集所有图片 parts,并在返回前合并。
2026-01-13 12:58:05 +08:00
ianshaw
d7fa47d732 refactor(openai): 移除不必要的 seedOpenAISessionHeaders 函数 2026-01-12 20:38:46 -08:00
ianshaw
3d6e01a58f fix(openai): 增强 OpenCode 兼容性和模型规范化
## 主要改动

1. **模型规范化扩展到所有账号**
   - 将 Codex 模型规范化(如 gpt-5-nano → gpt-5.1)应用到所有 OpenAI 账号类型
   - 不再仅限于 OAuth 非 CLI 请求
   - 解决 Codex CLI 使用 ChatGPT 账号时的模型兼容性问题

2. **reasoning.effort 参数规范化**
   - 自动将 `minimal` 转换为 `none`
   - 解决 gpt-5.1 模型不支持 `minimal` 值的问题

3. **Session/Conversation ID fallback 机制**
   - 从请求体多个字段提取 session_id/conversation_id
   - 优先级:prompt_cache_key → session_id → conversation_id → previous_response_id
   - 支持 Codex CLI 的会话保持

4. **Tool Call ID fallback**
   - 当 call_id 为空时使用 id 字段作为 fallback
   - 确保 tool call 输出能正确匹配
   - 保留 item_reference 类型的 items

5. **Header 优化**
   - 添加 conversation_id 到允许的 headers
   - 移除删除 session headers 的逻辑

## 相关 Issue
- 参考 OpenCode issue #3118 关于 item_reference 的讨论
2026-01-12 20:18:53 -08:00
IanShaw027
f9713e8733 fix(codex): 添加codex CLI instructions fallback机制
## 问题
- 使用OpenAI API key时,opencode客户端可能因instructions不兼容而报错
- 依赖外部GitHub获取instructions,网络故障时会失败

## 解决方案
1. 将codex CLI标准instructions嵌入到项目中
2. 实现自动fallback机制:
   - 优先使用opencode GitHub的instructions
   - 失败时自动fallback到本地codex CLI instructions
3. 添加辅助函数用于错误检测和手动替换

## 改动
- 新增: internal/service/prompts/codex_cli_instructions.md
  - 从codex项目复制的标准instructions
  - 使用go:embed嵌入到二进制文件

- 修改: internal/service/openai_codex_transform.go
  - 添加embed支持
  - 增强getOpenCodeCodexHeader()的fallback逻辑
  - 新增GetCodexCLIInstructions()公开函数
  - 新增ReplaceWithCodexInstructions()用于手动替换
  - 新增IsInstructionError()用于错误检测

## 优势
- 零停机:GitHub不可用时仍能正常工作
- 离线可用:不依赖外部网络
- 兼容性:使用标准codex CLI instructions
- 部署简单:instructions嵌入到二进制文件
2026-01-13 11:14:32 +08:00
yangjianbo
0e44829720 Merge branch 'main' into dev 2026-01-13 10:29:12 +08:00
shaw
93db889a10 fix: Gemini OpenCode 教程 baseURL 改为 v1beta 2026-01-13 09:52:37 +08:00
Wesley Liddick
0df7385c4e Merge pull request #226 from xilu0/main
feat(gateway): 优化 Antigravity/Gemini 思考块处理 此提交解决了思考块 (thinking blocks) 在转发过程中的兼容性问题
2026-01-13 09:39:43 +08:00
Wesley Liddick
1a3fa6411c Merge pull request #260 from IanShaw027/fix/sync-openai-gpt5-models
fix: 同步 OpenAI GPT-5 模型列表并完善参数处理
2026-01-13 09:31:00 +08:00
Wesley Liddick
64614756d1 Merge pull request #259 from cyhhao/main
fix: adjust OpenCode OpenAI example store placement
2026-01-13 09:30:26 +08:00
Wesley Liddick
bb1fd54d4d Merge pull request #257 from Edric-Li/feat/ops-fullscreen-scrollbar
feat(ops): 添加运维监控全屏模式 & 优化滚动条
2026-01-13 09:29:25 +08:00
ianshaw
d85288a6c0 Revert "fix(gateway): 修复 base_url 包含 /chat/completions 时路径拼接错误"
This reverts commit 7fdc25df3c.
2026-01-12 13:29:04 -08:00
ianshaw
3402acb606 feat(gateway): 对所有请求(包括 Codex CLI)应用模型映射
- 移除 Codex CLI 的模型映射跳过逻辑
- 添加详细的模型映射日志,包含账号名称和请求类型
- 确保所有 OpenAI 请求都能正确应用账号配置的模型映射
2026-01-12 13:23:05 -08:00
ianshaw
7fdc25df3c fix(gateway): 修复 base_url 包含 /chat/completions 时路径拼接错误
问题:
- 当账号的 base_url 配置为 https://example.com/v1/chat/completions 时
- 代码直接追加 /responses,导致路径变成 /v1/chat/completions/responses
- 上游返回 404 错误

修复:
- 在追加 /responses 前,先移除 base_url 中的 /chat/completions 后缀
- 确保最终路径为 https://example.com/v1/responses

影响范围:
- OpenAI API Key 账号的测试接口
- OpenAI API Key 账号的实际网关请求

Related-to: #231
2026-01-12 11:39:45 -08:00
ianshaw
ea699cbdc2 docs(frontend): 完善 OpenCode 配置说明
更新 API 密钥页面 OpenCode 配置提示信息:
- 补充支持 opencode.jsonc 后缀名
- 说明可使用默认 provider(openai/anthropic/google)或自定义 provider_id
- 说明 API Key 支持直接配置或通过 /connect 命令配置
- 保留"示例仅供参考,模型与选项可按需调整"的提示

配置文件路径:~/.config/opencode/opencode.json(或 opencode.jsonc)
2026-01-12 11:17:47 -08:00
ianshaw
fe6a3f4267 fix(gateway): 完善 max_output_tokens 参数处理逻辑
根据不同平台和账号类型处理 max_output_tokens 参数:
- OpenAI OAuth (Responses API): 保留 max_output_tokens(支持)
- OpenAI API Key: 删除 max_output_tokens(不支持)
- Anthropic (Claude): 转换 max_output_tokens 为 max_tokens
- Gemini: 删除 max_output_tokens(由 Gemini 专用转换处理)
- 其他平台: 删除(安全起见)

同时处理 max_completion_tokens 参数,仅在 OpenAI OAuth 时保留。

修复客户端(如 OpenCode)发送不支持参数导致上游返回 400 错误的问题。

Related-to: #231
2026-01-12 11:08:28 -08:00
ianshaw
fe8198c8cd fix(frontend): 同步 OpenAI GPT-5 系列模型列表
修复编辑账号页面 GPT-5 模型只显示 3 个的问题:
- 原来只有: gpt-5, gpt-5-mini, gpt-5-nano
- 现在添加完整的 22 个模型,包括:
  * GPT-5 系列: gpt-5, gpt-5-codex, gpt-5-chat, gpt-5-pro, gpt-5-mini, gpt-5-nano 及各时间戳版本
  * GPT-5.1 系列: gpt-5.1, gpt-5.1-codex, gpt-5.1-codex-max, gpt-5.1-codex-mini 及各版本
  * GPT-5.2 系列: gpt-5.2, gpt-5.2-codex, gpt-5.2-pro 及各版本
- 更新快捷预设按钮,新增 GPT-5.1, GPT-5.2, GPT-5.1 Codex 选项

与后端定价文件 (model_prices_and_context_window.json) 保持一致。

Fixes issue introduced in fb86002 (feat: 添加模型白名单选择器组件)
Related-to: fb86002ef9
2026-01-12 10:14:50 -08:00
cyhhao
675e61385f Merge branch 'main' of github.com:Wei-Shaw/sub2api 2026-01-12 22:36:14 +08:00
cyhhao
67acac1082 fix: adjust OpenCode OpenAI example store placement 2026-01-12 22:31:43 +08:00
Edric Li
d02e1db018 style: 优化滚动条自动隐藏效果
- 默认隐藏滚动条,悬停时显示
- 支持 Webkit (Chrome/Safari/Edge) 和 Firefox
- 滚动条样式与暗色主题适配
2026-01-12 22:10:59 +08:00
Edric Li
0da515071b feat(ops): 添加运维监控全屏模式
- 支持通过 URL 参数 ?fullscreen=1 进入全屏模式
- 全屏模式下隐藏非必要 UI 元素(选择器、按钮、提示等)
- 增大健康评分圆环和字体以提升可读性
- 支持 ESC 键退出全屏
- 添加全屏按钮的 i18n 翻译
2026-01-12 22:10:59 +08:00
xiluo
524d80ae1c feat(gateway): 优化 Antigravity/Gemini 思考块处理
此提交解决了思考块 (thinking blocks) 在转发过程中的兼容性问题。

主要变更:

1. **思考块优化 (Thinking Blocks)**:
   - 在 AntigravityGatewayService 中增加了 sanitizeThinkingBlocks 处理,强制移除思考块中不支持的 cache_control 字段(避免 Anthropic/Vertex AI 报错)
   - 实现历史思考块展平 (Flattening):将非最后一条消息中的思考块转换为普通文本块,以绕过上游对历史思考块签名的严格校验
   - 增加 cleanCacheControlFromGeminiJSON 作为最后一道防线,确保转换后的 Gemini 请求中不残留非法的 cache_control

2. **GatewayService 缓存控制优化**:
   - 更新缓存控制逻辑,跳过 thinking 块(thinking 块不支持 cache_control 字段)
   - 增加 removeCacheControlFromThinkingBlocks 函数强制清理

关联 Issue: #225
2026-01-12 13:36:59 +00:00
song
f0ece82111 feat: 在 dashboard 右上角添加文档链接 2026-01-12 17:01:57 +08:00
yangjianbo
9618cb5643 Merge branch 'main' into test 2026-01-12 15:15:03 +08:00
yangjianbo
9c02ab789d fix(rate_limiter): 更新速率限制逻辑,支持返回修复状态 2026-01-12 14:42:58 +08:00
song
11bfc807d7 merge upstream/main 2026-01-09 21:37:27 +08:00
song
c2a6ca8d3a chore: 提升 SSE 单行上限到 40MB 2026-01-09 20:57:06 +08:00
song
7b1cf2c495 chore: 调整 SSE 单行上限到 25MB 2026-01-09 20:47:13 +08:00
song
da1f3d61be feat: antigravity 配额域限流 2026-01-09 17:35:02 +08:00
song
dc3cd62125 Merge remote-tracking branch 'upstream/main' 2026-01-08 22:52:18 +08:00
song
bc404d4fc1 merge: 合并 upstream/main 使用上游版本解决冲突 2026-01-08 18:01:29 +08:00
song
a4a0c0e2cc feat(antigravity): 增强请求参数和注入 Antigravity 身份 system prompt 2026-01-08 13:07:20 +08:00
song
c7abfe67b5 Merge remote-tracking branch 'upstream/main' 2026-01-08 12:17:56 +08:00
song
4e3476a669 fix: 添加 gemini-3-flash 前缀映射支持 gemini-3-flash-preview 2026-01-06 15:09:21 +08:00
378 changed files with 49817 additions and 4182 deletions

3
.gitignore vendored
View File

@@ -83,6 +83,8 @@ temp/
*.log
*.bak
.cache/
.dev/
.serena/
# ===================
# 构建产物
@@ -127,3 +129,4 @@ deploy/docker-compose.override.yml
.gocache/
vite.config.js
docs/*
.serena/

164
PR_DESCRIPTION.md Normal file
View File

@@ -0,0 +1,164 @@
## 概述
全面增强运维监控系统Ops的错误日志管理和告警静默功能优化前端 UI 组件代码质量和用户体验。本次更新重构了核心服务层和数据访问层,提升系统可维护性和运维效率。
## 主要改动
### 1. 错误日志查询优化
**功能特性:**
- 新增 GetErrorLogByID 接口,支持按 ID 精确查询错误详情
- 优化错误日志过滤逻辑,支持多维度筛选(平台、阶段、来源、所有者等)
- 改进查询参数处理,简化代码结构
- 增强错误分类和标准化处理
- 支持错误解决状态追踪resolved 字段)
**技术实现:**
- `ops_handler.go` - 新增单条错误日志查询接口
- `ops_repo.go` - 优化数据查询和过滤条件构建
- `ops_models.go` - 扩展错误日志数据模型
- 前端 API 接口同步更新
### 2. 告警静默功能
**功能特性:**
- 支持按规则、平台、分组、区域等维度静默告警
- 可设置静默时长和原因说明
- 静默记录可追溯,记录创建人和创建时间
- 自动过期机制,避免永久静默
**技术实现:**
- `037_ops_alert_silences.sql` - 新增告警静默表
- `ops_alerts.go` - 告警静默逻辑实现
- `ops_alerts_handler.go` - 告警静默 API 接口
- `OpsAlertEventsCard.vue` - 前端告警静默操作界面
**数据库结构:**
| 字段 | 类型 | 说明 |
|------|------|------|
| rule_id | BIGINT | 告警规则 ID |
| platform | VARCHAR(64) | 平台标识 |
| group_id | BIGINT | 分组 ID可选 |
| region | VARCHAR(64) | 区域(可选) |
| until | TIMESTAMPTZ | 静默截止时间 |
| reason | TEXT | 静默原因 |
| created_by | BIGINT | 创建人 ID |
### 3. 错误分类标准化
**功能特性:**
- 统一错误阶段分类request|auth|routing|upstream|network|internal
- 规范错误归属分类client|provider|platform
- 标准化错误来源分类client_request|upstream_http|gateway
- 自动迁移历史数据到新分类体系
**技术实现:**
- `038_ops_errors_resolution_retry_results_and_standardize_classification.sql` - 分类标准化迁移
- 自动映射历史遗留分类到新标准
- 自动解决已恢复的上游错误(客户端状态码 < 400
### 4. Gateway 服务集成
**功能特性:**
- 完善各 Gateway 服务的 Ops 集成
- 统一错误日志记录接口
- 增强上游错误追踪能力
**涉及服务:**
- `antigravity_gateway_service.go` - Antigravity 网关集成
- `gateway_service.go` - 通用网关集成
- `gemini_messages_compat_service.go` - Gemini 兼容层集成
- `openai_gateway_service.go` - OpenAI 网关集成
### 5. 前端 UI 优化
**代码重构:**
- 大幅简化错误详情模态框代码(从 828 行优化到 450 行)
- 优化错误日志表格组件,提升可读性
- 清理未使用的 i18n 翻译,减少冗余
- 统一组件代码风格和格式
- 优化骨架屏组件,更好匹配实际看板布局
**布局改进:**
- 修复模态框内容溢出和滚动问题
- 优化表格布局,使用 flex 布局确保正确显示
- 改进看板头部布局和交互
- 提升响应式体验
- 骨架屏支持全屏模式适配
**交互优化:**
- 优化告警事件卡片功能和展示
- 改进错误详情展示逻辑
- 增强请求详情模态框
- 完善运行时设置卡片
- 改进加载动画效果
### 6. 国际化完善
**文案补充:**
- 补充错误日志相关的英文翻译
- 添加告警静默功能的中英文文案
- 完善提示文本和错误信息
- 统一术语翻译标准
## 文件变更
**后端26 个文件):**
- `backend/internal/handler/admin/ops_alerts_handler.go` - 告警接口增强
- `backend/internal/handler/admin/ops_handler.go` - 错误日志接口优化
- `backend/internal/handler/ops_error_logger.go` - 错误记录器增强
- `backend/internal/repository/ops_repo.go` - 数据访问层重构
- `backend/internal/repository/ops_repo_alerts.go` - 告警数据访问增强
- `backend/internal/service/ops_*.go` - 核心服务层重构10 个文件)
- `backend/internal/service/*_gateway_service.go` - Gateway 集成4 个文件)
- `backend/internal/server/routes/admin.go` - 路由配置更新
- `backend/migrations/*.sql` - 数据库迁移2 个文件)
- 测试文件更新5 个文件)
**前端13 个文件):**
- `frontend/src/views/admin/ops/OpsDashboard.vue` - 看板主页优化
- `frontend/src/views/admin/ops/components/*.vue` - 组件重构10 个文件)
- `frontend/src/api/admin/ops.ts` - API 接口扩展
- `frontend/src/i18n/locales/*.ts` - 国际化文本2 个文件)
## 代码统计
- 44 个文件修改
- 3733 行新增
- 995 行删除
- 净增加 2738 行
## 核心改进
**可维护性提升:**
- 重构核心服务层,职责更清晰
- 简化前端组件代码,降低复杂度
- 统一代码风格和命名规范
- 清理冗余代码和未使用的翻译
- 标准化错误分类体系
**功能完善:**
- 告警静默功能,减少告警噪音
- 错误日志查询优化,提升运维效率
- Gateway 服务集成完善,统一监控能力
- 错误解决状态追踪,便于问题管理
**用户体验优化:**
- 修复多个 UI 布局问题
- 优化交互流程
- 完善国际化支持
- 提升响应式体验
- 改进加载状态展示
## 测试验证
- ✅ 错误日志查询和过滤功能
- ✅ 告警静默创建和自动过期
- ✅ 错误分类标准化迁移
- ✅ Gateway 服务错误日志记录
- ✅ 前端组件布局和交互
- ✅ 骨架屏全屏模式适配
- ✅ 国际化文本完整性
- ✅ API 接口功能正确性
- ✅ 数据库迁移执行成功

View File

@@ -18,7 +18,7 @@ English | [中文](README_CN.md)
## Demo
Try Sub2API online: **https://v2.pincc.ai/**
Try Sub2API online: **https://demo.sub2api.org/**
Demo credentials (shared demo environment; **not** created automatically for self-hosted installs):

View File

@@ -57,6 +57,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
---
## OpenAI Responses 兼容注意事项
- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id``tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`
- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。
---
## 部署方式
### 方式一:脚本安装(推荐)

View File

@@ -8,6 +8,7 @@ import (
"errors"
"flag"
"log"
"log/slog"
"net/http"
"os"
"os/signal"
@@ -44,7 +45,25 @@ func init() {
}
}
// initLogger configures the default slog handler based on gin.Mode().
// In non-release mode, Debug level logs are enabled.
func initLogger() {
var level slog.Level
if gin.Mode() == gin.ReleaseMode {
level = slog.LevelInfo
} else {
level = slog.LevelDebug
}
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
})
slog.SetDefault(slog.New(handler))
}
func main() {
// Initialize slog logger based on gin mode
initLogger()
// Parse command line flags
setupMode := flag.Bool("setup", false, "Run setup wizard in CLI mode")
showVersion := flag.Bool("version", false, "Show version information")

View File

@@ -70,6 +70,8 @@ func provideCleanup(
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
subscriptionExpiry *service.SubscriptionExpiryService,
usageCleanup *service.UsageCleanupService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
billingCache *service.BillingCacheService,
@@ -123,6 +125,12 @@ func provideCleanup(
}
return nil
}},
{"UsageCleanupService", func() error {
if usageCleanup != nil {
usageCleanup.Stop()
}
return nil
}},
{"TokenRefreshService", func() error {
tokenRefresh.Stop()
return nil
@@ -131,6 +139,10 @@ func provideCleanup(
accountExpiry.Stop()
return nil
}},
{"SubscriptionExpiryService", func() error {
subscriptionExpiry.Stop()
return nil
}},
{"PricingService", func() error {
pricing.Stop()
return nil

View File

@@ -63,11 +63,16 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService)
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
if err != nil {
return nil, err
}
totpCache := repository.NewTotpCache(redisClient)
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, totpService)
userHandler := handler.NewUserHandler(userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
@@ -76,15 +81,21 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig)
timingWheelService := service.ProvideTimingWheelService()
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig)
timingWheelService, err := service.ProvideTimingWheelService()
if err != nil {
return nil, err
}
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
accountRepository := repository.NewAccountRepository(client, db)
schedulerCache := repository.NewSchedulerCache(redisClient)
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, apiKeyAuthCacheInvalidator)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService)
claudeOAuthClient := repository.NewClaudeOAuthClient()
@@ -98,25 +109,25 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService)
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
httpUpstream := repository.NewHTTPUpstream(configConfig)
claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream)
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
usageCache := service.NewUsageCache()
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
identityCache := repository.NewIdentityCache(redisClient)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
gatewayCache := repository.NewGatewayCache(redisClient)
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
httpUpstream := repository.NewHTTPUpstream(configConfig)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
schedulerCache := repository.NewSchedulerCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
@@ -125,17 +136,20 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminRedeemHandler := admin.NewRedeemHandler(adminService)
promoHandler := admin.NewPromoHandler(promoService)
opsRepository := repository.NewOpsRepository(db)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil {
return nil, err
}
billingService := service.NewBillingService(configConfig, pricingService)
identityCache := repository.NewIdentityCache(redisClient)
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache)
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
@@ -146,7 +160,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo)
systemHandler := handler.ProvideSystemHandler(updateService)
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService)
usageCleanupRepository := repository.NewUsageCleanupRepository(client, db)
usageCleanupService := service.ProvideUsageCleanupService(usageCleanupRepository, timingWheelService, dashboardAggregationService, configConfig)
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService, usageCleanupService)
userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client)
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
@@ -155,7 +171,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
totpHandler := handler.NewTotpHandler(totpService)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
@@ -166,9 +183,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -201,6 +219,8 @@ func provideCleanup(
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
subscriptionExpiry *service.SubscriptionExpiryService,
usageCleanup *service.UsageCleanupService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
billingCache *service.BillingCacheService,
@@ -253,6 +273,12 @@ func provideCleanup(
}
return nil
}},
{"UsageCleanupService", func() error {
if usageCleanup != nil {
usageCleanup.Stop()
}
return nil
}},
{"TokenRefreshService", func() error {
tokenRefresh.Stop()
return nil
@@ -261,6 +287,10 @@ func provideCleanup(
accountExpiry.Stop()
return nil
}},
{"SubscriptionExpiryService", func() error {
subscriptionExpiry.Stop()
return nil
}},
{"PricingService", func() error {
pricing.Stop()
return nil

View File

@@ -43,6 +43,8 @@ type Account struct {
Concurrency int `json:"concurrency,omitempty"`
// Priority holds the value of the "priority" field.
Priority int `json:"priority,omitempty"`
// RateMultiplier holds the value of the "rate_multiplier" field.
RateMultiplier float64 `json:"rate_multiplier,omitempty"`
// Status holds the value of the "status" field.
Status string `json:"status,omitempty"`
// ErrorMessage holds the value of the "error_message" field.
@@ -135,6 +137,8 @@ func (*Account) scanValues(columns []string) ([]any, error) {
values[i] = new([]byte)
case account.FieldAutoPauseOnExpired, account.FieldSchedulable:
values[i] = new(sql.NullBool)
case account.FieldRateMultiplier:
values[i] = new(sql.NullFloat64)
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority:
values[i] = new(sql.NullInt64)
case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldSessionWindowStatus:
@@ -241,6 +245,12 @@ func (_m *Account) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.Priority = int(value.Int64)
}
case account.FieldRateMultiplier:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field rate_multiplier", values[i])
} else if value.Valid {
_m.RateMultiplier = value.Float64
}
case account.FieldStatus:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field status", values[i])
@@ -420,6 +430,9 @@ func (_m *Account) String() string {
builder.WriteString("priority=")
builder.WriteString(fmt.Sprintf("%v", _m.Priority))
builder.WriteString(", ")
builder.WriteString("rate_multiplier=")
builder.WriteString(fmt.Sprintf("%v", _m.RateMultiplier))
builder.WriteString(", ")
builder.WriteString("status=")
builder.WriteString(_m.Status)
builder.WriteString(", ")

View File

@@ -39,6 +39,8 @@ const (
FieldConcurrency = "concurrency"
// FieldPriority holds the string denoting the priority field in the database.
FieldPriority = "priority"
// FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
FieldRateMultiplier = "rate_multiplier"
// FieldStatus holds the string denoting the status field in the database.
FieldStatus = "status"
// FieldErrorMessage holds the string denoting the error_message field in the database.
@@ -116,6 +118,7 @@ var Columns = []string{
FieldProxyID,
FieldConcurrency,
FieldPriority,
FieldRateMultiplier,
FieldStatus,
FieldErrorMessage,
FieldLastUsedAt,
@@ -174,6 +177,8 @@ var (
DefaultConcurrency int
// DefaultPriority holds the default value on creation for the "priority" field.
DefaultPriority int
// DefaultRateMultiplier holds the default value on creation for the "rate_multiplier" field.
DefaultRateMultiplier float64
// DefaultStatus holds the default value on creation for the "status" field.
DefaultStatus string
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
@@ -244,6 +249,11 @@ func ByPriority(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldPriority, opts...).ToFunc()
}
// ByRateMultiplier orders the results by the rate_multiplier field.
func ByRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRateMultiplier, opts...).ToFunc()
}
// ByStatus orders the results by the status field.
func ByStatus(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStatus, opts...).ToFunc()

View File

@@ -105,6 +105,11 @@ func Priority(v int) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldPriority, v))
}
// RateMultiplier applies equality check predicate on the "rate_multiplier" field. It's identical to RateMultiplierEQ.
func RateMultiplier(v float64) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldRateMultiplier, v))
}
// Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
func Status(v string) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldStatus, v))
@@ -675,6 +680,46 @@ func PriorityLTE(v int) predicate.Account {
return predicate.Account(sql.FieldLTE(FieldPriority, v))
}
// RateMultiplierEQ applies the EQ predicate on the "rate_multiplier" field.
func RateMultiplierEQ(v float64) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldRateMultiplier, v))
}
// RateMultiplierNEQ applies the NEQ predicate on the "rate_multiplier" field.
func RateMultiplierNEQ(v float64) predicate.Account {
return predicate.Account(sql.FieldNEQ(FieldRateMultiplier, v))
}
// RateMultiplierIn applies the In predicate on the "rate_multiplier" field.
func RateMultiplierIn(vs ...float64) predicate.Account {
return predicate.Account(sql.FieldIn(FieldRateMultiplier, vs...))
}
// RateMultiplierNotIn applies the NotIn predicate on the "rate_multiplier" field.
func RateMultiplierNotIn(vs ...float64) predicate.Account {
return predicate.Account(sql.FieldNotIn(FieldRateMultiplier, vs...))
}
// RateMultiplierGT applies the GT predicate on the "rate_multiplier" field.
func RateMultiplierGT(v float64) predicate.Account {
return predicate.Account(sql.FieldGT(FieldRateMultiplier, v))
}
// RateMultiplierGTE applies the GTE predicate on the "rate_multiplier" field.
func RateMultiplierGTE(v float64) predicate.Account {
return predicate.Account(sql.FieldGTE(FieldRateMultiplier, v))
}
// RateMultiplierLT applies the LT predicate on the "rate_multiplier" field.
func RateMultiplierLT(v float64) predicate.Account {
return predicate.Account(sql.FieldLT(FieldRateMultiplier, v))
}
// RateMultiplierLTE applies the LTE predicate on the "rate_multiplier" field.
func RateMultiplierLTE(v float64) predicate.Account {
return predicate.Account(sql.FieldLTE(FieldRateMultiplier, v))
}
// StatusEQ applies the EQ predicate on the "status" field.
func StatusEQ(v string) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldStatus, v))

View File

@@ -153,6 +153,20 @@ func (_c *AccountCreate) SetNillablePriority(v *int) *AccountCreate {
return _c
}
// SetRateMultiplier sets the "rate_multiplier" field.
func (_c *AccountCreate) SetRateMultiplier(v float64) *AccountCreate {
_c.mutation.SetRateMultiplier(v)
return _c
}
// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil.
func (_c *AccountCreate) SetNillableRateMultiplier(v *float64) *AccountCreate {
if v != nil {
_c.SetRateMultiplier(*v)
}
return _c
}
// SetStatus sets the "status" field.
func (_c *AccountCreate) SetStatus(v string) *AccountCreate {
_c.mutation.SetStatus(v)
@@ -429,6 +443,10 @@ func (_c *AccountCreate) defaults() error {
v := account.DefaultPriority
_c.mutation.SetPriority(v)
}
if _, ok := _c.mutation.RateMultiplier(); !ok {
v := account.DefaultRateMultiplier
_c.mutation.SetRateMultiplier(v)
}
if _, ok := _c.mutation.Status(); !ok {
v := account.DefaultStatus
_c.mutation.SetStatus(v)
@@ -488,6 +506,9 @@ func (_c *AccountCreate) check() error {
if _, ok := _c.mutation.Priority(); !ok {
return &ValidationError{Name: "priority", err: errors.New(`ent: missing required field "Account.priority"`)}
}
if _, ok := _c.mutation.RateMultiplier(); !ok {
return &ValidationError{Name: "rate_multiplier", err: errors.New(`ent: missing required field "Account.rate_multiplier"`)}
}
if _, ok := _c.mutation.Status(); !ok {
return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "Account.status"`)}
}
@@ -578,6 +599,10 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) {
_spec.SetField(account.FieldPriority, field.TypeInt, value)
_node.Priority = value
}
if value, ok := _c.mutation.RateMultiplier(); ok {
_spec.SetField(account.FieldRateMultiplier, field.TypeFloat64, value)
_node.RateMultiplier = value
}
if value, ok := _c.mutation.Status(); ok {
_spec.SetField(account.FieldStatus, field.TypeString, value)
_node.Status = value
@@ -893,6 +918,24 @@ func (u *AccountUpsert) AddPriority(v int) *AccountUpsert {
return u
}
// SetRateMultiplier sets the "rate_multiplier" field.
func (u *AccountUpsert) SetRateMultiplier(v float64) *AccountUpsert {
u.Set(account.FieldRateMultiplier, v)
return u
}
// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create.
func (u *AccountUpsert) UpdateRateMultiplier() *AccountUpsert {
u.SetExcluded(account.FieldRateMultiplier)
return u
}
// AddRateMultiplier adds v to the "rate_multiplier" field.
func (u *AccountUpsert) AddRateMultiplier(v float64) *AccountUpsert {
u.Add(account.FieldRateMultiplier, v)
return u
}
// SetStatus sets the "status" field.
func (u *AccountUpsert) SetStatus(v string) *AccountUpsert {
u.Set(account.FieldStatus, v)
@@ -1325,6 +1368,27 @@ func (u *AccountUpsertOne) UpdatePriority() *AccountUpsertOne {
})
}
// SetRateMultiplier sets the "rate_multiplier" field.
func (u *AccountUpsertOne) SetRateMultiplier(v float64) *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
s.SetRateMultiplier(v)
})
}
// AddRateMultiplier adds v to the "rate_multiplier" field.
func (u *AccountUpsertOne) AddRateMultiplier(v float64) *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
s.AddRateMultiplier(v)
})
}
// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create.
func (u *AccountUpsertOne) UpdateRateMultiplier() *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
s.UpdateRateMultiplier()
})
}
// SetStatus sets the "status" field.
func (u *AccountUpsertOne) SetStatus(v string) *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
@@ -1956,6 +2020,27 @@ func (u *AccountUpsertBulk) UpdatePriority() *AccountUpsertBulk {
})
}
// SetRateMultiplier sets the "rate_multiplier" field.
func (u *AccountUpsertBulk) SetRateMultiplier(v float64) *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
s.SetRateMultiplier(v)
})
}
// AddRateMultiplier adds v to the "rate_multiplier" field.
func (u *AccountUpsertBulk) AddRateMultiplier(v float64) *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
s.AddRateMultiplier(v)
})
}
// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create.
func (u *AccountUpsertBulk) UpdateRateMultiplier() *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
s.UpdateRateMultiplier()
})
}
// SetStatus sets the "status" field.
func (u *AccountUpsertBulk) SetStatus(v string) *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {

View File

@@ -193,6 +193,27 @@ func (_u *AccountUpdate) AddPriority(v int) *AccountUpdate {
return _u
}
// SetRateMultiplier sets the "rate_multiplier" field.
func (_u *AccountUpdate) SetRateMultiplier(v float64) *AccountUpdate {
_u.mutation.ResetRateMultiplier()
_u.mutation.SetRateMultiplier(v)
return _u
}
// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil.
func (_u *AccountUpdate) SetNillableRateMultiplier(v *float64) *AccountUpdate {
if v != nil {
_u.SetRateMultiplier(*v)
}
return _u
}
// AddRateMultiplier adds value to the "rate_multiplier" field.
func (_u *AccountUpdate) AddRateMultiplier(v float64) *AccountUpdate {
_u.mutation.AddRateMultiplier(v)
return _u
}
// SetStatus sets the "status" field.
func (_u *AccountUpdate) SetStatus(v string) *AccountUpdate {
_u.mutation.SetStatus(v)
@@ -629,6 +650,12 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.AddedPriority(); ok {
_spec.AddField(account.FieldPriority, field.TypeInt, value)
}
if value, ok := _u.mutation.RateMultiplier(); ok {
_spec.SetField(account.FieldRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
_spec.AddField(account.FieldRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(account.FieldStatus, field.TypeString, value)
}
@@ -1005,6 +1032,27 @@ func (_u *AccountUpdateOne) AddPriority(v int) *AccountUpdateOne {
return _u
}
// SetRateMultiplier sets the "rate_multiplier" field.
func (_u *AccountUpdateOne) SetRateMultiplier(v float64) *AccountUpdateOne {
_u.mutation.ResetRateMultiplier()
_u.mutation.SetRateMultiplier(v)
return _u
}
// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil.
func (_u *AccountUpdateOne) SetNillableRateMultiplier(v *float64) *AccountUpdateOne {
if v != nil {
_u.SetRateMultiplier(*v)
}
return _u
}
// AddRateMultiplier adds value to the "rate_multiplier" field.
func (_u *AccountUpdateOne) AddRateMultiplier(v float64) *AccountUpdateOne {
_u.mutation.AddRateMultiplier(v)
return _u
}
// SetStatus sets the "status" field.
func (_u *AccountUpdateOne) SetStatus(v string) *AccountUpdateOne {
_u.mutation.SetStatus(v)
@@ -1471,6 +1519,12 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er
if value, ok := _u.mutation.AddedPriority(); ok {
_spec.AddField(account.FieldPriority, field.TypeInt, value)
}
if value, ok := _u.mutation.RateMultiplier(); ok {
_spec.SetField(account.FieldRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
_spec.AddField(account.FieldRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(account.FieldStatus, field.TypeString, value)
}

View File

@@ -24,6 +24,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/setting"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
@@ -57,6 +58,8 @@ type Client struct {
RedeemCode *RedeemCodeClient
// Setting is the client for interacting with the Setting builders.
Setting *SettingClient
// UsageCleanupTask is the client for interacting with the UsageCleanupTask builders.
UsageCleanupTask *UsageCleanupTaskClient
// UsageLog is the client for interacting with the UsageLog builders.
UsageLog *UsageLogClient
// User is the client for interacting with the User builders.
@@ -89,6 +92,7 @@ func (c *Client) init() {
c.Proxy = NewProxyClient(c.config)
c.RedeemCode = NewRedeemCodeClient(c.config)
c.Setting = NewSettingClient(c.config)
c.UsageCleanupTask = NewUsageCleanupTaskClient(c.config)
c.UsageLog = NewUsageLogClient(c.config)
c.User = NewUserClient(c.config)
c.UserAllowedGroup = NewUserAllowedGroupClient(c.config)
@@ -196,6 +200,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
Proxy: NewProxyClient(cfg),
RedeemCode: NewRedeemCodeClient(cfg),
Setting: NewSettingClient(cfg),
UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
UsageLog: NewUsageLogClient(cfg),
User: NewUserClient(cfg),
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
@@ -230,6 +235,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
Proxy: NewProxyClient(cfg),
RedeemCode: NewRedeemCodeClient(cfg),
Setting: NewSettingClient(cfg),
UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
UsageLog: NewUsageLogClient(cfg),
User: NewUserClient(cfg),
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
@@ -266,8 +272,9 @@ func (c *Client) Close() error {
func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{
c.APIKey, c.Account, c.AccountGroup, c.Group, c.PromoCode, c.PromoCodeUsage,
c.Proxy, c.RedeemCode, c.Setting, c.UsageLog, c.User, c.UserAllowedGroup,
c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription,
c.Proxy, c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User,
c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
} {
n.Use(hooks...)
}
@@ -278,8 +285,9 @@ func (c *Client) Use(hooks ...Hook) {
func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{
c.APIKey, c.Account, c.AccountGroup, c.Group, c.PromoCode, c.PromoCodeUsage,
c.Proxy, c.RedeemCode, c.Setting, c.UsageLog, c.User, c.UserAllowedGroup,
c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription,
c.Proxy, c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User,
c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
} {
n.Intercept(interceptors...)
}
@@ -306,6 +314,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.RedeemCode.mutate(ctx, m)
case *SettingMutation:
return c.Setting.mutate(ctx, m)
case *UsageCleanupTaskMutation:
return c.UsageCleanupTask.mutate(ctx, m)
case *UsageLogMutation:
return c.UsageLog.mutate(ctx, m)
case *UserMutation:
@@ -1847,6 +1857,139 @@ func (c *SettingClient) mutate(ctx context.Context, m *SettingMutation) (Value,
}
}
// UsageCleanupTaskClient is a client for the UsageCleanupTask schema.
type UsageCleanupTaskClient struct {
config
}
// NewUsageCleanupTaskClient returns a client for the UsageCleanupTask from the given config.
func NewUsageCleanupTaskClient(c config) *UsageCleanupTaskClient {
return &UsageCleanupTaskClient{config: c}
}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `usagecleanuptask.Hooks(f(g(h())))`.
func (c *UsageCleanupTaskClient) Use(hooks ...Hook) {
c.hooks.UsageCleanupTask = append(c.hooks.UsageCleanupTask, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `usagecleanuptask.Intercept(f(g(h())))`.
func (c *UsageCleanupTaskClient) Intercept(interceptors ...Interceptor) {
c.inters.UsageCleanupTask = append(c.inters.UsageCleanupTask, interceptors...)
}
// Create returns a builder for creating a UsageCleanupTask entity.
func (c *UsageCleanupTaskClient) Create() *UsageCleanupTaskCreate {
mutation := newUsageCleanupTaskMutation(c.config, OpCreate)
return &UsageCleanupTaskCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of UsageCleanupTask entities.
func (c *UsageCleanupTaskClient) CreateBulk(builders ...*UsageCleanupTaskCreate) *UsageCleanupTaskCreateBulk {
return &UsageCleanupTaskCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *UsageCleanupTaskClient) MapCreateBulk(slice any, setFunc func(*UsageCleanupTaskCreate, int)) *UsageCleanupTaskCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &UsageCleanupTaskCreateBulk{err: fmt.Errorf("calling to UsageCleanupTaskClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*UsageCleanupTaskCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &UsageCleanupTaskCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for UsageCleanupTask.
func (c *UsageCleanupTaskClient) Update() *UsageCleanupTaskUpdate {
mutation := newUsageCleanupTaskMutation(c.config, OpUpdate)
return &UsageCleanupTaskUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *UsageCleanupTaskClient) UpdateOne(_m *UsageCleanupTask) *UsageCleanupTaskUpdateOne {
mutation := newUsageCleanupTaskMutation(c.config, OpUpdateOne, withUsageCleanupTask(_m))
return &UsageCleanupTaskUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOneID returns an update builder for the given id.
func (c *UsageCleanupTaskClient) UpdateOneID(id int64) *UsageCleanupTaskUpdateOne {
mutation := newUsageCleanupTaskMutation(c.config, OpUpdateOne, withUsageCleanupTaskID(id))
return &UsageCleanupTaskUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// Delete returns a delete builder for UsageCleanupTask.
func (c *UsageCleanupTaskClient) Delete() *UsageCleanupTaskDelete {
mutation := newUsageCleanupTaskMutation(c.config, OpDelete)
return &UsageCleanupTaskDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// DeleteOne returns a builder for deleting the given entity.
func (c *UsageCleanupTaskClient) DeleteOne(_m *UsageCleanupTask) *UsageCleanupTaskDeleteOne {
return c.DeleteOneID(_m.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *UsageCleanupTaskClient) DeleteOneID(id int64) *UsageCleanupTaskDeleteOne {
builder := c.Delete().Where(usagecleanuptask.ID(id))
builder.mutation.id = &id
builder.mutation.op = OpDeleteOne
return &UsageCleanupTaskDeleteOne{builder}
}
// Query returns a query builder for UsageCleanupTask.
func (c *UsageCleanupTaskClient) Query() *UsageCleanupTaskQuery {
return &UsageCleanupTaskQuery{
config: c.config,
ctx: &QueryContext{Type: TypeUsageCleanupTask},
inters: c.Interceptors(),
}
}
// Get returns a UsageCleanupTask entity by its id.
func (c *UsageCleanupTaskClient) Get(ctx context.Context, id int64) (*UsageCleanupTask, error) {
return c.Query().Where(usagecleanuptask.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *UsageCleanupTaskClient) GetX(ctx context.Context, id int64) *UsageCleanupTask {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
// Hooks returns the client hooks.
func (c *UsageCleanupTaskClient) Hooks() []Hook {
return c.hooks.UsageCleanupTask
}
// Interceptors returns the client interceptors.
func (c *UsageCleanupTaskClient) Interceptors() []Interceptor {
return c.inters.UsageCleanupTask
}
func (c *UsageCleanupTaskClient) mutate(ctx context.Context, m *UsageCleanupTaskMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&UsageCleanupTaskCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&UsageCleanupTaskUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&UsageCleanupTaskUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&UsageCleanupTaskDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown UsageCleanupTask mutation op: %q", m.Op())
}
}
// UsageLogClient is a client for the UsageLog schema.
type UsageLogClient struct {
config
@@ -2974,13 +3117,13 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
type (
hooks struct {
APIKey, Account, AccountGroup, Group, PromoCode, PromoCodeUsage, Proxy,
RedeemCode, Setting, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
UserAttributeValue, UserSubscription []ent.Hook
RedeemCode, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook
}
inters struct {
APIKey, Account, AccountGroup, Group, PromoCode, PromoCodeUsage, Proxy,
RedeemCode, Setting, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
UserAttributeValue, UserSubscription []ent.Interceptor
RedeemCode, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor
}
)

View File

@@ -21,6 +21,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/setting"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
@@ -96,6 +97,7 @@ func checkColumn(t, c string) error {
proxy.Table: proxy.ValidColumn,
redeemcode.Table: redeemcode.ValidColumn,
setting.Table: setting.ValidColumn,
usagecleanuptask.Table: usagecleanuptask.ValidColumn,
usagelog.Table: usagelog.ValidColumn,
user.Table: user.ValidColumn,
userallowedgroup.Table: userallowedgroup.ValidColumn,

View File

@@ -3,6 +3,7 @@
package ent
import (
"encoding/json"
"fmt"
"strings"
"time"
@@ -55,6 +56,10 @@ type Group struct {
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
// 非 Claude Code 请求降级使用的分组 ID
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
// 模型路由配置:模型模式 -> 优先账号ID列表
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
// 是否启用模型路由配置
ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the GroupQuery when eager-loading is set.
Edges GroupEdges `json:"edges"`
@@ -161,7 +166,9 @@ func (*Group) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case group.FieldIsExclusive, group.FieldClaudeCodeOnly:
case group.FieldModelRouting:
values[i] = new([]byte)
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled:
values[i] = new(sql.NullBool)
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
values[i] = new(sql.NullFloat64)
@@ -315,6 +322,20 @@ func (_m *Group) assignValues(columns []string, values []any) error {
_m.FallbackGroupID = new(int64)
*_m.FallbackGroupID = value.Int64
}
case group.FieldModelRouting:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field model_routing", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.ModelRouting); err != nil {
return fmt.Errorf("unmarshal field model_routing: %w", err)
}
}
case group.FieldModelRoutingEnabled:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field model_routing_enabled", values[i])
} else if value.Valid {
_m.ModelRoutingEnabled = value.Bool
}
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -465,6 +486,12 @@ func (_m *Group) String() string {
builder.WriteString("fallback_group_id=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("model_routing=")
builder.WriteString(fmt.Sprintf("%v", _m.ModelRouting))
builder.WriteString(", ")
builder.WriteString("model_routing_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.ModelRoutingEnabled))
builder.WriteByte(')')
return builder.String()
}

View File

@@ -53,6 +53,10 @@ const (
FieldClaudeCodeOnly = "claude_code_only"
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
FieldFallbackGroupID = "fallback_group_id"
// FieldModelRouting holds the string denoting the model_routing field in the database.
FieldModelRouting = "model_routing"
// FieldModelRoutingEnabled holds the string denoting the model_routing_enabled field in the database.
FieldModelRoutingEnabled = "model_routing_enabled"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -147,6 +151,8 @@ var Columns = []string{
FieldImagePrice4k,
FieldClaudeCodeOnly,
FieldFallbackGroupID,
FieldModelRouting,
FieldModelRoutingEnabled,
}
var (
@@ -204,6 +210,8 @@ var (
DefaultDefaultValidityDays int
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
DefaultClaudeCodeOnly bool
// DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field.
DefaultModelRoutingEnabled bool
)
// OrderOption defines the ordering options for the Group queries.
@@ -309,6 +317,11 @@ func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc()
}
// ByModelRoutingEnabled orders the results by the model_routing_enabled field.
func ByModelRoutingEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldModelRoutingEnabled, opts...).ToFunc()
}
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {

View File

@@ -150,6 +150,11 @@ func FallbackGroupID(v int64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
}
// ModelRoutingEnabled applies equality check predicate on the "model_routing_enabled" field. It's identical to ModelRoutingEnabledEQ.
func ModelRoutingEnabled(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
@@ -1065,6 +1070,26 @@ func FallbackGroupIDNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID))
}
// ModelRoutingIsNil applies the IsNil predicate on the "model_routing" field.
func ModelRoutingIsNil() predicate.Group {
return predicate.Group(sql.FieldIsNull(FieldModelRouting))
}
// ModelRoutingNotNil applies the NotNil predicate on the "model_routing" field.
func ModelRoutingNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldModelRouting))
}
// ModelRoutingEnabledEQ applies the EQ predicate on the "model_routing_enabled" field.
func ModelRoutingEnabledEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v))
}
// ModelRoutingEnabledNEQ applies the NEQ predicate on the "model_routing_enabled" field.
func ModelRoutingEnabledNEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldModelRoutingEnabled, v))
}
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.Group {
return predicate.Group(func(s *sql.Selector) {

View File

@@ -286,6 +286,26 @@ func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate {
return _c
}
// SetModelRouting sets the "model_routing" field.
func (_c *GroupCreate) SetModelRouting(v map[string][]int64) *GroupCreate {
_c.mutation.SetModelRouting(v)
return _c
}
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
func (_c *GroupCreate) SetModelRoutingEnabled(v bool) *GroupCreate {
_c.mutation.SetModelRoutingEnabled(v)
return _c
}
// SetNillableModelRoutingEnabled sets the "model_routing_enabled" field if the given value is not nil.
func (_c *GroupCreate) SetNillableModelRoutingEnabled(v *bool) *GroupCreate {
if v != nil {
_c.SetModelRoutingEnabled(*v)
}
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -455,6 +475,10 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultClaudeCodeOnly
_c.mutation.SetClaudeCodeOnly(v)
}
if _, ok := _c.mutation.ModelRoutingEnabled(); !ok {
v := group.DefaultModelRoutingEnabled
_c.mutation.SetModelRoutingEnabled(v)
}
return nil
}
@@ -510,6 +534,9 @@ func (_c *GroupCreate) check() error {
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
}
if _, ok := _c.mutation.ModelRoutingEnabled(); !ok {
return &ValidationError{Name: "model_routing_enabled", err: errors.New(`ent: missing required field "Group.model_routing_enabled"`)}
}
return nil
}
@@ -613,6 +640,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
_node.FallbackGroupID = &value
}
if value, ok := _c.mutation.ModelRouting(); ok {
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
_node.ModelRouting = value
}
if value, ok := _c.mutation.ModelRoutingEnabled(); ok {
_spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
_node.ModelRoutingEnabled = value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1093,6 +1128,36 @@ func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert {
return u
}
// SetModelRouting sets the "model_routing" field.
func (u *GroupUpsert) SetModelRouting(v map[string][]int64) *GroupUpsert {
u.Set(group.FieldModelRouting, v)
return u
}
// UpdateModelRouting sets the "model_routing" field to the value that was provided on create.
func (u *GroupUpsert) UpdateModelRouting() *GroupUpsert {
u.SetExcluded(group.FieldModelRouting)
return u
}
// ClearModelRouting clears the value of the "model_routing" field.
func (u *GroupUpsert) ClearModelRouting() *GroupUpsert {
u.SetNull(group.FieldModelRouting)
return u
}
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
func (u *GroupUpsert) SetModelRoutingEnabled(v bool) *GroupUpsert {
u.Set(group.FieldModelRoutingEnabled, v)
return u
}
// UpdateModelRoutingEnabled sets the "model_routing_enabled" field to the value that was provided on create.
func (u *GroupUpsert) UpdateModelRoutingEnabled() *GroupUpsert {
u.SetExcluded(group.FieldModelRoutingEnabled)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -1516,6 +1581,41 @@ func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne {
})
}
// SetModelRouting sets the "model_routing" field.
func (u *GroupUpsertOne) SetModelRouting(v map[string][]int64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetModelRouting(v)
})
}
// UpdateModelRouting sets the "model_routing" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateModelRouting() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateModelRouting()
})
}
// ClearModelRouting clears the value of the "model_routing" field.
func (u *GroupUpsertOne) ClearModelRouting() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.ClearModelRouting()
})
}
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
func (u *GroupUpsertOne) SetModelRoutingEnabled(v bool) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetModelRoutingEnabled(v)
})
}
// UpdateModelRoutingEnabled sets the "model_routing_enabled" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateModelRoutingEnabled() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateModelRoutingEnabled()
})
}
// Exec executes the query.
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -2105,6 +2205,41 @@ func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk {
})
}
// SetModelRouting sets the "model_routing" field.
func (u *GroupUpsertBulk) SetModelRouting(v map[string][]int64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetModelRouting(v)
})
}
// UpdateModelRouting sets the "model_routing" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateModelRouting() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateModelRouting()
})
}
// ClearModelRouting clears the value of the "model_routing" field.
func (u *GroupUpsertBulk) ClearModelRouting() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.ClearModelRouting()
})
}
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
func (u *GroupUpsertBulk) SetModelRoutingEnabled(v bool) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetModelRoutingEnabled(v)
})
}
// UpdateModelRoutingEnabled sets the "model_routing_enabled" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateModelRoutingEnabled() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateModelRoutingEnabled()
})
}
// Exec executes the query.
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {

View File

@@ -395,6 +395,32 @@ func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate {
return _u
}
// SetModelRouting sets the "model_routing" field.
func (_u *GroupUpdate) SetModelRouting(v map[string][]int64) *GroupUpdate {
_u.mutation.SetModelRouting(v)
return _u
}
// ClearModelRouting clears the value of the "model_routing" field.
func (_u *GroupUpdate) ClearModelRouting() *GroupUpdate {
_u.mutation.ClearModelRouting()
return _u
}
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
func (_u *GroupUpdate) SetModelRoutingEnabled(v bool) *GroupUpdate {
_u.mutation.SetModelRoutingEnabled(v)
return _u
}
// SetNillableModelRoutingEnabled sets the "model_routing_enabled" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableModelRoutingEnabled(v *bool) *GroupUpdate {
if v != nil {
_u.SetModelRoutingEnabled(*v)
}
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -803,6 +829,15 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.FallbackGroupIDCleared() {
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
}
if value, ok := _u.mutation.ModelRouting(); ok {
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
}
if _u.mutation.ModelRoutingCleared() {
_spec.ClearField(group.FieldModelRouting, field.TypeJSON)
}
if value, ok := _u.mutation.ModelRoutingEnabled(); ok {
_spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1478,6 +1513,32 @@ func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne {
return _u
}
// SetModelRouting sets the "model_routing" field.
func (_u *GroupUpdateOne) SetModelRouting(v map[string][]int64) *GroupUpdateOne {
_u.mutation.SetModelRouting(v)
return _u
}
// ClearModelRouting clears the value of the "model_routing" field.
func (_u *GroupUpdateOne) ClearModelRouting() *GroupUpdateOne {
_u.mutation.ClearModelRouting()
return _u
}
// SetModelRoutingEnabled sets the "model_routing_enabled" field.
func (_u *GroupUpdateOne) SetModelRoutingEnabled(v bool) *GroupUpdateOne {
_u.mutation.SetModelRoutingEnabled(v)
return _u
}
// SetNillableModelRoutingEnabled sets the "model_routing_enabled" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableModelRoutingEnabled(v *bool) *GroupUpdateOne {
if v != nil {
_u.SetModelRoutingEnabled(*v)
}
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1916,6 +1977,15 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if _u.mutation.FallbackGroupIDCleared() {
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
}
if value, ok := _u.mutation.ModelRouting(); ok {
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
}
if _u.mutation.ModelRoutingCleared() {
_spec.ClearField(group.FieldModelRouting, field.TypeJSON)
}
if value, ok := _u.mutation.ModelRoutingEnabled(); ok {
_spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,

View File

@@ -117,6 +117,18 @@ func (f SettingFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, err
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SettingMutation", m)
}
// The UsageCleanupTaskFunc type is an adapter to allow the use of ordinary
// function as UsageCleanupTask mutator.
type UsageCleanupTaskFunc func(context.Context, *ent.UsageCleanupTaskMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f UsageCleanupTaskFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if mv, ok := m.(*ent.UsageCleanupTaskMutation); ok {
return f(ctx, mv)
}
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UsageCleanupTaskMutation", m)
}
// The UsageLogFunc type is an adapter to allow the use of ordinary
// function as UsageLog mutator.
type UsageLogFunc func(context.Context, *ent.UsageLogMutation) (ent.Value, error)

View File

@@ -18,6 +18,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/setting"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
@@ -325,6 +326,33 @@ func (f TraverseSetting) Traverse(ctx context.Context, q ent.Query) error {
return fmt.Errorf("unexpected query type %T. expect *ent.SettingQuery", q)
}
// The UsageCleanupTaskFunc type is an adapter to allow the use of ordinary function as a Querier.
type UsageCleanupTaskFunc func(context.Context, *ent.UsageCleanupTaskQuery) (ent.Value, error)
// Query calls f(ctx, q).
func (f UsageCleanupTaskFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
if q, ok := q.(*ent.UsageCleanupTaskQuery); ok {
return f(ctx, q)
}
return nil, fmt.Errorf("unexpected query type %T. expect *ent.UsageCleanupTaskQuery", q)
}
// The TraverseUsageCleanupTask type is an adapter to allow the use of ordinary function as Traverser.
type TraverseUsageCleanupTask func(context.Context, *ent.UsageCleanupTaskQuery) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f TraverseUsageCleanupTask) Intercept(next ent.Querier) ent.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f TraverseUsageCleanupTask) Traverse(ctx context.Context, q ent.Query) error {
if q, ok := q.(*ent.UsageCleanupTaskQuery); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect *ent.UsageCleanupTaskQuery", q)
}
// The UsageLogFunc type is an adapter to allow the use of ordinary function as a Querier.
type UsageLogFunc func(context.Context, *ent.UsageLogQuery) (ent.Value, error)
@@ -508,6 +536,8 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.RedeemCodeQuery, predicate.RedeemCode, redeemcode.OrderOption]{typ: ent.TypeRedeemCode, tq: q}, nil
case *ent.SettingQuery:
return &query[*ent.SettingQuery, predicate.Setting, setting.OrderOption]{typ: ent.TypeSetting, tq: q}, nil
case *ent.UsageCleanupTaskQuery:
return &query[*ent.UsageCleanupTaskQuery, predicate.UsageCleanupTask, usagecleanuptask.OrderOption]{typ: ent.TypeUsageCleanupTask, tq: q}, nil
case *ent.UsageLogQuery:
return &query[*ent.UsageLogQuery, predicate.UsageLog, usagelog.OrderOption]{typ: ent.TypeUsageLog, tq: q}, nil
case *ent.UserQuery:

View File

@@ -79,6 +79,7 @@ var (
{Name: "extra", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "concurrency", Type: field.TypeInt, Default: 3},
{Name: "priority", Type: field.TypeInt, Default: 50},
{Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
{Name: "error_message", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "last_used_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
@@ -101,7 +102,7 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "accounts_proxies_proxy",
Columns: []*schema.Column{AccountsColumns[24]},
Columns: []*schema.Column{AccountsColumns[25]},
RefColumns: []*schema.Column{ProxiesColumns[0]},
OnDelete: schema.SetNull,
},
@@ -120,12 +121,12 @@ var (
{
Name: "account_status",
Unique: false,
Columns: []*schema.Column{AccountsColumns[12]},
Columns: []*schema.Column{AccountsColumns[13]},
},
{
Name: "account_proxy_id",
Unique: false,
Columns: []*schema.Column{AccountsColumns[24]},
Columns: []*schema.Column{AccountsColumns[25]},
},
{
Name: "account_priority",
@@ -135,27 +136,27 @@ var (
{
Name: "account_last_used_at",
Unique: false,
Columns: []*schema.Column{AccountsColumns[14]},
Columns: []*schema.Column{AccountsColumns[15]},
},
{
Name: "account_schedulable",
Unique: false,
Columns: []*schema.Column{AccountsColumns[17]},
Columns: []*schema.Column{AccountsColumns[18]},
},
{
Name: "account_rate_limited_at",
Unique: false,
Columns: []*schema.Column{AccountsColumns[18]},
Columns: []*schema.Column{AccountsColumns[19]},
},
{
Name: "account_rate_limit_reset_at",
Unique: false,
Columns: []*schema.Column{AccountsColumns[19]},
Columns: []*schema.Column{AccountsColumns[20]},
},
{
Name: "account_overload_until",
Unique: false,
Columns: []*schema.Column{AccountsColumns[20]},
Columns: []*schema.Column{AccountsColumns[21]},
},
{
Name: "account_deleted_at",
@@ -225,6 +226,8 @@ var (
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
{Name: "model_routing", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "model_routing_enabled", Type: field.TypeBool, Default: false},
}
// GroupsTable holds the schema information for the "groups" table.
GroupsTable = &schema.Table{
@@ -431,6 +434,44 @@ var (
Columns: SettingsColumns,
PrimaryKey: []*schema.Column{SettingsColumns[0]},
}
// UsageCleanupTasksColumns holds the columns for the "usage_cleanup_tasks" table.
UsageCleanupTasksColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "status", Type: field.TypeString, Size: 20},
{Name: "filters", Type: field.TypeJSON},
{Name: "created_by", Type: field.TypeInt64},
{Name: "deleted_rows", Type: field.TypeInt64, Default: 0},
{Name: "error_message", Type: field.TypeString, Nullable: true},
{Name: "canceled_by", Type: field.TypeInt64, Nullable: true},
{Name: "canceled_at", Type: field.TypeTime, Nullable: true},
{Name: "started_at", Type: field.TypeTime, Nullable: true},
{Name: "finished_at", Type: field.TypeTime, Nullable: true},
}
// UsageCleanupTasksTable holds the schema information for the "usage_cleanup_tasks" table.
UsageCleanupTasksTable = &schema.Table{
Name: "usage_cleanup_tasks",
Columns: UsageCleanupTasksColumns,
PrimaryKey: []*schema.Column{UsageCleanupTasksColumns[0]},
Indexes: []*schema.Index{
{
Name: "usagecleanuptask_status_created_at",
Unique: false,
Columns: []*schema.Column{UsageCleanupTasksColumns[3], UsageCleanupTasksColumns[1]},
},
{
Name: "usagecleanuptask_created_at",
Unique: false,
Columns: []*schema.Column{UsageCleanupTasksColumns[1]},
},
{
Name: "usagecleanuptask_canceled_at",
Unique: false,
Columns: []*schema.Column{UsageCleanupTasksColumns[9]},
},
},
}
// UsageLogsColumns holds the columns for the "usage_logs" table.
UsageLogsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -449,6 +490,7 @@ var (
{Name: "total_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "actual_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
{Name: "account_rate_multiplier", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
{Name: "billing_type", Type: field.TypeInt8, Default: 0},
{Name: "stream", Type: field.TypeBool, Default: false},
{Name: "duration_ms", Type: field.TypeInt, Nullable: true},
@@ -472,31 +514,31 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "usage_logs_api_keys_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[25]},
Columns: []*schema.Column{UsageLogsColumns[26]},
RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_accounts_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[26]},
Columns: []*schema.Column{UsageLogsColumns[27]},
RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_groups_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[28]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "usage_logs_users_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[28]},
Columns: []*schema.Column{UsageLogsColumns[29]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_user_subscriptions_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[29]},
Columns: []*schema.Column{UsageLogsColumns[30]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull,
},
@@ -505,32 +547,32 @@ var (
{
Name: "usagelog_user_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[28]},
Columns: []*schema.Column{UsageLogsColumns[29]},
},
{
Name: "usagelog_api_key_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[25]},
Columns: []*schema.Column{UsageLogsColumns[26]},
},
{
Name: "usagelog_account_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[26]},
Columns: []*schema.Column{UsageLogsColumns[27]},
},
{
Name: "usagelog_group_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[28]},
},
{
Name: "usagelog_subscription_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[29]},
Columns: []*schema.Column{UsageLogsColumns[30]},
},
{
Name: "usagelog_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[24]},
Columns: []*schema.Column{UsageLogsColumns[25]},
},
{
Name: "usagelog_model",
@@ -545,12 +587,12 @@ var (
{
Name: "usagelog_user_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[24]},
Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[25]},
},
{
Name: "usagelog_api_key_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[25], UsageLogsColumns[24]},
Columns: []*schema.Column{UsageLogsColumns[26], UsageLogsColumns[25]},
},
},
}
@@ -568,6 +610,9 @@ var (
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
{Name: "username", Type: field.TypeString, Size: 100, Default: ""},
{Name: "notes", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
}
// UsersTable holds the schema information for the "users" table.
UsersTable = &schema.Table{
@@ -801,6 +846,7 @@ var (
ProxiesTable,
RedeemCodesTable,
SettingsTable,
UsageCleanupTasksTable,
UsageLogsTable,
UsersTable,
UserAllowedGroupsTable,
@@ -847,6 +893,9 @@ func init() {
SettingsTable.Annotation = &entsql.Annotation{
Table: "settings",
}
UsageCleanupTasksTable.Annotation = &entsql.Annotation{
Table: "usage_cleanup_tasks",
}
UsageLogsTable.ForeignKeys[0].RefTable = APIKeysTable
UsageLogsTable.ForeignKeys[1].RefTable = AccountsTable
UsageLogsTable.ForeignKeys[2].RefTable = GroupsTable

File diff suppressed because it is too large Load Diff

View File

@@ -33,6 +33,9 @@ type RedeemCode func(*sql.Selector)
// Setting is the predicate function for setting builders.
type Setting func(*sql.Selector)
// UsageCleanupTask is the predicate function for usagecleanuptask builders.
type UsageCleanupTask func(*sql.Selector)
// UsageLog is the predicate function for usagelog builders.
type UsageLog func(*sql.Selector)

View File

@@ -15,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/schema"
"github.com/Wei-Shaw/sub2api/ent/setting"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
@@ -177,22 +178,26 @@ func init() {
accountDescPriority := accountFields[8].Descriptor()
// account.DefaultPriority holds the default value on creation for the priority field.
account.DefaultPriority = accountDescPriority.Default.(int)
// accountDescRateMultiplier is the schema descriptor for rate_multiplier field.
accountDescRateMultiplier := accountFields[9].Descriptor()
// account.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
account.DefaultRateMultiplier = accountDescRateMultiplier.Default.(float64)
// accountDescStatus is the schema descriptor for status field.
accountDescStatus := accountFields[9].Descriptor()
accountDescStatus := accountFields[10].Descriptor()
// account.DefaultStatus holds the default value on creation for the status field.
account.DefaultStatus = accountDescStatus.Default.(string)
// account.StatusValidator is a validator for the "status" field. It is called by the builders before save.
account.StatusValidator = accountDescStatus.Validators[0].(func(string) error)
// accountDescAutoPauseOnExpired is the schema descriptor for auto_pause_on_expired field.
accountDescAutoPauseOnExpired := accountFields[13].Descriptor()
accountDescAutoPauseOnExpired := accountFields[14].Descriptor()
// account.DefaultAutoPauseOnExpired holds the default value on creation for the auto_pause_on_expired field.
account.DefaultAutoPauseOnExpired = accountDescAutoPauseOnExpired.Default.(bool)
// accountDescSchedulable is the schema descriptor for schedulable field.
accountDescSchedulable := accountFields[14].Descriptor()
accountDescSchedulable := accountFields[15].Descriptor()
// account.DefaultSchedulable holds the default value on creation for the schedulable field.
account.DefaultSchedulable = accountDescSchedulable.Default.(bool)
// accountDescSessionWindowStatus is the schema descriptor for session_window_status field.
accountDescSessionWindowStatus := accountFields[20].Descriptor()
accountDescSessionWindowStatus := accountFields[21].Descriptor()
// account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save.
account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error)
accountgroupFields := schema.AccountGroup{}.Fields()
@@ -276,6 +281,10 @@ func init() {
groupDescClaudeCodeOnly := groupFields[14].Descriptor()
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
groupDescModelRoutingEnabled := groupFields[17].Descriptor()
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
promocodeFields := schema.PromoCode{}.Fields()
_ = promocodeFields
// promocodeDescCode is the schema descriptor for code field.
@@ -487,6 +496,43 @@ func init() {
setting.DefaultUpdatedAt = settingDescUpdatedAt.Default.(func() time.Time)
// setting.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
setting.UpdateDefaultUpdatedAt = settingDescUpdatedAt.UpdateDefault.(func() time.Time)
usagecleanuptaskMixin := schema.UsageCleanupTask{}.Mixin()
usagecleanuptaskMixinFields0 := usagecleanuptaskMixin[0].Fields()
_ = usagecleanuptaskMixinFields0
usagecleanuptaskFields := schema.UsageCleanupTask{}.Fields()
_ = usagecleanuptaskFields
// usagecleanuptaskDescCreatedAt is the schema descriptor for created_at field.
usagecleanuptaskDescCreatedAt := usagecleanuptaskMixinFields0[0].Descriptor()
// usagecleanuptask.DefaultCreatedAt holds the default value on creation for the created_at field.
usagecleanuptask.DefaultCreatedAt = usagecleanuptaskDescCreatedAt.Default.(func() time.Time)
// usagecleanuptaskDescUpdatedAt is the schema descriptor for updated_at field.
usagecleanuptaskDescUpdatedAt := usagecleanuptaskMixinFields0[1].Descriptor()
// usagecleanuptask.DefaultUpdatedAt holds the default value on creation for the updated_at field.
usagecleanuptask.DefaultUpdatedAt = usagecleanuptaskDescUpdatedAt.Default.(func() time.Time)
// usagecleanuptask.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
usagecleanuptask.UpdateDefaultUpdatedAt = usagecleanuptaskDescUpdatedAt.UpdateDefault.(func() time.Time)
// usagecleanuptaskDescStatus is the schema descriptor for status field.
usagecleanuptaskDescStatus := usagecleanuptaskFields[0].Descriptor()
// usagecleanuptask.StatusValidator is a validator for the "status" field. It is called by the builders before save.
usagecleanuptask.StatusValidator = func() func(string) error {
validators := usagecleanuptaskDescStatus.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
}
return func(status string) error {
for _, fn := range fns {
if err := fn(status); err != nil {
return err
}
}
return nil
}
}()
// usagecleanuptaskDescDeletedRows is the schema descriptor for deleted_rows field.
usagecleanuptaskDescDeletedRows := usagecleanuptaskFields[3].Descriptor()
// usagecleanuptask.DefaultDeletedRows holds the default value on creation for the deleted_rows field.
usagecleanuptask.DefaultDeletedRows = usagecleanuptaskDescDeletedRows.Default.(int64)
usagelogFields := schema.UsageLog{}.Fields()
_ = usagelogFields
// usagelogDescRequestID is the schema descriptor for request_id field.
@@ -578,31 +624,31 @@ func init() {
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
// usagelogDescBillingType is the schema descriptor for billing_type field.
usagelogDescBillingType := usagelogFields[20].Descriptor()
usagelogDescBillingType := usagelogFields[21].Descriptor()
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
// usagelogDescStream is the schema descriptor for stream field.
usagelogDescStream := usagelogFields[21].Descriptor()
usagelogDescStream := usagelogFields[22].Descriptor()
// usagelog.DefaultStream holds the default value on creation for the stream field.
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
// usagelogDescUserAgent is the schema descriptor for user_agent field.
usagelogDescUserAgent := usagelogFields[24].Descriptor()
usagelogDescUserAgent := usagelogFields[25].Descriptor()
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
// usagelogDescIPAddress is the schema descriptor for ip_address field.
usagelogDescIPAddress := usagelogFields[25].Descriptor()
usagelogDescIPAddress := usagelogFields[26].Descriptor()
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
// usagelogDescImageCount is the schema descriptor for image_count field.
usagelogDescImageCount := usagelogFields[26].Descriptor()
usagelogDescImageCount := usagelogFields[27].Descriptor()
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
// usagelogDescImageSize is the schema descriptor for image_size field.
usagelogDescImageSize := usagelogFields[27].Descriptor()
usagelogDescImageSize := usagelogFields[28].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
// usagelogDescCreatedAt is the schema descriptor for created_at field.
usagelogDescCreatedAt := usagelogFields[28].Descriptor()
usagelogDescCreatedAt := usagelogFields[29].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin()
@@ -690,6 +736,10 @@ func init() {
userDescNotes := userFields[7].Descriptor()
// user.DefaultNotes holds the default value on creation for the notes field.
user.DefaultNotes = userDescNotes.Default.(string)
// userDescTotpEnabled is the schema descriptor for totp_enabled field.
userDescTotpEnabled := userFields[9].Descriptor()
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
_ = userallowedgroupFields
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.

View File

@@ -102,6 +102,12 @@ func (Account) Fields() []ent.Field {
field.Int("priority").
Default(50),
// rate_multiplier: 账号计费倍率(>=0允许 0 表示该账号计费为 0
// 仅影响账号维度计费口径,不影响用户/API Key 扣费(分组倍率)
field.Float("rate_multiplier").
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}).
Default(1.0),
// status: 账户状态,如 "active", "error", "disabled"
field.String("status").
MaxLen(20).

View File

@@ -95,6 +95,17 @@ func (Group) Fields() []ent.Field {
Optional().
Nillable().
Comment("非 Claude Code 请求降级使用的分组 ID"),
// 模型路由配置 (added by migration 040)
field.JSON("model_routing", map[string][]int64{}).
Optional().
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
Comment("模型路由配置:模型模式 -> 优先账号ID列表"),
// 模型路由开关 (added by migration 041)
field.Bool("model_routing_enabled").
Default(false).
Comment("是否启用模型路由配置"),
}
}

View File

@@ -5,6 +5,7 @@ package mixins
import (
"context"
"fmt"
"reflect"
"time"
"entgo.io/ent"
@@ -12,7 +13,6 @@ import (
"entgo.io/ent/dialect/sql"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/mixin"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/intercept"
)
@@ -113,7 +113,6 @@ func (d SoftDeleteMixin) Hooks() []ent.Hook {
SetOp(ent.Op)
SetDeletedAt(time.Time)
WhereP(...func(*sql.Selector))
Client() *dbent.Client
})
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
@@ -124,7 +123,7 @@ func (d SoftDeleteMixin) Hooks() []ent.Hook {
mx.SetOp(ent.OpUpdate)
// 设置删除时间为当前时间
mx.SetDeletedAt(time.Now())
return mx.Client().Mutate(ctx, m)
return mutateWithClient(ctx, m, next)
})
},
}
@@ -137,3 +136,41 @@ func (d SoftDeleteMixin) applyPredicate(w interface{ WhereP(...func(*sql.Selecto
sql.FieldIsNull(d.Fields()[0].Descriptor().Name),
)
}
func mutateWithClient(ctx context.Context, m ent.Mutation, fallback ent.Mutator) (ent.Value, error) {
clientMethod := reflect.ValueOf(m).MethodByName("Client")
if !clientMethod.IsValid() || clientMethod.Type().NumIn() != 0 || clientMethod.Type().NumOut() != 1 {
return nil, fmt.Errorf("soft delete: mutation client method not found for %T", m)
}
client := clientMethod.Call(nil)[0]
mutateMethod := client.MethodByName("Mutate")
if !mutateMethod.IsValid() {
return nil, fmt.Errorf("soft delete: mutation client missing Mutate for %T", m)
}
if mutateMethod.Type().NumIn() != 2 || mutateMethod.Type().NumOut() != 2 {
return nil, fmt.Errorf("soft delete: mutation client signature mismatch for %T", m)
}
results := mutateMethod.Call([]reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(m)})
value := results[0].Interface()
var err error
if !results[1].IsNil() {
errValue := results[1].Interface()
typedErr, ok := errValue.(error)
if !ok {
return nil, fmt.Errorf("soft delete: unexpected error type %T for %T", errValue, m)
}
err = typedErr
}
if err != nil {
return nil, err
}
if value == nil {
return nil, fmt.Errorf("soft delete: mutation client returned nil for %T", m)
}
v, ok := value.(ent.Value)
if !ok {
return nil, fmt.Errorf("soft delete: unexpected value type %T for %T", value, m)
}
return v, nil
}

View File

@@ -0,0 +1,75 @@
package schema
import (
"encoding/json"
"fmt"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"entgo.io/ent"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// UsageCleanupTask 定义使用记录清理任务的 schema。
type UsageCleanupTask struct {
ent.Schema
}
func (UsageCleanupTask) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "usage_cleanup_tasks"},
}
}
func (UsageCleanupTask) Mixin() []ent.Mixin {
return []ent.Mixin{
mixins.TimeMixin{},
}
}
func (UsageCleanupTask) Fields() []ent.Field {
return []ent.Field{
field.String("status").
MaxLen(20).
Validate(validateUsageCleanupStatus),
field.JSON("filters", json.RawMessage{}),
field.Int64("created_by"),
field.Int64("deleted_rows").
Default(0),
field.String("error_message").
Optional().
Nillable(),
field.Int64("canceled_by").
Optional().
Nillable(),
field.Time("canceled_at").
Optional().
Nillable(),
field.Time("started_at").
Optional().
Nillable(),
field.Time("finished_at").
Optional().
Nillable(),
}
}
func (UsageCleanupTask) Indexes() []ent.Index {
return []ent.Index{
index.Fields("status", "created_at"),
index.Fields("created_at"),
index.Fields("canceled_at"),
}
}
func validateUsageCleanupStatus(status string) error {
switch status {
case "pending", "running", "succeeded", "failed", "canceled":
return nil
default:
return fmt.Errorf("invalid usage cleanup status: %s", status)
}
}

View File

@@ -85,6 +85,12 @@ func (UsageLog) Fields() []ent.Field {
Default(1).
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}),
// account_rate_multiplier: 账号计费倍率快照NULL 表示按 1.0 处理)
field.Float("account_rate_multiplier").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}),
// 其他字段
field.Int8("billing_type").
Default(0),

View File

@@ -61,6 +61,17 @@ func (User) Fields() []ent.Field {
field.String("notes").
SchemaType(map[string]string{dialect.Postgres: "text"}).
Default(""),
// TOTP 双因素认证字段
field.String("totp_secret_encrypted").
SchemaType(map[string]string{dialect.Postgres: "text"}).
Optional().
Nillable(),
field.Bool("totp_enabled").
Default(false),
field.Time("totp_enabled_at").
Optional().
Nillable(),
}
}

View File

@@ -32,6 +32,8 @@ type Tx struct {
RedeemCode *RedeemCodeClient
// Setting is the client for interacting with the Setting builders.
Setting *SettingClient
// UsageCleanupTask is the client for interacting with the UsageCleanupTask builders.
UsageCleanupTask *UsageCleanupTaskClient
// UsageLog is the client for interacting with the UsageLog builders.
UsageLog *UsageLogClient
// User is the client for interacting with the User builders.
@@ -184,6 +186,7 @@ func (tx *Tx) init() {
tx.Proxy = NewProxyClient(tx.config)
tx.RedeemCode = NewRedeemCodeClient(tx.config)
tx.Setting = NewSettingClient(tx.config)
tx.UsageCleanupTask = NewUsageCleanupTaskClient(tx.config)
tx.UsageLog = NewUsageLogClient(tx.config)
tx.User = NewUserClient(tx.config)
tx.UserAllowedGroup = NewUserAllowedGroupClient(tx.config)

View File

@@ -0,0 +1,236 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"encoding/json"
"fmt"
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
)
// UsageCleanupTask is the model entity for the UsageCleanupTask schema.
type UsageCleanupTask struct {
config `json:"-"`
// ID of the ent.
ID int64 `json:"id,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// UpdatedAt holds the value of the "updated_at" field.
UpdatedAt time.Time `json:"updated_at,omitempty"`
// Status holds the value of the "status" field.
Status string `json:"status,omitempty"`
// Filters holds the value of the "filters" field.
Filters json.RawMessage `json:"filters,omitempty"`
// CreatedBy holds the value of the "created_by" field.
CreatedBy int64 `json:"created_by,omitempty"`
// DeletedRows holds the value of the "deleted_rows" field.
DeletedRows int64 `json:"deleted_rows,omitempty"`
// ErrorMessage holds the value of the "error_message" field.
ErrorMessage *string `json:"error_message,omitempty"`
// CanceledBy holds the value of the "canceled_by" field.
CanceledBy *int64 `json:"canceled_by,omitempty"`
// CanceledAt holds the value of the "canceled_at" field.
CanceledAt *time.Time `json:"canceled_at,omitempty"`
// StartedAt holds the value of the "started_at" field.
StartedAt *time.Time `json:"started_at,omitempty"`
// FinishedAt holds the value of the "finished_at" field.
FinishedAt *time.Time `json:"finished_at,omitempty"`
selectValues sql.SelectValues
}
// scanValues returns the types for scanning values from sql.Rows.
func (*UsageCleanupTask) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case usagecleanuptask.FieldFilters:
values[i] = new([]byte)
case usagecleanuptask.FieldID, usagecleanuptask.FieldCreatedBy, usagecleanuptask.FieldDeletedRows, usagecleanuptask.FieldCanceledBy:
values[i] = new(sql.NullInt64)
case usagecleanuptask.FieldStatus, usagecleanuptask.FieldErrorMessage:
values[i] = new(sql.NullString)
case usagecleanuptask.FieldCreatedAt, usagecleanuptask.FieldUpdatedAt, usagecleanuptask.FieldCanceledAt, usagecleanuptask.FieldStartedAt, usagecleanuptask.FieldFinishedAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
}
}
return values, nil
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the UsageCleanupTask fields.
func (_m *UsageCleanupTask) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
for i := range columns {
switch columns[i] {
case usagecleanuptask.FieldID:
value, ok := values[i].(*sql.NullInt64)
if !ok {
return fmt.Errorf("unexpected type %T for field id", value)
}
_m.ID = int64(value.Int64)
case usagecleanuptask.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
} else if value.Valid {
_m.CreatedAt = value.Time
}
case usagecleanuptask.FieldUpdatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
} else if value.Valid {
_m.UpdatedAt = value.Time
}
case usagecleanuptask.FieldStatus:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field status", values[i])
} else if value.Valid {
_m.Status = value.String
}
case usagecleanuptask.FieldFilters:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field filters", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.Filters); err != nil {
return fmt.Errorf("unmarshal field filters: %w", err)
}
}
case usagecleanuptask.FieldCreatedBy:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field created_by", values[i])
} else if value.Valid {
_m.CreatedBy = value.Int64
}
case usagecleanuptask.FieldDeletedRows:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field deleted_rows", values[i])
} else if value.Valid {
_m.DeletedRows = value.Int64
}
case usagecleanuptask.FieldErrorMessage:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field error_message", values[i])
} else if value.Valid {
_m.ErrorMessage = new(string)
*_m.ErrorMessage = value.String
}
case usagecleanuptask.FieldCanceledBy:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field canceled_by", values[i])
} else if value.Valid {
_m.CanceledBy = new(int64)
*_m.CanceledBy = value.Int64
}
case usagecleanuptask.FieldCanceledAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field canceled_at", values[i])
} else if value.Valid {
_m.CanceledAt = new(time.Time)
*_m.CanceledAt = value.Time
}
case usagecleanuptask.FieldStartedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field started_at", values[i])
} else if value.Valid {
_m.StartedAt = new(time.Time)
*_m.StartedAt = value.Time
}
case usagecleanuptask.FieldFinishedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field finished_at", values[i])
} else if value.Valid {
_m.FinishedAt = new(time.Time)
*_m.FinishedAt = value.Time
}
default:
_m.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// Value returns the ent.Value that was dynamically selected and assigned to the UsageCleanupTask.
// This includes values selected through modifiers, order, etc.
func (_m *UsageCleanupTask) Value(name string) (ent.Value, error) {
return _m.selectValues.Get(name)
}
// Update returns a builder for updating this UsageCleanupTask.
// Note that you need to call UsageCleanupTask.Unwrap() before calling this method if this UsageCleanupTask
// was returned from a transaction, and the transaction was committed or rolled back.
func (_m *UsageCleanupTask) Update() *UsageCleanupTaskUpdateOne {
return NewUsageCleanupTaskClient(_m.config).UpdateOne(_m)
}
// Unwrap unwraps the UsageCleanupTask entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func (_m *UsageCleanupTask) Unwrap() *UsageCleanupTask {
_tx, ok := _m.config.driver.(*txDriver)
if !ok {
panic("ent: UsageCleanupTask is not a transactional entity")
}
_m.config.driver = _tx.drv
return _m
}
// String implements the fmt.Stringer.
func (_m *UsageCleanupTask) String() string {
var builder strings.Builder
builder.WriteString("UsageCleanupTask(")
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("updated_at=")
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("status=")
builder.WriteString(_m.Status)
builder.WriteString(", ")
builder.WriteString("filters=")
builder.WriteString(fmt.Sprintf("%v", _m.Filters))
builder.WriteString(", ")
builder.WriteString("created_by=")
builder.WriteString(fmt.Sprintf("%v", _m.CreatedBy))
builder.WriteString(", ")
builder.WriteString("deleted_rows=")
builder.WriteString(fmt.Sprintf("%v", _m.DeletedRows))
builder.WriteString(", ")
if v := _m.ErrorMessage; v != nil {
builder.WriteString("error_message=")
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.CanceledBy; v != nil {
builder.WriteString("canceled_by=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
if v := _m.CanceledAt; v != nil {
builder.WriteString("canceled_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
if v := _m.StartedAt; v != nil {
builder.WriteString("started_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
if v := _m.FinishedAt; v != nil {
builder.WriteString("finished_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteByte(')')
return builder.String()
}
// UsageCleanupTasks is a parsable slice of UsageCleanupTask.
type UsageCleanupTasks []*UsageCleanupTask

View File

@@ -0,0 +1,137 @@
// Code generated by ent, DO NOT EDIT.
package usagecleanuptask
import (
"time"
"entgo.io/ent/dialect/sql"
)
const (
// Label holds the string label denoting the usagecleanuptask type in the database.
Label = "usage_cleanup_task"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
FieldUpdatedAt = "updated_at"
// FieldStatus holds the string denoting the status field in the database.
FieldStatus = "status"
// FieldFilters holds the string denoting the filters field in the database.
FieldFilters = "filters"
// FieldCreatedBy holds the string denoting the created_by field in the database.
FieldCreatedBy = "created_by"
// FieldDeletedRows holds the string denoting the deleted_rows field in the database.
FieldDeletedRows = "deleted_rows"
// FieldErrorMessage holds the string denoting the error_message field in the database.
FieldErrorMessage = "error_message"
// FieldCanceledBy holds the string denoting the canceled_by field in the database.
FieldCanceledBy = "canceled_by"
// FieldCanceledAt holds the string denoting the canceled_at field in the database.
FieldCanceledAt = "canceled_at"
// FieldStartedAt holds the string denoting the started_at field in the database.
FieldStartedAt = "started_at"
// FieldFinishedAt holds the string denoting the finished_at field in the database.
FieldFinishedAt = "finished_at"
// Table holds the table name of the usagecleanuptask in the database.
Table = "usage_cleanup_tasks"
)
// Columns holds all SQL columns for usagecleanuptask fields.
var Columns = []string{
FieldID,
FieldCreatedAt,
FieldUpdatedAt,
FieldStatus,
FieldFilters,
FieldCreatedBy,
FieldDeletedRows,
FieldErrorMessage,
FieldCanceledBy,
FieldCanceledAt,
FieldStartedAt,
FieldFinishedAt,
}
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
if column == Columns[i] {
return true
}
}
return false
}
var (
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
DefaultUpdatedAt func() time.Time
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
UpdateDefaultUpdatedAt func() time.Time
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
StatusValidator func(string) error
// DefaultDeletedRows holds the default value on creation for the "deleted_rows" field.
DefaultDeletedRows int64
)
// OrderOption defines the ordering options for the UsageCleanupTask queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByUpdatedAt orders the results by the updated_at field.
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// ByStatus orders the results by the status field.
func ByStatus(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStatus, opts...).ToFunc()
}
// ByCreatedBy orders the results by the created_by field.
func ByCreatedBy(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedBy, opts...).ToFunc()
}
// ByDeletedRows orders the results by the deleted_rows field.
func ByDeletedRows(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDeletedRows, opts...).ToFunc()
}
// ByErrorMessage orders the results by the error_message field.
func ByErrorMessage(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldErrorMessage, opts...).ToFunc()
}
// ByCanceledBy orders the results by the canceled_by field.
func ByCanceledBy(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCanceledBy, opts...).ToFunc()
}
// ByCanceledAt orders the results by the canceled_at field.
func ByCanceledAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCanceledAt, opts...).ToFunc()
}
// ByStartedAt orders the results by the started_at field.
func ByStartedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStartedAt, opts...).ToFunc()
}
// ByFinishedAt orders the results by the finished_at field.
func ByFinishedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldFinishedAt, opts...).ToFunc()
}

View File

@@ -0,0 +1,620 @@
// Code generated by ent, DO NOT EDIT.
package usagecleanuptask
import (
"time"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// ID filters vertices based on their ID field.
func ID(id int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldID, id))
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldID, id))
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldID, id))
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIn(FieldID, ids...))
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldID, ids...))
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGT(FieldID, id))
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGTE(FieldID, id))
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLT(FieldID, id))
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLTE(FieldID, id))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldCreatedAt, v))
}
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
func UpdatedAt(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldUpdatedAt, v))
}
// Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
func Status(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldStatus, v))
}
// CreatedBy applies equality check predicate on the "created_by" field. It's identical to CreatedByEQ.
func CreatedBy(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldCreatedBy, v))
}
// DeletedRows applies equality check predicate on the "deleted_rows" field. It's identical to DeletedRowsEQ.
func DeletedRows(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldDeletedRows, v))
}
// ErrorMessage applies equality check predicate on the "error_message" field. It's identical to ErrorMessageEQ.
func ErrorMessage(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldErrorMessage, v))
}
// CanceledBy applies equality check predicate on the "canceled_by" field. It's identical to CanceledByEQ.
func CanceledBy(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldCanceledBy, v))
}
// CanceledAt applies equality check predicate on the "canceled_at" field. It's identical to CanceledAtEQ.
func CanceledAt(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldCanceledAt, v))
}
// StartedAt applies equality check predicate on the "started_at" field. It's identical to StartedAtEQ.
func StartedAt(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldStartedAt, v))
}
// FinishedAt applies equality check predicate on the "finished_at" field. It's identical to FinishedAtEQ.
func FinishedAt(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldFinishedAt, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldCreatedAt, v))
}
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
func CreatedAtNEQ(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldCreatedAt, v))
}
// CreatedAtIn applies the In predicate on the "created_at" field.
func CreatedAtIn(vs ...time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIn(FieldCreatedAt, vs...))
}
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
func CreatedAtNotIn(vs ...time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldCreatedAt, vs...))
}
// CreatedAtGT applies the GT predicate on the "created_at" field.
func CreatedAtGT(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGT(FieldCreatedAt, v))
}
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
func CreatedAtGTE(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGTE(FieldCreatedAt, v))
}
// CreatedAtLT applies the LT predicate on the "created_at" field.
func CreatedAtLT(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLT(FieldCreatedAt, v))
}
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
func CreatedAtLTE(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLTE(FieldCreatedAt, v))
}
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
func UpdatedAtEQ(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldUpdatedAt, v))
}
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
func UpdatedAtNEQ(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldUpdatedAt, v))
}
// UpdatedAtIn applies the In predicate on the "updated_at" field.
func UpdatedAtIn(vs ...time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIn(FieldUpdatedAt, vs...))
}
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
func UpdatedAtNotIn(vs ...time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldUpdatedAt, vs...))
}
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
func UpdatedAtGT(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGT(FieldUpdatedAt, v))
}
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
func UpdatedAtGTE(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGTE(FieldUpdatedAt, v))
}
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
func UpdatedAtLT(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLT(FieldUpdatedAt, v))
}
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
func UpdatedAtLTE(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLTE(FieldUpdatedAt, v))
}
// StatusEQ applies the EQ predicate on the "status" field.
func StatusEQ(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldStatus, v))
}
// StatusNEQ applies the NEQ predicate on the "status" field.
func StatusNEQ(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldStatus, v))
}
// StatusIn applies the In predicate on the "status" field.
func StatusIn(vs ...string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIn(FieldStatus, vs...))
}
// StatusNotIn applies the NotIn predicate on the "status" field.
func StatusNotIn(vs ...string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldStatus, vs...))
}
// StatusGT applies the GT predicate on the "status" field.
func StatusGT(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGT(FieldStatus, v))
}
// StatusGTE applies the GTE predicate on the "status" field.
func StatusGTE(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGTE(FieldStatus, v))
}
// StatusLT applies the LT predicate on the "status" field.
func StatusLT(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLT(FieldStatus, v))
}
// StatusLTE applies the LTE predicate on the "status" field.
func StatusLTE(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLTE(FieldStatus, v))
}
// StatusContains applies the Contains predicate on the "status" field.
func StatusContains(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldContains(FieldStatus, v))
}
// StatusHasPrefix applies the HasPrefix predicate on the "status" field.
func StatusHasPrefix(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldHasPrefix(FieldStatus, v))
}
// StatusHasSuffix applies the HasSuffix predicate on the "status" field.
func StatusHasSuffix(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldHasSuffix(FieldStatus, v))
}
// StatusEqualFold applies the EqualFold predicate on the "status" field.
func StatusEqualFold(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEqualFold(FieldStatus, v))
}
// StatusContainsFold applies the ContainsFold predicate on the "status" field.
func StatusContainsFold(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldContainsFold(FieldStatus, v))
}
// CreatedByEQ applies the EQ predicate on the "created_by" field.
func CreatedByEQ(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldCreatedBy, v))
}
// CreatedByNEQ applies the NEQ predicate on the "created_by" field.
func CreatedByNEQ(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldCreatedBy, v))
}
// CreatedByIn applies the In predicate on the "created_by" field.
func CreatedByIn(vs ...int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIn(FieldCreatedBy, vs...))
}
// CreatedByNotIn applies the NotIn predicate on the "created_by" field.
func CreatedByNotIn(vs ...int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldCreatedBy, vs...))
}
// CreatedByGT applies the GT predicate on the "created_by" field.
func CreatedByGT(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGT(FieldCreatedBy, v))
}
// CreatedByGTE applies the GTE predicate on the "created_by" field.
func CreatedByGTE(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGTE(FieldCreatedBy, v))
}
// CreatedByLT applies the LT predicate on the "created_by" field.
func CreatedByLT(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLT(FieldCreatedBy, v))
}
// CreatedByLTE applies the LTE predicate on the "created_by" field.
func CreatedByLTE(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLTE(FieldCreatedBy, v))
}
// DeletedRowsEQ applies the EQ predicate on the "deleted_rows" field.
func DeletedRowsEQ(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldDeletedRows, v))
}
// DeletedRowsNEQ applies the NEQ predicate on the "deleted_rows" field.
func DeletedRowsNEQ(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldDeletedRows, v))
}
// DeletedRowsIn applies the In predicate on the "deleted_rows" field.
func DeletedRowsIn(vs ...int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIn(FieldDeletedRows, vs...))
}
// DeletedRowsNotIn applies the NotIn predicate on the "deleted_rows" field.
func DeletedRowsNotIn(vs ...int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldDeletedRows, vs...))
}
// DeletedRowsGT applies the GT predicate on the "deleted_rows" field.
func DeletedRowsGT(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGT(FieldDeletedRows, v))
}
// DeletedRowsGTE applies the GTE predicate on the "deleted_rows" field.
func DeletedRowsGTE(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGTE(FieldDeletedRows, v))
}
// DeletedRowsLT applies the LT predicate on the "deleted_rows" field.
func DeletedRowsLT(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLT(FieldDeletedRows, v))
}
// DeletedRowsLTE applies the LTE predicate on the "deleted_rows" field.
func DeletedRowsLTE(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLTE(FieldDeletedRows, v))
}
// ErrorMessageEQ applies the EQ predicate on the "error_message" field.
func ErrorMessageEQ(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldErrorMessage, v))
}
// ErrorMessageNEQ applies the NEQ predicate on the "error_message" field.
func ErrorMessageNEQ(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldErrorMessage, v))
}
// ErrorMessageIn applies the In predicate on the "error_message" field.
func ErrorMessageIn(vs ...string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIn(FieldErrorMessage, vs...))
}
// ErrorMessageNotIn applies the NotIn predicate on the "error_message" field.
func ErrorMessageNotIn(vs ...string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldErrorMessage, vs...))
}
// ErrorMessageGT applies the GT predicate on the "error_message" field.
func ErrorMessageGT(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGT(FieldErrorMessage, v))
}
// ErrorMessageGTE applies the GTE predicate on the "error_message" field.
func ErrorMessageGTE(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGTE(FieldErrorMessage, v))
}
// ErrorMessageLT applies the LT predicate on the "error_message" field.
func ErrorMessageLT(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLT(FieldErrorMessage, v))
}
// ErrorMessageLTE applies the LTE predicate on the "error_message" field.
func ErrorMessageLTE(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLTE(FieldErrorMessage, v))
}
// ErrorMessageContains applies the Contains predicate on the "error_message" field.
func ErrorMessageContains(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldContains(FieldErrorMessage, v))
}
// ErrorMessageHasPrefix applies the HasPrefix predicate on the "error_message" field.
func ErrorMessageHasPrefix(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldHasPrefix(FieldErrorMessage, v))
}
// ErrorMessageHasSuffix applies the HasSuffix predicate on the "error_message" field.
func ErrorMessageHasSuffix(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldHasSuffix(FieldErrorMessage, v))
}
// ErrorMessageIsNil applies the IsNil predicate on the "error_message" field.
func ErrorMessageIsNil() predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIsNull(FieldErrorMessage))
}
// ErrorMessageNotNil applies the NotNil predicate on the "error_message" field.
func ErrorMessageNotNil() predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotNull(FieldErrorMessage))
}
// ErrorMessageEqualFold applies the EqualFold predicate on the "error_message" field.
func ErrorMessageEqualFold(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEqualFold(FieldErrorMessage, v))
}
// ErrorMessageContainsFold applies the ContainsFold predicate on the "error_message" field.
func ErrorMessageContainsFold(v string) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldContainsFold(FieldErrorMessage, v))
}
// CanceledByEQ applies the EQ predicate on the "canceled_by" field.
func CanceledByEQ(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldCanceledBy, v))
}
// CanceledByNEQ applies the NEQ predicate on the "canceled_by" field.
func CanceledByNEQ(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldCanceledBy, v))
}
// CanceledByIn applies the In predicate on the "canceled_by" field.
func CanceledByIn(vs ...int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIn(FieldCanceledBy, vs...))
}
// CanceledByNotIn applies the NotIn predicate on the "canceled_by" field.
func CanceledByNotIn(vs ...int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldCanceledBy, vs...))
}
// CanceledByGT applies the GT predicate on the "canceled_by" field.
func CanceledByGT(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGT(FieldCanceledBy, v))
}
// CanceledByGTE applies the GTE predicate on the "canceled_by" field.
func CanceledByGTE(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGTE(FieldCanceledBy, v))
}
// CanceledByLT applies the LT predicate on the "canceled_by" field.
func CanceledByLT(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLT(FieldCanceledBy, v))
}
// CanceledByLTE applies the LTE predicate on the "canceled_by" field.
func CanceledByLTE(v int64) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLTE(FieldCanceledBy, v))
}
// CanceledByIsNil applies the IsNil predicate on the "canceled_by" field.
func CanceledByIsNil() predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIsNull(FieldCanceledBy))
}
// CanceledByNotNil applies the NotNil predicate on the "canceled_by" field.
func CanceledByNotNil() predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotNull(FieldCanceledBy))
}
// CanceledAtEQ applies the EQ predicate on the "canceled_at" field.
func CanceledAtEQ(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldCanceledAt, v))
}
// CanceledAtNEQ applies the NEQ predicate on the "canceled_at" field.
func CanceledAtNEQ(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldCanceledAt, v))
}
// CanceledAtIn applies the In predicate on the "canceled_at" field.
func CanceledAtIn(vs ...time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIn(FieldCanceledAt, vs...))
}
// CanceledAtNotIn applies the NotIn predicate on the "canceled_at" field.
func CanceledAtNotIn(vs ...time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldCanceledAt, vs...))
}
// CanceledAtGT applies the GT predicate on the "canceled_at" field.
func CanceledAtGT(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGT(FieldCanceledAt, v))
}
// CanceledAtGTE applies the GTE predicate on the "canceled_at" field.
func CanceledAtGTE(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGTE(FieldCanceledAt, v))
}
// CanceledAtLT applies the LT predicate on the "canceled_at" field.
func CanceledAtLT(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLT(FieldCanceledAt, v))
}
// CanceledAtLTE applies the LTE predicate on the "canceled_at" field.
func CanceledAtLTE(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLTE(FieldCanceledAt, v))
}
// CanceledAtIsNil applies the IsNil predicate on the "canceled_at" field.
func CanceledAtIsNil() predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIsNull(FieldCanceledAt))
}
// CanceledAtNotNil applies the NotNil predicate on the "canceled_at" field.
func CanceledAtNotNil() predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotNull(FieldCanceledAt))
}
// StartedAtEQ applies the EQ predicate on the "started_at" field.
func StartedAtEQ(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldStartedAt, v))
}
// StartedAtNEQ applies the NEQ predicate on the "started_at" field.
func StartedAtNEQ(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldStartedAt, v))
}
// StartedAtIn applies the In predicate on the "started_at" field.
func StartedAtIn(vs ...time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIn(FieldStartedAt, vs...))
}
// StartedAtNotIn applies the NotIn predicate on the "started_at" field.
func StartedAtNotIn(vs ...time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldStartedAt, vs...))
}
// StartedAtGT applies the GT predicate on the "started_at" field.
func StartedAtGT(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGT(FieldStartedAt, v))
}
// StartedAtGTE applies the GTE predicate on the "started_at" field.
func StartedAtGTE(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGTE(FieldStartedAt, v))
}
// StartedAtLT applies the LT predicate on the "started_at" field.
func StartedAtLT(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLT(FieldStartedAt, v))
}
// StartedAtLTE applies the LTE predicate on the "started_at" field.
func StartedAtLTE(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLTE(FieldStartedAt, v))
}
// StartedAtIsNil applies the IsNil predicate on the "started_at" field.
func StartedAtIsNil() predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIsNull(FieldStartedAt))
}
// StartedAtNotNil applies the NotNil predicate on the "started_at" field.
func StartedAtNotNil() predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotNull(FieldStartedAt))
}
// FinishedAtEQ applies the EQ predicate on the "finished_at" field.
func FinishedAtEQ(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldEQ(FieldFinishedAt, v))
}
// FinishedAtNEQ applies the NEQ predicate on the "finished_at" field.
func FinishedAtNEQ(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNEQ(FieldFinishedAt, v))
}
// FinishedAtIn applies the In predicate on the "finished_at" field.
func FinishedAtIn(vs ...time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIn(FieldFinishedAt, vs...))
}
// FinishedAtNotIn applies the NotIn predicate on the "finished_at" field.
func FinishedAtNotIn(vs ...time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotIn(FieldFinishedAt, vs...))
}
// FinishedAtGT applies the GT predicate on the "finished_at" field.
func FinishedAtGT(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGT(FieldFinishedAt, v))
}
// FinishedAtGTE applies the GTE predicate on the "finished_at" field.
func FinishedAtGTE(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldGTE(FieldFinishedAt, v))
}
// FinishedAtLT applies the LT predicate on the "finished_at" field.
func FinishedAtLT(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLT(FieldFinishedAt, v))
}
// FinishedAtLTE applies the LTE predicate on the "finished_at" field.
func FinishedAtLTE(v time.Time) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldLTE(FieldFinishedAt, v))
}
// FinishedAtIsNil applies the IsNil predicate on the "finished_at" field.
func FinishedAtIsNil() predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldIsNull(FieldFinishedAt))
}
// FinishedAtNotNil applies the NotNil predicate on the "finished_at" field.
func FinishedAtNotNil() predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.FieldNotNull(FieldFinishedAt))
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.UsageCleanupTask) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.AndPredicates(predicates...))
}
// Or groups predicates with the OR operator between them.
func Or(predicates ...predicate.UsageCleanupTask) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.OrPredicates(predicates...))
}
// Not applies the not operator on the given predicate.
func Not(p predicate.UsageCleanupTask) predicate.UsageCleanupTask {
return predicate.UsageCleanupTask(sql.NotPredicates(p))
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,88 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
)
// UsageCleanupTaskDelete is the builder for deleting a UsageCleanupTask entity.
type UsageCleanupTaskDelete struct {
config
hooks []Hook
mutation *UsageCleanupTaskMutation
}
// Where appends a list predicates to the UsageCleanupTaskDelete builder.
func (_d *UsageCleanupTaskDelete) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskDelete {
_d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (_d *UsageCleanupTaskDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *UsageCleanupTaskDelete) ExecX(ctx context.Context) int {
n, err := _d.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (_d *UsageCleanupTaskDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(usagecleanuptask.Table, sqlgraph.NewFieldSpec(usagecleanuptask.FieldID, field.TypeInt64))
if ps := _d.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
_d.mutation.done = true
return affected, err
}
// UsageCleanupTaskDeleteOne is the builder for deleting a single UsageCleanupTask entity.
type UsageCleanupTaskDeleteOne struct {
_d *UsageCleanupTaskDelete
}
// Where appends a list predicates to the UsageCleanupTaskDelete builder.
func (_d *UsageCleanupTaskDeleteOne) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskDeleteOne {
_d._d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query.
func (_d *UsageCleanupTaskDeleteOne) Exec(ctx context.Context) error {
n, err := _d._d.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{usagecleanuptask.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *UsageCleanupTaskDeleteOne) ExecX(ctx context.Context) {
if err := _d.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,564 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
)
// UsageCleanupTaskQuery is the builder for querying UsageCleanupTask entities.
type UsageCleanupTaskQuery struct {
config
ctx *QueryContext
order []usagecleanuptask.OrderOption
inters []Interceptor
predicates []predicate.UsageCleanupTask
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the UsageCleanupTaskQuery builder.
func (_q *UsageCleanupTaskQuery) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *UsageCleanupTaskQuery) Limit(limit int) *UsageCleanupTaskQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *UsageCleanupTaskQuery) Offset(offset int) *UsageCleanupTaskQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *UsageCleanupTaskQuery) Unique(unique bool) *UsageCleanupTaskQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *UsageCleanupTaskQuery) Order(o ...usagecleanuptask.OrderOption) *UsageCleanupTaskQuery {
_q.order = append(_q.order, o...)
return _q
}
// First returns the first UsageCleanupTask entity from the query.
// Returns a *NotFoundError when no UsageCleanupTask was found.
func (_q *UsageCleanupTaskQuery) First(ctx context.Context) (*UsageCleanupTask, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{usagecleanuptask.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *UsageCleanupTaskQuery) FirstX(ctx context.Context) *UsageCleanupTask {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first UsageCleanupTask ID from the query.
// Returns a *NotFoundError when no UsageCleanupTask ID was found.
func (_q *UsageCleanupTaskQuery) FirstID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{usagecleanuptask.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *UsageCleanupTaskQuery) FirstIDX(ctx context.Context) int64 {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single UsageCleanupTask entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one UsageCleanupTask entity is found.
// Returns a *NotFoundError when no UsageCleanupTask entities are found.
func (_q *UsageCleanupTaskQuery) Only(ctx context.Context) (*UsageCleanupTask, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{usagecleanuptask.Label}
default:
return nil, &NotSingularError{usagecleanuptask.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *UsageCleanupTaskQuery) OnlyX(ctx context.Context) *UsageCleanupTask {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only UsageCleanupTask ID in the query.
// Returns a *NotSingularError when more than one UsageCleanupTask ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *UsageCleanupTaskQuery) OnlyID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{usagecleanuptask.Label}
default:
err = &NotSingularError{usagecleanuptask.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *UsageCleanupTaskQuery) OnlyIDX(ctx context.Context) int64 {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of UsageCleanupTasks.
func (_q *UsageCleanupTaskQuery) All(ctx context.Context) ([]*UsageCleanupTask, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*UsageCleanupTask, *UsageCleanupTaskQuery]()
return withInterceptors[[]*UsageCleanupTask](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *UsageCleanupTaskQuery) AllX(ctx context.Context) []*UsageCleanupTask {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of UsageCleanupTask IDs.
func (_q *UsageCleanupTaskQuery) IDs(ctx context.Context) (ids []int64, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(usagecleanuptask.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *UsageCleanupTaskQuery) IDsX(ctx context.Context) []int64 {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *UsageCleanupTaskQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*UsageCleanupTaskQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *UsageCleanupTaskQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *UsageCleanupTaskQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *UsageCleanupTaskQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the UsageCleanupTaskQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *UsageCleanupTaskQuery) Clone() *UsageCleanupTaskQuery {
if _q == nil {
return nil
}
return &UsageCleanupTaskQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]usagecleanuptask.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.UsageCleanupTask{}, _q.predicates...),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.UsageCleanupTask.Query().
// GroupBy(usagecleanuptask.FieldCreatedAt).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *UsageCleanupTaskQuery) GroupBy(field string, fields ...string) *UsageCleanupTaskGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &UsageCleanupTaskGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = usagecleanuptask.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// }
//
// client.UsageCleanupTask.Query().
// Select(usagecleanuptask.FieldCreatedAt).
// Scan(ctx, &v)
func (_q *UsageCleanupTaskQuery) Select(fields ...string) *UsageCleanupTaskSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &UsageCleanupTaskSelect{UsageCleanupTaskQuery: _q}
sbuild.label = usagecleanuptask.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a UsageCleanupTaskSelect configured with the given aggregations.
func (_q *UsageCleanupTaskQuery) Aggregate(fns ...AggregateFunc) *UsageCleanupTaskSelect {
return _q.Select().Aggregate(fns...)
}
func (_q *UsageCleanupTaskQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !usagecleanuptask.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *UsageCleanupTaskQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UsageCleanupTask, error) {
var (
nodes = []*UsageCleanupTask{}
_spec = _q.querySpec()
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*UsageCleanupTask).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &UsageCleanupTask{config: _q.config}
nodes = append(nodes, node)
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
return nodes, nil
}
func (_q *UsageCleanupTaskQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *UsageCleanupTaskQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(usagecleanuptask.Table, usagecleanuptask.Columns, sqlgraph.NewFieldSpec(usagecleanuptask.FieldID, field.TypeInt64))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, usagecleanuptask.FieldID)
for i := range fields {
if fields[i] != usagecleanuptask.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *UsageCleanupTaskQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(usagecleanuptask.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = usagecleanuptask.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UsageCleanupTaskQuery) ForUpdate(opts ...sql.LockOption) *UsageCleanupTaskQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UsageCleanupTaskQuery) ForShare(opts ...sql.LockOption) *UsageCleanupTaskQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UsageCleanupTaskGroupBy is the group-by builder for UsageCleanupTask entities.
type UsageCleanupTaskGroupBy struct {
selector
build *UsageCleanupTaskQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *UsageCleanupTaskGroupBy) Aggregate(fns ...AggregateFunc) *UsageCleanupTaskGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *UsageCleanupTaskGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*UsageCleanupTaskQuery, *UsageCleanupTaskGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *UsageCleanupTaskGroupBy) sqlScan(ctx context.Context, root *UsageCleanupTaskQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// UsageCleanupTaskSelect is the builder for selecting fields of UsageCleanupTask entities.
type UsageCleanupTaskSelect struct {
*UsageCleanupTaskQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *UsageCleanupTaskSelect) Aggregate(fns ...AggregateFunc) *UsageCleanupTaskSelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *UsageCleanupTaskSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*UsageCleanupTaskQuery, *UsageCleanupTaskSelect](ctx, _s.UsageCleanupTaskQuery, _s, _s.inters, v)
}
func (_s *UsageCleanupTaskSelect) sqlScan(ctx context.Context, root *UsageCleanupTaskQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}

View File

@@ -0,0 +1,702 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/dialect/sql/sqljson"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
)
// UsageCleanupTaskUpdate is the builder for updating UsageCleanupTask entities.
type UsageCleanupTaskUpdate struct {
config
hooks []Hook
mutation *UsageCleanupTaskMutation
}
// Where appends a list predicates to the UsageCleanupTaskUpdate builder.
func (_u *UsageCleanupTaskUpdate) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *UsageCleanupTaskUpdate) SetUpdatedAt(v time.Time) *UsageCleanupTaskUpdate {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetStatus sets the "status" field.
func (_u *UsageCleanupTaskUpdate) SetStatus(v string) *UsageCleanupTaskUpdate {
_u.mutation.SetStatus(v)
return _u
}
// SetNillableStatus sets the "status" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdate) SetNillableStatus(v *string) *UsageCleanupTaskUpdate {
if v != nil {
_u.SetStatus(*v)
}
return _u
}
// SetFilters sets the "filters" field.
func (_u *UsageCleanupTaskUpdate) SetFilters(v json.RawMessage) *UsageCleanupTaskUpdate {
_u.mutation.SetFilters(v)
return _u
}
// AppendFilters appends value to the "filters" field.
func (_u *UsageCleanupTaskUpdate) AppendFilters(v json.RawMessage) *UsageCleanupTaskUpdate {
_u.mutation.AppendFilters(v)
return _u
}
// SetCreatedBy sets the "created_by" field.
func (_u *UsageCleanupTaskUpdate) SetCreatedBy(v int64) *UsageCleanupTaskUpdate {
_u.mutation.ResetCreatedBy()
_u.mutation.SetCreatedBy(v)
return _u
}
// SetNillableCreatedBy sets the "created_by" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdate) SetNillableCreatedBy(v *int64) *UsageCleanupTaskUpdate {
if v != nil {
_u.SetCreatedBy(*v)
}
return _u
}
// AddCreatedBy adds value to the "created_by" field.
func (_u *UsageCleanupTaskUpdate) AddCreatedBy(v int64) *UsageCleanupTaskUpdate {
_u.mutation.AddCreatedBy(v)
return _u
}
// SetDeletedRows sets the "deleted_rows" field.
func (_u *UsageCleanupTaskUpdate) SetDeletedRows(v int64) *UsageCleanupTaskUpdate {
_u.mutation.ResetDeletedRows()
_u.mutation.SetDeletedRows(v)
return _u
}
// SetNillableDeletedRows sets the "deleted_rows" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdate) SetNillableDeletedRows(v *int64) *UsageCleanupTaskUpdate {
if v != nil {
_u.SetDeletedRows(*v)
}
return _u
}
// AddDeletedRows adds value to the "deleted_rows" field.
func (_u *UsageCleanupTaskUpdate) AddDeletedRows(v int64) *UsageCleanupTaskUpdate {
_u.mutation.AddDeletedRows(v)
return _u
}
// SetErrorMessage sets the "error_message" field.
func (_u *UsageCleanupTaskUpdate) SetErrorMessage(v string) *UsageCleanupTaskUpdate {
_u.mutation.SetErrorMessage(v)
return _u
}
// SetNillableErrorMessage sets the "error_message" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdate) SetNillableErrorMessage(v *string) *UsageCleanupTaskUpdate {
if v != nil {
_u.SetErrorMessage(*v)
}
return _u
}
// ClearErrorMessage clears the value of the "error_message" field.
func (_u *UsageCleanupTaskUpdate) ClearErrorMessage() *UsageCleanupTaskUpdate {
_u.mutation.ClearErrorMessage()
return _u
}
// SetCanceledBy sets the "canceled_by" field.
func (_u *UsageCleanupTaskUpdate) SetCanceledBy(v int64) *UsageCleanupTaskUpdate {
_u.mutation.ResetCanceledBy()
_u.mutation.SetCanceledBy(v)
return _u
}
// SetNillableCanceledBy sets the "canceled_by" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdate) SetNillableCanceledBy(v *int64) *UsageCleanupTaskUpdate {
if v != nil {
_u.SetCanceledBy(*v)
}
return _u
}
// AddCanceledBy adds value to the "canceled_by" field.
func (_u *UsageCleanupTaskUpdate) AddCanceledBy(v int64) *UsageCleanupTaskUpdate {
_u.mutation.AddCanceledBy(v)
return _u
}
// ClearCanceledBy clears the value of the "canceled_by" field.
func (_u *UsageCleanupTaskUpdate) ClearCanceledBy() *UsageCleanupTaskUpdate {
_u.mutation.ClearCanceledBy()
return _u
}
// SetCanceledAt sets the "canceled_at" field.
func (_u *UsageCleanupTaskUpdate) SetCanceledAt(v time.Time) *UsageCleanupTaskUpdate {
_u.mutation.SetCanceledAt(v)
return _u
}
// SetNillableCanceledAt sets the "canceled_at" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdate) SetNillableCanceledAt(v *time.Time) *UsageCleanupTaskUpdate {
if v != nil {
_u.SetCanceledAt(*v)
}
return _u
}
// ClearCanceledAt clears the value of the "canceled_at" field.
func (_u *UsageCleanupTaskUpdate) ClearCanceledAt() *UsageCleanupTaskUpdate {
_u.mutation.ClearCanceledAt()
return _u
}
// SetStartedAt sets the "started_at" field.
func (_u *UsageCleanupTaskUpdate) SetStartedAt(v time.Time) *UsageCleanupTaskUpdate {
_u.mutation.SetStartedAt(v)
return _u
}
// SetNillableStartedAt sets the "started_at" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdate) SetNillableStartedAt(v *time.Time) *UsageCleanupTaskUpdate {
if v != nil {
_u.SetStartedAt(*v)
}
return _u
}
// ClearStartedAt clears the value of the "started_at" field.
func (_u *UsageCleanupTaskUpdate) ClearStartedAt() *UsageCleanupTaskUpdate {
_u.mutation.ClearStartedAt()
return _u
}
// SetFinishedAt sets the "finished_at" field.
func (_u *UsageCleanupTaskUpdate) SetFinishedAt(v time.Time) *UsageCleanupTaskUpdate {
_u.mutation.SetFinishedAt(v)
return _u
}
// SetNillableFinishedAt sets the "finished_at" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdate) SetNillableFinishedAt(v *time.Time) *UsageCleanupTaskUpdate {
if v != nil {
_u.SetFinishedAt(*v)
}
return _u
}
// ClearFinishedAt clears the value of the "finished_at" field.
func (_u *UsageCleanupTaskUpdate) ClearFinishedAt() *UsageCleanupTaskUpdate {
_u.mutation.ClearFinishedAt()
return _u
}
// Mutation returns the UsageCleanupTaskMutation object of the builder.
func (_u *UsageCleanupTaskUpdate) Mutation() *UsageCleanupTaskMutation {
return _u.mutation
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UsageCleanupTaskUpdate) Save(ctx context.Context) (int, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *UsageCleanupTaskUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *UsageCleanupTaskUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *UsageCleanupTaskUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *UsageCleanupTaskUpdate) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := usagecleanuptask.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *UsageCleanupTaskUpdate) check() error {
if v, ok := _u.mutation.Status(); ok {
if err := usagecleanuptask.StatusValidator(v); err != nil {
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "UsageCleanupTask.status": %w`, err)}
}
}
return nil
}
func (_u *UsageCleanupTaskUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(usagecleanuptask.Table, usagecleanuptask.Columns, sqlgraph.NewFieldSpec(usagecleanuptask.FieldID, field.TypeInt64))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(usagecleanuptask.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(usagecleanuptask.FieldStatus, field.TypeString, value)
}
if value, ok := _u.mutation.Filters(); ok {
_spec.SetField(usagecleanuptask.FieldFilters, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedFilters(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, usagecleanuptask.FieldFilters, value)
})
}
if value, ok := _u.mutation.CreatedBy(); ok {
_spec.SetField(usagecleanuptask.FieldCreatedBy, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedCreatedBy(); ok {
_spec.AddField(usagecleanuptask.FieldCreatedBy, field.TypeInt64, value)
}
if value, ok := _u.mutation.DeletedRows(); ok {
_spec.SetField(usagecleanuptask.FieldDeletedRows, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedDeletedRows(); ok {
_spec.AddField(usagecleanuptask.FieldDeletedRows, field.TypeInt64, value)
}
if value, ok := _u.mutation.ErrorMessage(); ok {
_spec.SetField(usagecleanuptask.FieldErrorMessage, field.TypeString, value)
}
if _u.mutation.ErrorMessageCleared() {
_spec.ClearField(usagecleanuptask.FieldErrorMessage, field.TypeString)
}
if value, ok := _u.mutation.CanceledBy(); ok {
_spec.SetField(usagecleanuptask.FieldCanceledBy, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedCanceledBy(); ok {
_spec.AddField(usagecleanuptask.FieldCanceledBy, field.TypeInt64, value)
}
if _u.mutation.CanceledByCleared() {
_spec.ClearField(usagecleanuptask.FieldCanceledBy, field.TypeInt64)
}
if value, ok := _u.mutation.CanceledAt(); ok {
_spec.SetField(usagecleanuptask.FieldCanceledAt, field.TypeTime, value)
}
if _u.mutation.CanceledAtCleared() {
_spec.ClearField(usagecleanuptask.FieldCanceledAt, field.TypeTime)
}
if value, ok := _u.mutation.StartedAt(); ok {
_spec.SetField(usagecleanuptask.FieldStartedAt, field.TypeTime, value)
}
if _u.mutation.StartedAtCleared() {
_spec.ClearField(usagecleanuptask.FieldStartedAt, field.TypeTime)
}
if value, ok := _u.mutation.FinishedAt(); ok {
_spec.SetField(usagecleanuptask.FieldFinishedAt, field.TypeTime, value)
}
if _u.mutation.FinishedAtCleared() {
_spec.ClearField(usagecleanuptask.FieldFinishedAt, field.TypeTime)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{usagecleanuptask.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// UsageCleanupTaskUpdateOne is the builder for updating a single UsageCleanupTask entity.
type UsageCleanupTaskUpdateOne struct {
config
fields []string
hooks []Hook
mutation *UsageCleanupTaskMutation
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *UsageCleanupTaskUpdateOne) SetUpdatedAt(v time.Time) *UsageCleanupTaskUpdateOne {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetStatus sets the "status" field.
func (_u *UsageCleanupTaskUpdateOne) SetStatus(v string) *UsageCleanupTaskUpdateOne {
_u.mutation.SetStatus(v)
return _u
}
// SetNillableStatus sets the "status" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdateOne) SetNillableStatus(v *string) *UsageCleanupTaskUpdateOne {
if v != nil {
_u.SetStatus(*v)
}
return _u
}
// SetFilters sets the "filters" field.
func (_u *UsageCleanupTaskUpdateOne) SetFilters(v json.RawMessage) *UsageCleanupTaskUpdateOne {
_u.mutation.SetFilters(v)
return _u
}
// AppendFilters appends value to the "filters" field.
func (_u *UsageCleanupTaskUpdateOne) AppendFilters(v json.RawMessage) *UsageCleanupTaskUpdateOne {
_u.mutation.AppendFilters(v)
return _u
}
// SetCreatedBy sets the "created_by" field.
func (_u *UsageCleanupTaskUpdateOne) SetCreatedBy(v int64) *UsageCleanupTaskUpdateOne {
_u.mutation.ResetCreatedBy()
_u.mutation.SetCreatedBy(v)
return _u
}
// SetNillableCreatedBy sets the "created_by" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdateOne) SetNillableCreatedBy(v *int64) *UsageCleanupTaskUpdateOne {
if v != nil {
_u.SetCreatedBy(*v)
}
return _u
}
// AddCreatedBy adds value to the "created_by" field.
func (_u *UsageCleanupTaskUpdateOne) AddCreatedBy(v int64) *UsageCleanupTaskUpdateOne {
_u.mutation.AddCreatedBy(v)
return _u
}
// SetDeletedRows sets the "deleted_rows" field.
func (_u *UsageCleanupTaskUpdateOne) SetDeletedRows(v int64) *UsageCleanupTaskUpdateOne {
_u.mutation.ResetDeletedRows()
_u.mutation.SetDeletedRows(v)
return _u
}
// SetNillableDeletedRows sets the "deleted_rows" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdateOne) SetNillableDeletedRows(v *int64) *UsageCleanupTaskUpdateOne {
if v != nil {
_u.SetDeletedRows(*v)
}
return _u
}
// AddDeletedRows adds value to the "deleted_rows" field.
func (_u *UsageCleanupTaskUpdateOne) AddDeletedRows(v int64) *UsageCleanupTaskUpdateOne {
_u.mutation.AddDeletedRows(v)
return _u
}
// SetErrorMessage sets the "error_message" field.
func (_u *UsageCleanupTaskUpdateOne) SetErrorMessage(v string) *UsageCleanupTaskUpdateOne {
_u.mutation.SetErrorMessage(v)
return _u
}
// SetNillableErrorMessage sets the "error_message" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdateOne) SetNillableErrorMessage(v *string) *UsageCleanupTaskUpdateOne {
if v != nil {
_u.SetErrorMessage(*v)
}
return _u
}
// ClearErrorMessage clears the value of the "error_message" field.
func (_u *UsageCleanupTaskUpdateOne) ClearErrorMessage() *UsageCleanupTaskUpdateOne {
_u.mutation.ClearErrorMessage()
return _u
}
// SetCanceledBy sets the "canceled_by" field.
func (_u *UsageCleanupTaskUpdateOne) SetCanceledBy(v int64) *UsageCleanupTaskUpdateOne {
_u.mutation.ResetCanceledBy()
_u.mutation.SetCanceledBy(v)
return _u
}
// SetNillableCanceledBy sets the "canceled_by" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdateOne) SetNillableCanceledBy(v *int64) *UsageCleanupTaskUpdateOne {
if v != nil {
_u.SetCanceledBy(*v)
}
return _u
}
// AddCanceledBy adds value to the "canceled_by" field.
func (_u *UsageCleanupTaskUpdateOne) AddCanceledBy(v int64) *UsageCleanupTaskUpdateOne {
_u.mutation.AddCanceledBy(v)
return _u
}
// ClearCanceledBy clears the value of the "canceled_by" field.
func (_u *UsageCleanupTaskUpdateOne) ClearCanceledBy() *UsageCleanupTaskUpdateOne {
_u.mutation.ClearCanceledBy()
return _u
}
// SetCanceledAt sets the "canceled_at" field.
func (_u *UsageCleanupTaskUpdateOne) SetCanceledAt(v time.Time) *UsageCleanupTaskUpdateOne {
_u.mutation.SetCanceledAt(v)
return _u
}
// SetNillableCanceledAt sets the "canceled_at" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdateOne) SetNillableCanceledAt(v *time.Time) *UsageCleanupTaskUpdateOne {
if v != nil {
_u.SetCanceledAt(*v)
}
return _u
}
// ClearCanceledAt clears the value of the "canceled_at" field.
func (_u *UsageCleanupTaskUpdateOne) ClearCanceledAt() *UsageCleanupTaskUpdateOne {
_u.mutation.ClearCanceledAt()
return _u
}
// SetStartedAt sets the "started_at" field.
func (_u *UsageCleanupTaskUpdateOne) SetStartedAt(v time.Time) *UsageCleanupTaskUpdateOne {
_u.mutation.SetStartedAt(v)
return _u
}
// SetNillableStartedAt sets the "started_at" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdateOne) SetNillableStartedAt(v *time.Time) *UsageCleanupTaskUpdateOne {
if v != nil {
_u.SetStartedAt(*v)
}
return _u
}
// ClearStartedAt clears the value of the "started_at" field.
func (_u *UsageCleanupTaskUpdateOne) ClearStartedAt() *UsageCleanupTaskUpdateOne {
_u.mutation.ClearStartedAt()
return _u
}
// SetFinishedAt sets the "finished_at" field.
func (_u *UsageCleanupTaskUpdateOne) SetFinishedAt(v time.Time) *UsageCleanupTaskUpdateOne {
_u.mutation.SetFinishedAt(v)
return _u
}
// SetNillableFinishedAt sets the "finished_at" field if the given value is not nil.
func (_u *UsageCleanupTaskUpdateOne) SetNillableFinishedAt(v *time.Time) *UsageCleanupTaskUpdateOne {
if v != nil {
_u.SetFinishedAt(*v)
}
return _u
}
// ClearFinishedAt clears the value of the "finished_at" field.
func (_u *UsageCleanupTaskUpdateOne) ClearFinishedAt() *UsageCleanupTaskUpdateOne {
_u.mutation.ClearFinishedAt()
return _u
}
// Mutation returns the UsageCleanupTaskMutation object of the builder.
func (_u *UsageCleanupTaskUpdateOne) Mutation() *UsageCleanupTaskMutation {
return _u.mutation
}
// Where appends a list predicates to the UsageCleanupTaskUpdate builder.
func (_u *UsageCleanupTaskUpdateOne) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *UsageCleanupTaskUpdateOne) Select(field string, fields ...string) *UsageCleanupTaskUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated UsageCleanupTask entity.
func (_u *UsageCleanupTaskUpdateOne) Save(ctx context.Context) (*UsageCleanupTask, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *UsageCleanupTaskUpdateOne) SaveX(ctx context.Context) *UsageCleanupTask {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *UsageCleanupTaskUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *UsageCleanupTaskUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *UsageCleanupTaskUpdateOne) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := usagecleanuptask.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *UsageCleanupTaskUpdateOne) check() error {
if v, ok := _u.mutation.Status(); ok {
if err := usagecleanuptask.StatusValidator(v); err != nil {
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "UsageCleanupTask.status": %w`, err)}
}
}
return nil
}
func (_u *UsageCleanupTaskUpdateOne) sqlSave(ctx context.Context) (_node *UsageCleanupTask, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(usagecleanuptask.Table, usagecleanuptask.Columns, sqlgraph.NewFieldSpec(usagecleanuptask.FieldID, field.TypeInt64))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "UsageCleanupTask.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, usagecleanuptask.FieldID)
for _, f := range fields {
if !usagecleanuptask.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != usagecleanuptask.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(usagecleanuptask.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(usagecleanuptask.FieldStatus, field.TypeString, value)
}
if value, ok := _u.mutation.Filters(); ok {
_spec.SetField(usagecleanuptask.FieldFilters, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedFilters(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, usagecleanuptask.FieldFilters, value)
})
}
if value, ok := _u.mutation.CreatedBy(); ok {
_spec.SetField(usagecleanuptask.FieldCreatedBy, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedCreatedBy(); ok {
_spec.AddField(usagecleanuptask.FieldCreatedBy, field.TypeInt64, value)
}
if value, ok := _u.mutation.DeletedRows(); ok {
_spec.SetField(usagecleanuptask.FieldDeletedRows, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedDeletedRows(); ok {
_spec.AddField(usagecleanuptask.FieldDeletedRows, field.TypeInt64, value)
}
if value, ok := _u.mutation.ErrorMessage(); ok {
_spec.SetField(usagecleanuptask.FieldErrorMessage, field.TypeString, value)
}
if _u.mutation.ErrorMessageCleared() {
_spec.ClearField(usagecleanuptask.FieldErrorMessage, field.TypeString)
}
if value, ok := _u.mutation.CanceledBy(); ok {
_spec.SetField(usagecleanuptask.FieldCanceledBy, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedCanceledBy(); ok {
_spec.AddField(usagecleanuptask.FieldCanceledBy, field.TypeInt64, value)
}
if _u.mutation.CanceledByCleared() {
_spec.ClearField(usagecleanuptask.FieldCanceledBy, field.TypeInt64)
}
if value, ok := _u.mutation.CanceledAt(); ok {
_spec.SetField(usagecleanuptask.FieldCanceledAt, field.TypeTime, value)
}
if _u.mutation.CanceledAtCleared() {
_spec.ClearField(usagecleanuptask.FieldCanceledAt, field.TypeTime)
}
if value, ok := _u.mutation.StartedAt(); ok {
_spec.SetField(usagecleanuptask.FieldStartedAt, field.TypeTime, value)
}
if _u.mutation.StartedAtCleared() {
_spec.ClearField(usagecleanuptask.FieldStartedAt, field.TypeTime)
}
if value, ok := _u.mutation.FinishedAt(); ok {
_spec.SetField(usagecleanuptask.FieldFinishedAt, field.TypeTime, value)
}
if _u.mutation.FinishedAtCleared() {
_spec.ClearField(usagecleanuptask.FieldFinishedAt, field.TypeTime)
}
_node = &UsageCleanupTask{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{usagecleanuptask.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}

View File

@@ -62,6 +62,8 @@ type UsageLog struct {
ActualCost float64 `json:"actual_cost,omitempty"`
// RateMultiplier holds the value of the "rate_multiplier" field.
RateMultiplier float64 `json:"rate_multiplier,omitempty"`
// AccountRateMultiplier holds the value of the "account_rate_multiplier" field.
AccountRateMultiplier *float64 `json:"account_rate_multiplier,omitempty"`
// BillingType holds the value of the "billing_type" field.
BillingType int8 `json:"billing_type,omitempty"`
// Stream holds the value of the "stream" field.
@@ -165,7 +167,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
switch columns[i] {
case usagelog.FieldStream:
values[i] = new(sql.NullBool)
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier:
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64)
@@ -316,6 +318,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.RateMultiplier = value.Float64
}
case usagelog.FieldAccountRateMultiplier:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field account_rate_multiplier", values[i])
} else if value.Valid {
_m.AccountRateMultiplier = new(float64)
*_m.AccountRateMultiplier = value.Float64
}
case usagelog.FieldBillingType:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field billing_type", values[i])
@@ -500,6 +509,11 @@ func (_m *UsageLog) String() string {
builder.WriteString("rate_multiplier=")
builder.WriteString(fmt.Sprintf("%v", _m.RateMultiplier))
builder.WriteString(", ")
if v := _m.AccountRateMultiplier; v != nil {
builder.WriteString("account_rate_multiplier=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("billing_type=")
builder.WriteString(fmt.Sprintf("%v", _m.BillingType))
builder.WriteString(", ")

View File

@@ -54,6 +54,8 @@ const (
FieldActualCost = "actual_cost"
// FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
FieldRateMultiplier = "rate_multiplier"
// FieldAccountRateMultiplier holds the string denoting the account_rate_multiplier field in the database.
FieldAccountRateMultiplier = "account_rate_multiplier"
// FieldBillingType holds the string denoting the billing_type field in the database.
FieldBillingType = "billing_type"
// FieldStream holds the string denoting the stream field in the database.
@@ -144,6 +146,7 @@ var Columns = []string{
FieldTotalCost,
FieldActualCost,
FieldRateMultiplier,
FieldAccountRateMultiplier,
FieldBillingType,
FieldStream,
FieldDurationMs,
@@ -320,6 +323,11 @@ func ByRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRateMultiplier, opts...).ToFunc()
}
// ByAccountRateMultiplier orders the results by the account_rate_multiplier field.
func ByAccountRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAccountRateMultiplier, opts...).ToFunc()
}
// ByBillingType orders the results by the billing_type field.
func ByBillingType(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBillingType, opts...).ToFunc()

View File

@@ -155,6 +155,11 @@ func RateMultiplier(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldRateMultiplier, v))
}
// AccountRateMultiplier applies equality check predicate on the "account_rate_multiplier" field. It's identical to AccountRateMultiplierEQ.
func AccountRateMultiplier(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldAccountRateMultiplier, v))
}
// BillingType applies equality check predicate on the "billing_type" field. It's identical to BillingTypeEQ.
func BillingType(v int8) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldBillingType, v))
@@ -970,6 +975,56 @@ func RateMultiplierLTE(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldRateMultiplier, v))
}
// AccountRateMultiplierEQ applies the EQ predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierEQ(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldAccountRateMultiplier, v))
}
// AccountRateMultiplierNEQ applies the NEQ predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierNEQ(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldAccountRateMultiplier, v))
}
// AccountRateMultiplierIn applies the In predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierIn(vs ...float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldAccountRateMultiplier, vs...))
}
// AccountRateMultiplierNotIn applies the NotIn predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierNotIn(vs ...float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldAccountRateMultiplier, vs...))
}
// AccountRateMultiplierGT applies the GT predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierGT(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldAccountRateMultiplier, v))
}
// AccountRateMultiplierGTE applies the GTE predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierGTE(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldAccountRateMultiplier, v))
}
// AccountRateMultiplierLT applies the LT predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierLT(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldAccountRateMultiplier, v))
}
// AccountRateMultiplierLTE applies the LTE predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierLTE(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldAccountRateMultiplier, v))
}
// AccountRateMultiplierIsNil applies the IsNil predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldAccountRateMultiplier))
}
// AccountRateMultiplierNotNil applies the NotNil predicate on the "account_rate_multiplier" field.
func AccountRateMultiplierNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldAccountRateMultiplier))
}
// BillingTypeEQ applies the EQ predicate on the "billing_type" field.
func BillingTypeEQ(v int8) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldBillingType, v))

View File

@@ -267,6 +267,20 @@ func (_c *UsageLogCreate) SetNillableRateMultiplier(v *float64) *UsageLogCreate
return _c
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (_c *UsageLogCreate) SetAccountRateMultiplier(v float64) *UsageLogCreate {
_c.mutation.SetAccountRateMultiplier(v)
return _c
}
// SetNillableAccountRateMultiplier sets the "account_rate_multiplier" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableAccountRateMultiplier(v *float64) *UsageLogCreate {
if v != nil {
_c.SetAccountRateMultiplier(*v)
}
return _c
}
// SetBillingType sets the "billing_type" field.
func (_c *UsageLogCreate) SetBillingType(v int8) *UsageLogCreate {
_c.mutation.SetBillingType(v)
@@ -712,6 +726,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
_node.RateMultiplier = value
}
if value, ok := _c.mutation.AccountRateMultiplier(); ok {
_spec.SetField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
_node.AccountRateMultiplier = &value
}
if value, ok := _c.mutation.BillingType(); ok {
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
_node.BillingType = value
@@ -1215,6 +1233,30 @@ func (u *UsageLogUpsert) AddRateMultiplier(v float64) *UsageLogUpsert {
return u
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (u *UsageLogUpsert) SetAccountRateMultiplier(v float64) *UsageLogUpsert {
u.Set(usagelog.FieldAccountRateMultiplier, v)
return u
}
// UpdateAccountRateMultiplier sets the "account_rate_multiplier" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateAccountRateMultiplier() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldAccountRateMultiplier)
return u
}
// AddAccountRateMultiplier adds v to the "account_rate_multiplier" field.
func (u *UsageLogUpsert) AddAccountRateMultiplier(v float64) *UsageLogUpsert {
u.Add(usagelog.FieldAccountRateMultiplier, v)
return u
}
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
func (u *UsageLogUpsert) ClearAccountRateMultiplier() *UsageLogUpsert {
u.SetNull(usagelog.FieldAccountRateMultiplier)
return u
}
// SetBillingType sets the "billing_type" field.
func (u *UsageLogUpsert) SetBillingType(v int8) *UsageLogUpsert {
u.Set(usagelog.FieldBillingType, v)
@@ -1795,6 +1837,34 @@ func (u *UsageLogUpsertOne) UpdateRateMultiplier() *UsageLogUpsertOne {
})
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (u *UsageLogUpsertOne) SetAccountRateMultiplier(v float64) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetAccountRateMultiplier(v)
})
}
// AddAccountRateMultiplier adds v to the "account_rate_multiplier" field.
func (u *UsageLogUpsertOne) AddAccountRateMultiplier(v float64) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.AddAccountRateMultiplier(v)
})
}
// UpdateAccountRateMultiplier sets the "account_rate_multiplier" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateAccountRateMultiplier() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateAccountRateMultiplier()
})
}
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
func (u *UsageLogUpsertOne) ClearAccountRateMultiplier() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearAccountRateMultiplier()
})
}
// SetBillingType sets the "billing_type" field.
func (u *UsageLogUpsertOne) SetBillingType(v int8) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
@@ -2566,6 +2636,34 @@ func (u *UsageLogUpsertBulk) UpdateRateMultiplier() *UsageLogUpsertBulk {
})
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (u *UsageLogUpsertBulk) SetAccountRateMultiplier(v float64) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetAccountRateMultiplier(v)
})
}
// AddAccountRateMultiplier adds v to the "account_rate_multiplier" field.
func (u *UsageLogUpsertBulk) AddAccountRateMultiplier(v float64) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.AddAccountRateMultiplier(v)
})
}
// UpdateAccountRateMultiplier sets the "account_rate_multiplier" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateAccountRateMultiplier() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateAccountRateMultiplier()
})
}
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
func (u *UsageLogUpsertBulk) ClearAccountRateMultiplier() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearAccountRateMultiplier()
})
}
// SetBillingType sets the "billing_type" field.
func (u *UsageLogUpsertBulk) SetBillingType(v int8) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {

View File

@@ -415,6 +415,33 @@ func (_u *UsageLogUpdate) AddRateMultiplier(v float64) *UsageLogUpdate {
return _u
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (_u *UsageLogUpdate) SetAccountRateMultiplier(v float64) *UsageLogUpdate {
_u.mutation.ResetAccountRateMultiplier()
_u.mutation.SetAccountRateMultiplier(v)
return _u
}
// SetNillableAccountRateMultiplier sets the "account_rate_multiplier" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableAccountRateMultiplier(v *float64) *UsageLogUpdate {
if v != nil {
_u.SetAccountRateMultiplier(*v)
}
return _u
}
// AddAccountRateMultiplier adds value to the "account_rate_multiplier" field.
func (_u *UsageLogUpdate) AddAccountRateMultiplier(v float64) *UsageLogUpdate {
_u.mutation.AddAccountRateMultiplier(v)
return _u
}
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
func (_u *UsageLogUpdate) ClearAccountRateMultiplier() *UsageLogUpdate {
_u.mutation.ClearAccountRateMultiplier()
return _u
}
// SetBillingType sets the "billing_type" field.
func (_u *UsageLogUpdate) SetBillingType(v int8) *UsageLogUpdate {
_u.mutation.ResetBillingType()
@@ -807,6 +834,15 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
_spec.AddField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AccountRateMultiplier(); ok {
_spec.SetField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedAccountRateMultiplier(); ok {
_spec.AddField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
}
if _u.mutation.AccountRateMultiplierCleared() {
_spec.ClearField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64)
}
if value, ok := _u.mutation.BillingType(); ok {
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
}
@@ -1406,6 +1442,33 @@ func (_u *UsageLogUpdateOne) AddRateMultiplier(v float64) *UsageLogUpdateOne {
return _u
}
// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
func (_u *UsageLogUpdateOne) SetAccountRateMultiplier(v float64) *UsageLogUpdateOne {
_u.mutation.ResetAccountRateMultiplier()
_u.mutation.SetAccountRateMultiplier(v)
return _u
}
// SetNillableAccountRateMultiplier sets the "account_rate_multiplier" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableAccountRateMultiplier(v *float64) *UsageLogUpdateOne {
if v != nil {
_u.SetAccountRateMultiplier(*v)
}
return _u
}
// AddAccountRateMultiplier adds value to the "account_rate_multiplier" field.
func (_u *UsageLogUpdateOne) AddAccountRateMultiplier(v float64) *UsageLogUpdateOne {
_u.mutation.AddAccountRateMultiplier(v)
return _u
}
// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
func (_u *UsageLogUpdateOne) ClearAccountRateMultiplier() *UsageLogUpdateOne {
_u.mutation.ClearAccountRateMultiplier()
return _u
}
// SetBillingType sets the "billing_type" field.
func (_u *UsageLogUpdateOne) SetBillingType(v int8) *UsageLogUpdateOne {
_u.mutation.ResetBillingType()
@@ -1828,6 +1891,15 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
_spec.AddField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AccountRateMultiplier(); ok {
_spec.SetField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedAccountRateMultiplier(); ok {
_spec.AddField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
}
if _u.mutation.AccountRateMultiplierCleared() {
_spec.ClearField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64)
}
if value, ok := _u.mutation.BillingType(); ok {
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
}

View File

@@ -39,6 +39,12 @@ type User struct {
Username string `json:"username,omitempty"`
// Notes holds the value of the "notes" field.
Notes string `json:"notes,omitempty"`
// TotpSecretEncrypted holds the value of the "totp_secret_encrypted" field.
TotpSecretEncrypted *string `json:"totp_secret_encrypted,omitempty"`
// TotpEnabled holds the value of the "totp_enabled" field.
TotpEnabled bool `json:"totp_enabled,omitempty"`
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserQuery when eager-loading is set.
Edges UserEdges `json:"edges"`
@@ -156,13 +162,15 @@ func (*User) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case user.FieldTotpEnabled:
values[i] = new(sql.NullBool)
case user.FieldBalance:
values[i] = new(sql.NullFloat64)
case user.FieldID, user.FieldConcurrency:
values[i] = new(sql.NullInt64)
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes:
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted:
values[i] = new(sql.NullString)
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt:
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
@@ -252,6 +260,26 @@ func (_m *User) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.Notes = value.String
}
case user.FieldTotpSecretEncrypted:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field totp_secret_encrypted", values[i])
} else if value.Valid {
_m.TotpSecretEncrypted = new(string)
*_m.TotpSecretEncrypted = value.String
}
case user.FieldTotpEnabled:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field totp_enabled", values[i])
} else if value.Valid {
_m.TotpEnabled = value.Bool
}
case user.FieldTotpEnabledAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field totp_enabled_at", values[i])
} else if value.Valid {
_m.TotpEnabledAt = new(time.Time)
*_m.TotpEnabledAt = value.Time
}
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -367,6 +395,19 @@ func (_m *User) String() string {
builder.WriteString(", ")
builder.WriteString("notes=")
builder.WriteString(_m.Notes)
builder.WriteString(", ")
if v := _m.TotpSecretEncrypted; v != nil {
builder.WriteString("totp_secret_encrypted=")
builder.WriteString(*v)
}
builder.WriteString(", ")
builder.WriteString("totp_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.TotpEnabled))
builder.WriteString(", ")
if v := _m.TotpEnabledAt; v != nil {
builder.WriteString("totp_enabled_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteByte(')')
return builder.String()
}

View File

@@ -37,6 +37,12 @@ const (
FieldUsername = "username"
// FieldNotes holds the string denoting the notes field in the database.
FieldNotes = "notes"
// FieldTotpSecretEncrypted holds the string denoting the totp_secret_encrypted field in the database.
FieldTotpSecretEncrypted = "totp_secret_encrypted"
// FieldTotpEnabled holds the string denoting the totp_enabled field in the database.
FieldTotpEnabled = "totp_enabled"
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
FieldTotpEnabledAt = "totp_enabled_at"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -134,6 +140,9 @@ var Columns = []string{
FieldStatus,
FieldUsername,
FieldNotes,
FieldTotpSecretEncrypted,
FieldTotpEnabled,
FieldTotpEnabledAt,
}
var (
@@ -188,6 +197,8 @@ var (
UsernameValidator func(string) error
// DefaultNotes holds the default value on creation for the "notes" field.
DefaultNotes string
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
DefaultTotpEnabled bool
)
// OrderOption defines the ordering options for the User queries.
@@ -253,6 +264,21 @@ func ByNotes(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldNotes, opts...).ToFunc()
}
// ByTotpSecretEncrypted orders the results by the totp_secret_encrypted field.
func ByTotpSecretEncrypted(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpSecretEncrypted, opts...).ToFunc()
}
// ByTotpEnabled orders the results by the totp_enabled field.
func ByTotpEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabled, opts...).ToFunc()
}
// ByTotpEnabledAt orders the results by the totp_enabled_at field.
func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
}
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {

View File

@@ -110,6 +110,21 @@ func Notes(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldNotes, v))
}
// TotpSecretEncrypted applies equality check predicate on the "totp_secret_encrypted" field. It's identical to TotpSecretEncryptedEQ.
func TotpSecretEncrypted(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpSecretEncrypted, v))
}
// TotpEnabled applies equality check predicate on the "totp_enabled" field. It's identical to TotpEnabledEQ.
func TotpEnabled(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabled, v))
}
// TotpEnabledAt applies equality check predicate on the "totp_enabled_at" field. It's identical to TotpEnabledAtEQ.
func TotpEnabledAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
@@ -710,6 +725,141 @@ func NotesContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldNotes, v))
}
// TotpSecretEncryptedEQ applies the EQ predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedNEQ applies the NEQ predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedNEQ(v string) predicate.User {
return predicate.User(sql.FieldNEQ(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedIn applies the In predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedIn(vs ...string) predicate.User {
return predicate.User(sql.FieldIn(FieldTotpSecretEncrypted, vs...))
}
// TotpSecretEncryptedNotIn applies the NotIn predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedNotIn(vs ...string) predicate.User {
return predicate.User(sql.FieldNotIn(FieldTotpSecretEncrypted, vs...))
}
// TotpSecretEncryptedGT applies the GT predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedGT(v string) predicate.User {
return predicate.User(sql.FieldGT(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedGTE applies the GTE predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedGTE(v string) predicate.User {
return predicate.User(sql.FieldGTE(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedLT applies the LT predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedLT(v string) predicate.User {
return predicate.User(sql.FieldLT(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedLTE applies the LTE predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedLTE(v string) predicate.User {
return predicate.User(sql.FieldLTE(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedContains applies the Contains predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedContains(v string) predicate.User {
return predicate.User(sql.FieldContains(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedHasPrefix applies the HasPrefix predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedHasPrefix(v string) predicate.User {
return predicate.User(sql.FieldHasPrefix(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedHasSuffix applies the HasSuffix predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedHasSuffix(v string) predicate.User {
return predicate.User(sql.FieldHasSuffix(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedIsNil applies the IsNil predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedIsNil() predicate.User {
return predicate.User(sql.FieldIsNull(FieldTotpSecretEncrypted))
}
// TotpSecretEncryptedNotNil applies the NotNil predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpSecretEncrypted))
}
// TotpSecretEncryptedEqualFold applies the EqualFold predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedEqualFold(v string) predicate.User {
return predicate.User(sql.FieldEqualFold(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedContainsFold applies the ContainsFold predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldTotpSecretEncrypted, v))
}
// TotpEnabledEQ applies the EQ predicate on the "totp_enabled" field.
func TotpEnabledEQ(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabled, v))
}
// TotpEnabledNEQ applies the NEQ predicate on the "totp_enabled" field.
func TotpEnabledNEQ(v bool) predicate.User {
return predicate.User(sql.FieldNEQ(FieldTotpEnabled, v))
}
// TotpEnabledAtEQ applies the EQ predicate on the "totp_enabled_at" field.
func TotpEnabledAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
}
// TotpEnabledAtNEQ applies the NEQ predicate on the "totp_enabled_at" field.
func TotpEnabledAtNEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldNEQ(FieldTotpEnabledAt, v))
}
// TotpEnabledAtIn applies the In predicate on the "totp_enabled_at" field.
func TotpEnabledAtIn(vs ...time.Time) predicate.User {
return predicate.User(sql.FieldIn(FieldTotpEnabledAt, vs...))
}
// TotpEnabledAtNotIn applies the NotIn predicate on the "totp_enabled_at" field.
func TotpEnabledAtNotIn(vs ...time.Time) predicate.User {
return predicate.User(sql.FieldNotIn(FieldTotpEnabledAt, vs...))
}
// TotpEnabledAtGT applies the GT predicate on the "totp_enabled_at" field.
func TotpEnabledAtGT(v time.Time) predicate.User {
return predicate.User(sql.FieldGT(FieldTotpEnabledAt, v))
}
// TotpEnabledAtGTE applies the GTE predicate on the "totp_enabled_at" field.
func TotpEnabledAtGTE(v time.Time) predicate.User {
return predicate.User(sql.FieldGTE(FieldTotpEnabledAt, v))
}
// TotpEnabledAtLT applies the LT predicate on the "totp_enabled_at" field.
func TotpEnabledAtLT(v time.Time) predicate.User {
return predicate.User(sql.FieldLT(FieldTotpEnabledAt, v))
}
// TotpEnabledAtLTE applies the LTE predicate on the "totp_enabled_at" field.
func TotpEnabledAtLTE(v time.Time) predicate.User {
return predicate.User(sql.FieldLTE(FieldTotpEnabledAt, v))
}
// TotpEnabledAtIsNil applies the IsNil predicate on the "totp_enabled_at" field.
func TotpEnabledAtIsNil() predicate.User {
return predicate.User(sql.FieldIsNull(FieldTotpEnabledAt))
}
// TotpEnabledAtNotNil applies the NotNil predicate on the "totp_enabled_at" field.
func TotpEnabledAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
}
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.User {
return predicate.User(func(s *sql.Selector) {

View File

@@ -167,6 +167,48 @@ func (_c *UserCreate) SetNillableNotes(v *string) *UserCreate {
return _c
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (_c *UserCreate) SetTotpSecretEncrypted(v string) *UserCreate {
_c.mutation.SetTotpSecretEncrypted(v)
return _c
}
// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
func (_c *UserCreate) SetNillableTotpSecretEncrypted(v *string) *UserCreate {
if v != nil {
_c.SetTotpSecretEncrypted(*v)
}
return _c
}
// SetTotpEnabled sets the "totp_enabled" field.
func (_c *UserCreate) SetTotpEnabled(v bool) *UserCreate {
_c.mutation.SetTotpEnabled(v)
return _c
}
// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
func (_c *UserCreate) SetNillableTotpEnabled(v *bool) *UserCreate {
if v != nil {
_c.SetTotpEnabled(*v)
}
return _c
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (_c *UserCreate) SetTotpEnabledAt(v time.Time) *UserCreate {
_c.mutation.SetTotpEnabledAt(v)
return _c
}
// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
if v != nil {
_c.SetTotpEnabledAt(*v)
}
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -362,6 +404,10 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultNotes
_c.mutation.SetNotes(v)
}
if _, ok := _c.mutation.TotpEnabled(); !ok {
v := user.DefaultTotpEnabled
_c.mutation.SetTotpEnabled(v)
}
return nil
}
@@ -422,6 +468,9 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.Notes(); !ok {
return &ValidationError{Name: "notes", err: errors.New(`ent: missing required field "User.notes"`)}
}
if _, ok := _c.mutation.TotpEnabled(); !ok {
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
}
return nil
}
@@ -493,6 +542,18 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldNotes, field.TypeString, value)
_node.Notes = value
}
if value, ok := _c.mutation.TotpSecretEncrypted(); ok {
_spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
_node.TotpSecretEncrypted = &value
}
if value, ok := _c.mutation.TotpEnabled(); ok {
_spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
_node.TotpEnabled = value
}
if value, ok := _c.mutation.TotpEnabledAt(); ok {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
_node.TotpEnabledAt = &value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -815,6 +876,54 @@ func (u *UserUpsert) UpdateNotes() *UserUpsert {
return u
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (u *UserUpsert) SetTotpSecretEncrypted(v string) *UserUpsert {
u.Set(user.FieldTotpSecretEncrypted, v)
return u
}
// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
func (u *UserUpsert) UpdateTotpSecretEncrypted() *UserUpsert {
u.SetExcluded(user.FieldTotpSecretEncrypted)
return u
}
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
func (u *UserUpsert) ClearTotpSecretEncrypted() *UserUpsert {
u.SetNull(user.FieldTotpSecretEncrypted)
return u
}
// SetTotpEnabled sets the "totp_enabled" field.
func (u *UserUpsert) SetTotpEnabled(v bool) *UserUpsert {
u.Set(user.FieldTotpEnabled, v)
return u
}
// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
func (u *UserUpsert) UpdateTotpEnabled() *UserUpsert {
u.SetExcluded(user.FieldTotpEnabled)
return u
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (u *UserUpsert) SetTotpEnabledAt(v time.Time) *UserUpsert {
u.Set(user.FieldTotpEnabledAt, v)
return u
}
// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
func (u *UserUpsert) UpdateTotpEnabledAt() *UserUpsert {
u.SetExcluded(user.FieldTotpEnabledAt)
return u
}
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
u.SetNull(user.FieldTotpEnabledAt)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -1021,6 +1130,62 @@ func (u *UserUpsertOne) UpdateNotes() *UserUpsertOne {
})
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (u *UserUpsertOne) SetTotpSecretEncrypted(v string) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetTotpSecretEncrypted(v)
})
}
// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateTotpSecretEncrypted() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateTotpSecretEncrypted()
})
}
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
func (u *UserUpsertOne) ClearTotpSecretEncrypted() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.ClearTotpSecretEncrypted()
})
}
// SetTotpEnabled sets the "totp_enabled" field.
func (u *UserUpsertOne) SetTotpEnabled(v bool) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetTotpEnabled(v)
})
}
// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateTotpEnabled() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateTotpEnabled()
})
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (u *UserUpsertOne) SetTotpEnabledAt(v time.Time) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetTotpEnabledAt(v)
})
}
// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateTotpEnabledAt() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateTotpEnabledAt()
})
}
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.ClearTotpEnabledAt()
})
}
// Exec executes the query.
func (u *UserUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -1393,6 +1558,62 @@ func (u *UserUpsertBulk) UpdateNotes() *UserUpsertBulk {
})
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (u *UserUpsertBulk) SetTotpSecretEncrypted(v string) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetTotpSecretEncrypted(v)
})
}
// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateTotpSecretEncrypted() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateTotpSecretEncrypted()
})
}
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
func (u *UserUpsertBulk) ClearTotpSecretEncrypted() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.ClearTotpSecretEncrypted()
})
}
// SetTotpEnabled sets the "totp_enabled" field.
func (u *UserUpsertBulk) SetTotpEnabled(v bool) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetTotpEnabled(v)
})
}
// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateTotpEnabled() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateTotpEnabled()
})
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (u *UserUpsertBulk) SetTotpEnabledAt(v time.Time) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetTotpEnabledAt(v)
})
}
// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateTotpEnabledAt() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateTotpEnabledAt()
})
}
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.ClearTotpEnabledAt()
})
}
// Exec executes the query.
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {

View File

@@ -187,6 +187,60 @@ func (_u *UserUpdate) SetNillableNotes(v *string) *UserUpdate {
return _u
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (_u *UserUpdate) SetTotpSecretEncrypted(v string) *UserUpdate {
_u.mutation.SetTotpSecretEncrypted(v)
return _u
}
// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
func (_u *UserUpdate) SetNillableTotpSecretEncrypted(v *string) *UserUpdate {
if v != nil {
_u.SetTotpSecretEncrypted(*v)
}
return _u
}
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
func (_u *UserUpdate) ClearTotpSecretEncrypted() *UserUpdate {
_u.mutation.ClearTotpSecretEncrypted()
return _u
}
// SetTotpEnabled sets the "totp_enabled" field.
func (_u *UserUpdate) SetTotpEnabled(v bool) *UserUpdate {
_u.mutation.SetTotpEnabled(v)
return _u
}
// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
func (_u *UserUpdate) SetNillableTotpEnabled(v *bool) *UserUpdate {
if v != nil {
_u.SetTotpEnabled(*v)
}
return _u
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (_u *UserUpdate) SetTotpEnabledAt(v time.Time) *UserUpdate {
_u.mutation.SetTotpEnabledAt(v)
return _u
}
// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
func (_u *UserUpdate) SetNillableTotpEnabledAt(v *time.Time) *UserUpdate {
if v != nil {
_u.SetTotpEnabledAt(*v)
}
return _u
}
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
_u.mutation.ClearTotpEnabledAt()
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -603,6 +657,21 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.Notes(); ok {
_spec.SetField(user.FieldNotes, field.TypeString, value)
}
if value, ok := _u.mutation.TotpSecretEncrypted(); ok {
_spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
}
if _u.mutation.TotpSecretEncryptedCleared() {
_spec.ClearField(user.FieldTotpSecretEncrypted, field.TypeString)
}
if value, ok := _u.mutation.TotpEnabled(); ok {
_spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
}
if value, ok := _u.mutation.TotpEnabledAt(); ok {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
}
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1147,6 +1216,60 @@ func (_u *UserUpdateOne) SetNillableNotes(v *string) *UserUpdateOne {
return _u
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (_u *UserUpdateOne) SetTotpSecretEncrypted(v string) *UserUpdateOne {
_u.mutation.SetTotpSecretEncrypted(v)
return _u
}
// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableTotpSecretEncrypted(v *string) *UserUpdateOne {
if v != nil {
_u.SetTotpSecretEncrypted(*v)
}
return _u
}
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
func (_u *UserUpdateOne) ClearTotpSecretEncrypted() *UserUpdateOne {
_u.mutation.ClearTotpSecretEncrypted()
return _u
}
// SetTotpEnabled sets the "totp_enabled" field.
func (_u *UserUpdateOne) SetTotpEnabled(v bool) *UserUpdateOne {
_u.mutation.SetTotpEnabled(v)
return _u
}
// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableTotpEnabled(v *bool) *UserUpdateOne {
if v != nil {
_u.SetTotpEnabled(*v)
}
return _u
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (_u *UserUpdateOne) SetTotpEnabledAt(v time.Time) *UserUpdateOne {
_u.mutation.SetTotpEnabledAt(v)
return _u
}
// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableTotpEnabledAt(v *time.Time) *UserUpdateOne {
if v != nil {
_u.SetTotpEnabledAt(*v)
}
return _u
}
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
_u.mutation.ClearTotpEnabledAt()
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1593,6 +1716,21 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if value, ok := _u.mutation.Notes(); ok {
_spec.SetField(user.FieldNotes, field.TypeString, value)
}
if value, ok := _u.mutation.TotpSecretEncrypted(); ok {
_spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
}
if _u.mutation.TotpSecretEncryptedCleared() {
_spec.ClearField(user.FieldTotpSecretEncrypted, field.TypeString)
}
if value, ok := _u.mutation.TotpEnabled(); ok {
_spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
}
if value, ok := _u.mutation.TotpEnabledAt(); ok {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
}
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,

View File

@@ -31,11 +31,13 @@ require (
ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9 // indirect
dario.cat/mergo v1.0.2 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/agext/levenshtein v1.2.3 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
@@ -97,6 +99,7 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/olekukonko/tablewriter v0.0.5 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
@@ -104,9 +107,11 @@ require (
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/pquerna/otp v1.5.0 // indirect
github.com/quic-go/qpack v0.6.0 // indirect
github.com/quic-go/quic-go v0.57.1 // indirect
github.com/refraction-networking/utls v1.8.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
@@ -139,7 +144,7 @@ require (
go.uber.org/automaxprocs v1.6.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/mod v0.30.0 // indirect
golang.org/x/sys v0.39.0 // indirect
golang.org/x/text v0.32.0 // indirect
@@ -148,4 +153,8 @@ require (
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
modernc.org/libc v1.67.6 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect
modernc.org/sqlite v1.44.1 // indirect
)

View File

@@ -20,6 +20,8 @@ github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
@@ -141,6 +143,7 @@ github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
@@ -199,6 +202,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
@@ -214,6 +219,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
@@ -224,6 +231,8 @@ github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4Vi
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
@@ -338,6 +347,8 @@ golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
@@ -365,6 +376,7 @@ golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
golang.org/x/tools/go/expect v0.1.0-deprecated h1:jY2C5HGYR5lqex3gEniOQL0r7Dq5+VGVgY1nudX5lXY=
golang.org/x/tools/go/expect v0.1.0-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY=
golang.org/x/tools/go/expect v0.1.1-deprecated h1:jpBZDwmgPhXsKZC6WhL20P4b/wmnpsEAGHaNy0n/rJM=
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated h1:1h2MnaIAIXISqTFKdENegdpAgUXz6NrPEsbIeWaBRvM=
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated/go.mod h1:RVAQXBGNv1ib0J382/DPCRS/BPnsGebyM1Gj5VSDpG8=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -387,4 +399,12 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI=
modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
modernc.org/sqlite v1.44.1 h1:qybx/rNpfQipX/t47OxbHmkkJuv2JWifCMH8SVUiDas=
modernc.org/sqlite v1.44.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

View File

@@ -19,7 +19,9 @@ const (
RunModeSimple = "simple"
)
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' https://challenges.cloudflare.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
// DefaultCSPPolicy is the default Content-Security-Policy with nonce support
// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
// 连接池隔离策略常量
// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗
@@ -45,6 +47,7 @@ type Config struct {
Redis RedisConfig `mapstructure:"redis"`
Ops OpsConfig `mapstructure:"ops"`
JWT JWTConfig `mapstructure:"jwt"`
Totp TotpConfig `mapstructure:"totp"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
@@ -53,6 +56,7 @@ type Config struct {
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
@@ -232,6 +236,10 @@ type GatewayConfig struct {
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
// SessionIdleTimeoutMinutes: 会话空闲超时时间(分钟),默认 5 分钟
// 用于 Anthropic OAuth/SetupToken 账号的会话数量限制功能
// 空闲超过此时间的会话将被自动释放
SessionIdleTimeoutMinutes int `mapstructure:"session_idle_timeout_minutes"`
// StreamDataIntervalTimeout: 流数据间隔超时0表示禁用
StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"`
@@ -251,8 +259,43 @@ type GatewayConfig struct {
// 是否允许对部分 400 错误触发 failover默认关闭以避免改变语义
FailoverOn400 bool `mapstructure:"failover_on_400"`
// 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限)
MaxAccountSwitches int `mapstructure:"max_account_switches"`
// Gemini 账户切换最大次数Gemini 平台单独配置,因 API 限制更严格)
MaxAccountSwitchesGemini int `mapstructure:"max_account_switches_gemini"`
// Antigravity 429 fallback 限流时间(分钟),解析重置时间失败时使用
AntigravityFallbackCooldownMinutes int `mapstructure:"antigravity_fallback_cooldown_minutes"`
// Scheduling: 账号调度相关配置
Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"`
// TLSFingerprint: TLS指纹伪装配置
TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"`
}
// TLSFingerprintConfig TLS指纹伪装配置
// 用于模拟 Claude CLI (Node.js) 的 TLS 握手特征,避免被识别为非官方客户端
type TLSFingerprintConfig struct {
// Enabled: 是否全局启用TLS指纹功能
Enabled bool `mapstructure:"enabled"`
// Profiles: 预定义的TLS指纹配置模板
// key 为模板名称,如 "claude_cli_v2", "chrome_120" 等
Profiles map[string]TLSProfileConfig `mapstructure:"profiles"`
}
// TLSProfileConfig 单个TLS指纹模板的配置
type TLSProfileConfig struct {
// Name: 模板显示名称
Name string `mapstructure:"name"`
// EnableGREASE: 是否启用GREASE扩展Chrome使用Node.js不使用
EnableGREASE bool `mapstructure:"enable_grease"`
// CipherSuites: TLS加密套件列表空则使用内置默认值
CipherSuites []uint16 `mapstructure:"cipher_suites"`
// Curves: 椭圆曲线列表(空则使用内置默认值)
Curves []uint16 `mapstructure:"curves"`
// PointFormats: 点格式列表(空则使用内置默认值)
PointFormats []uint8 `mapstructure:"point_formats"`
}
// GatewaySchedulingConfig accounts scheduling configuration.
@@ -265,6 +308,9 @@ type GatewaySchedulingConfig struct {
FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"`
FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"`
// 兜底层账户选择策略: "last_used"(按最后使用时间排序,默认) 或 "random"(随机)
FallbackSelectionMode string `mapstructure:"fallback_selection_mode"`
// 负载计算
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
@@ -421,6 +467,16 @@ type JWTConfig struct {
ExpireHour int `mapstructure:"expire_hour"`
}
// TotpConfig TOTP 双因素认证配置
type TotpConfig struct {
// EncryptionKey 用于加密 TOTP 密钥的 AES-256 密钥32 字节 hex 编码)
// 如果为空,将自动生成一个随机密钥(仅适用于开发环境)
EncryptionKey string `mapstructure:"encryption_key"`
// EncryptionKeyConfigured 标记加密密钥是否为手动配置(非自动生成)
// 只有手动配置了密钥才允许在管理后台启用 TOTP 功能
EncryptionKeyConfigured bool `mapstructure:"-"`
}
type TurnstileConfig struct {
Required bool `mapstructure:"required"`
}
@@ -487,6 +543,20 @@ type DashboardAggregationRetentionConfig struct {
DailyDays int `mapstructure:"daily_days"`
}
// UsageCleanupConfig 使用记录清理任务配置
type UsageCleanupConfig struct {
// Enabled: 是否启用清理任务执行器
Enabled bool `mapstructure:"enabled"`
// MaxRangeDays: 单次任务允许的最大时间跨度(天)
MaxRangeDays int `mapstructure:"max_range_days"`
// BatchSize: 单批删除数量
BatchSize int `mapstructure:"batch_size"`
// WorkerIntervalSeconds: 后台任务轮询间隔(秒)
WorkerIntervalSeconds int `mapstructure:"worker_interval_seconds"`
// TaskTimeoutSeconds: 单次任务最大执行时长(秒)
TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"`
}
func NormalizeRunMode(value string) string {
normalized := strings.ToLower(strings.TrimSpace(value))
switch normalized {
@@ -567,6 +637,20 @@ func Load() (*Config, error) {
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
}
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
if cfg.Totp.EncryptionKey == "" {
key, err := generateJWTSecret(32) // Reuse the same random generation function
if err != nil {
return nil, fmt.Errorf("generate totp encryption key error: %w", err)
}
cfg.Totp.EncryptionKey = key
cfg.Totp.EncryptionKeyConfigured = false
log.Println("Warning: TOTP encryption key auto-generated. Consider setting a fixed key for production.")
} else {
cfg.Totp.EncryptionKeyConfigured = true
}
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("validate config error: %w", err)
}
@@ -697,6 +781,9 @@ func setDefaults() {
viper.SetDefault("jwt.secret", "")
viper.SetDefault("jwt.expire_hour", 24)
// TOTP
viper.SetDefault("totp.encryption_key", "")
// Default
// Admin credentials are created via the setup flow (web wizard / CLI / AUTO_SETUP).
// Do not ship fixed defaults here to avoid insecure "known credentials" in production.
@@ -747,12 +834,22 @@ func setDefaults() {
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
// Usage cleanup task
viper.SetDefault("usage_cleanup.enabled", true)
viper.SetDefault("usage_cleanup.max_range_days", 31)
viper.SetDefault("usage_cleanup.batch_size", 5000)
viper.SetDefault("usage_cleanup.worker_interval_seconds", 10)
viper.SetDefault("usage_cleanup.task_timeout_seconds", 1800)
// Gateway
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头LLM高负载时可能排队较久
viper.SetDefault("gateway.log_upstream_error_body", true)
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
viper.SetDefault("gateway.inject_beta_for_apikey", false)
viper.SetDefault("gateway.failover_on_400", false)
viper.SetDefault("gateway.max_account_switches", 10)
viper.SetDefault("gateway.max_account_switches_gemini", 3)
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
@@ -765,11 +862,12 @@ func setDefaults() {
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
viper.SetDefault("gateway.stream_keepalive_interval", 10)
viper.SetDefault("gateway.max_line_size", 10*1024*1024)
viper.SetDefault("gateway.max_line_size", 40*1024*1024)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used")
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
viper.SetDefault("gateway.scheduling.db_fallback_enabled", true)
@@ -781,6 +879,8 @@ func setDefaults() {
viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3)
viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000)
viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300)
// TLS指纹伪装配置默认关闭需要账号级别单独启用
viper.SetDefault("gateway.tls_fingerprint.enabled", true)
viper.SetDefault("concurrency.ping_interval", 10)
// TokenRefresh
@@ -983,6 +1083,33 @@ func (c *Config) Validate() error {
return fmt.Errorf("dashboard_aggregation.recompute_days must be non-negative")
}
}
if c.UsageCleanup.Enabled {
if c.UsageCleanup.MaxRangeDays <= 0 {
return fmt.Errorf("usage_cleanup.max_range_days must be positive")
}
if c.UsageCleanup.BatchSize <= 0 {
return fmt.Errorf("usage_cleanup.batch_size must be positive")
}
if c.UsageCleanup.WorkerIntervalSeconds <= 0 {
return fmt.Errorf("usage_cleanup.worker_interval_seconds must be positive")
}
if c.UsageCleanup.TaskTimeoutSeconds <= 0 {
return fmt.Errorf("usage_cleanup.task_timeout_seconds must be positive")
}
} else {
if c.UsageCleanup.MaxRangeDays < 0 {
return fmt.Errorf("usage_cleanup.max_range_days must be non-negative")
}
if c.UsageCleanup.BatchSize < 0 {
return fmt.Errorf("usage_cleanup.batch_size must be non-negative")
}
if c.UsageCleanup.WorkerIntervalSeconds < 0 {
return fmt.Errorf("usage_cleanup.worker_interval_seconds must be non-negative")
}
if c.UsageCleanup.TaskTimeoutSeconds < 0 {
return fmt.Errorf("usage_cleanup.task_timeout_seconds must be non-negative")
}
}
if c.Gateway.MaxBodySize <= 0 {
return fmt.Errorf("gateway.max_body_size must be positive")
}

View File

@@ -280,3 +280,573 @@ func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) {
t.Fatalf("Validate() expected backfill_max_days error, got: %v", err)
}
}
func TestLoadDefaultUsageCleanupConfig(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if !cfg.UsageCleanup.Enabled {
t.Fatalf("UsageCleanup.Enabled = false, want true")
}
if cfg.UsageCleanup.MaxRangeDays != 31 {
t.Fatalf("UsageCleanup.MaxRangeDays = %d, want 31", cfg.UsageCleanup.MaxRangeDays)
}
if cfg.UsageCleanup.BatchSize != 5000 {
t.Fatalf("UsageCleanup.BatchSize = %d, want 5000", cfg.UsageCleanup.BatchSize)
}
if cfg.UsageCleanup.WorkerIntervalSeconds != 10 {
t.Fatalf("UsageCleanup.WorkerIntervalSeconds = %d, want 10", cfg.UsageCleanup.WorkerIntervalSeconds)
}
if cfg.UsageCleanup.TaskTimeoutSeconds != 1800 {
t.Fatalf("UsageCleanup.TaskTimeoutSeconds = %d, want 1800", cfg.UsageCleanup.TaskTimeoutSeconds)
}
}
func TestValidateUsageCleanupConfigEnabled(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.UsageCleanup.Enabled = true
cfg.UsageCleanup.MaxRangeDays = 0
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error for usage_cleanup.max_range_days, got nil")
}
if !strings.Contains(err.Error(), "usage_cleanup.max_range_days") {
t.Fatalf("Validate() expected max_range_days error, got: %v", err)
}
}
func TestValidateUsageCleanupConfigDisabled(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.UsageCleanup.Enabled = false
cfg.UsageCleanup.BatchSize = -1
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error for usage_cleanup.batch_size, got nil")
}
if !strings.Contains(err.Error(), "usage_cleanup.batch_size") {
t.Fatalf("Validate() expected batch_size error, got: %v", err)
}
}
func TestConfigAddressHelpers(t *testing.T) {
server := ServerConfig{Host: "127.0.0.1", Port: 9000}
if server.Address() != "127.0.0.1:9000" {
t.Fatalf("ServerConfig.Address() = %q", server.Address())
}
dbCfg := DatabaseConfig{
Host: "localhost",
Port: 5432,
User: "postgres",
Password: "",
DBName: "sub2api",
SSLMode: "disable",
}
if !strings.Contains(dbCfg.DSN(), "password=") {
} else {
t.Fatalf("DatabaseConfig.DSN() should not include password when empty")
}
dbCfg.Password = "secret"
if !strings.Contains(dbCfg.DSN(), "password=secret") {
t.Fatalf("DatabaseConfig.DSN() missing password")
}
dbCfg.Password = ""
if strings.Contains(dbCfg.DSNWithTimezone("UTC"), "password=") {
t.Fatalf("DatabaseConfig.DSNWithTimezone() should omit password when empty")
}
if !strings.Contains(dbCfg.DSNWithTimezone(""), "TimeZone=Asia/Shanghai") {
t.Fatalf("DatabaseConfig.DSNWithTimezone() should use default timezone")
}
if !strings.Contains(dbCfg.DSNWithTimezone("UTC"), "TimeZone=UTC") {
t.Fatalf("DatabaseConfig.DSNWithTimezone() should use provided timezone")
}
redis := RedisConfig{Host: "redis", Port: 6379}
if redis.Address() != "redis:6379" {
t.Fatalf("RedisConfig.Address() = %q", redis.Address())
}
}
func TestNormalizeStringSlice(t *testing.T) {
values := normalizeStringSlice([]string{" a ", "", "b", " ", "c"})
if len(values) != 3 || values[0] != "a" || values[1] != "b" || values[2] != "c" {
t.Fatalf("normalizeStringSlice() unexpected result: %#v", values)
}
if normalizeStringSlice(nil) != nil {
t.Fatalf("normalizeStringSlice(nil) expected nil slice")
}
}
func TestGetServerAddressFromEnv(t *testing.T) {
t.Setenv("SERVER_HOST", "127.0.0.1")
t.Setenv("SERVER_PORT", "9090")
address := GetServerAddress()
if address != "127.0.0.1:9090" {
t.Fatalf("GetServerAddress() = %q", address)
}
}
func TestValidateAbsoluteHTTPURL(t *testing.T) {
if err := ValidateAbsoluteHTTPURL("https://example.com/path"); err != nil {
t.Fatalf("ValidateAbsoluteHTTPURL valid url error: %v", err)
}
if err := ValidateAbsoluteHTTPURL(""); err == nil {
t.Fatalf("ValidateAbsoluteHTTPURL should reject empty url")
}
if err := ValidateAbsoluteHTTPURL("/relative"); err == nil {
t.Fatalf("ValidateAbsoluteHTTPURL should reject relative url")
}
if err := ValidateAbsoluteHTTPURL("ftp://example.com"); err == nil {
t.Fatalf("ValidateAbsoluteHTTPURL should reject ftp scheme")
}
if err := ValidateAbsoluteHTTPURL("https://example.com/#frag"); err == nil {
t.Fatalf("ValidateAbsoluteHTTPURL should reject fragment")
}
}
func TestValidateFrontendRedirectURL(t *testing.T) {
if err := ValidateFrontendRedirectURL("/auth/callback"); err != nil {
t.Fatalf("ValidateFrontendRedirectURL relative error: %v", err)
}
if err := ValidateFrontendRedirectURL("https://example.com/auth"); err != nil {
t.Fatalf("ValidateFrontendRedirectURL absolute error: %v", err)
}
if err := ValidateFrontendRedirectURL("example.com/path"); err == nil {
t.Fatalf("ValidateFrontendRedirectURL should reject non-absolute url")
}
if err := ValidateFrontendRedirectURL("//evil.com"); err == nil {
t.Fatalf("ValidateFrontendRedirectURL should reject // prefix")
}
if err := ValidateFrontendRedirectURL("javascript:alert(1)"); err == nil {
t.Fatalf("ValidateFrontendRedirectURL should reject javascript scheme")
}
}
func TestWarnIfInsecureURL(t *testing.T) {
warnIfInsecureURL("test", "http://example.com")
warnIfInsecureURL("test", "bad://url")
}
func TestGenerateJWTSecretDefaultLength(t *testing.T) {
secret, err := generateJWTSecret(0)
if err != nil {
t.Fatalf("generateJWTSecret error: %v", err)
}
if len(secret) == 0 {
t.Fatalf("generateJWTSecret returned empty string")
}
}
func TestValidateOpsCleanupScheduleRequired(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Ops.Cleanup.Enabled = true
cfg.Ops.Cleanup.Schedule = ""
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error for ops.cleanup.schedule")
}
if !strings.Contains(err.Error(), "ops.cleanup.schedule") {
t.Fatalf("Validate() expected ops.cleanup.schedule error, got: %v", err)
}
}
func TestValidateConcurrencyPingInterval(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Concurrency.PingInterval = 3
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error for concurrency.ping_interval")
}
if !strings.Contains(err.Error(), "concurrency.ping_interval") {
t.Fatalf("Validate() expected concurrency.ping_interval error, got: %v", err)
}
}
func TestProvideConfig(t *testing.T) {
viper.Reset()
if _, err := ProvideConfig(); err != nil {
t.Fatalf("ProvideConfig() error: %v", err)
}
}
func TestValidateConfigWithLinuxDoEnabled(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Security.CSP.Enabled = true
cfg.Security.CSP.Policy = "default-src 'self'"
cfg.LinuxDo.Enabled = true
cfg.LinuxDo.ClientID = "client"
cfg.LinuxDo.ClientSecret = "secret"
cfg.LinuxDo.AuthorizeURL = "https://example.com/oauth2/authorize"
cfg.LinuxDo.TokenURL = "https://example.com/oauth2/token"
cfg.LinuxDo.UserInfoURL = "https://example.com/oauth2/userinfo"
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback"
cfg.LinuxDo.TokenAuthMethod = "client_secret_post"
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate() unexpected error: %v", err)
}
}
func TestValidateJWTSecretStrength(t *testing.T) {
if !isWeakJWTSecret("change-me-in-production") {
t.Fatalf("isWeakJWTSecret should detect weak secret")
}
if isWeakJWTSecret("StrongSecretValue") {
t.Fatalf("isWeakJWTSecret should accept strong secret")
}
}
func TestGenerateJWTSecretWithLength(t *testing.T) {
secret, err := generateJWTSecret(16)
if err != nil {
t.Fatalf("generateJWTSecret error: %v", err)
}
if len(secret) == 0 {
t.Fatalf("generateJWTSecret returned empty string")
}
}
func TestValidateAbsoluteHTTPURLMissingHost(t *testing.T) {
if err := ValidateAbsoluteHTTPURL("https://"); err == nil {
t.Fatalf("ValidateAbsoluteHTTPURL should reject missing host")
}
}
func TestValidateFrontendRedirectURLInvalidChars(t *testing.T) {
if err := ValidateFrontendRedirectURL("/auth/\ncallback"); err == nil {
t.Fatalf("ValidateFrontendRedirectURL should reject invalid chars")
}
if err := ValidateFrontendRedirectURL("http://"); err == nil {
t.Fatalf("ValidateFrontendRedirectURL should reject missing host")
}
if err := ValidateFrontendRedirectURL("mailto:user@example.com"); err == nil {
t.Fatalf("ValidateFrontendRedirectURL should reject mailto")
}
}
func TestWarnIfInsecureURLHTTPS(t *testing.T) {
warnIfInsecureURL("secure", "https://example.com")
}
func TestValidateConfigErrors(t *testing.T) {
buildValid := func(t *testing.T) *Config {
t.Helper()
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
return cfg
}
cases := []struct {
name string
mutate func(*Config)
wantErr string
}{
{
name: "jwt expire hour positive",
mutate: func(c *Config) { c.JWT.ExpireHour = 0 },
wantErr: "jwt.expire_hour must be positive",
},
{
name: "jwt expire hour max",
mutate: func(c *Config) { c.JWT.ExpireHour = 200 },
wantErr: "jwt.expire_hour must be <= 168",
},
{
name: "csp policy required",
mutate: func(c *Config) { c.Security.CSP.Enabled = true; c.Security.CSP.Policy = "" },
wantErr: "security.csp.policy",
},
{
name: "linuxdo client id required",
mutate: func(c *Config) {
c.LinuxDo.Enabled = true
c.LinuxDo.ClientID = ""
},
wantErr: "linuxdo_connect.client_id",
},
{
name: "linuxdo token auth method",
mutate: func(c *Config) {
c.LinuxDo.Enabled = true
c.LinuxDo.ClientID = "client"
c.LinuxDo.ClientSecret = "secret"
c.LinuxDo.AuthorizeURL = "https://example.com/authorize"
c.LinuxDo.TokenURL = "https://example.com/token"
c.LinuxDo.UserInfoURL = "https://example.com/userinfo"
c.LinuxDo.RedirectURL = "https://example.com/callback"
c.LinuxDo.FrontendRedirectURL = "/auth/callback"
c.LinuxDo.TokenAuthMethod = "invalid"
},
wantErr: "linuxdo_connect.token_auth_method",
},
{
name: "billing circuit breaker threshold",
mutate: func(c *Config) { c.Billing.CircuitBreaker.FailureThreshold = 0 },
wantErr: "billing.circuit_breaker.failure_threshold",
},
{
name: "billing circuit breaker reset",
mutate: func(c *Config) { c.Billing.CircuitBreaker.ResetTimeoutSeconds = 0 },
wantErr: "billing.circuit_breaker.reset_timeout_seconds",
},
{
name: "billing circuit breaker half open",
mutate: func(c *Config) { c.Billing.CircuitBreaker.HalfOpenRequests = 0 },
wantErr: "billing.circuit_breaker.half_open_requests",
},
{
name: "database max open conns",
mutate: func(c *Config) { c.Database.MaxOpenConns = 0 },
wantErr: "database.max_open_conns",
},
{
name: "database max lifetime",
mutate: func(c *Config) { c.Database.ConnMaxLifetimeMinutes = -1 },
wantErr: "database.conn_max_lifetime_minutes",
},
{
name: "database idle exceeds open",
mutate: func(c *Config) { c.Database.MaxIdleConns = c.Database.MaxOpenConns + 1 },
wantErr: "database.max_idle_conns cannot exceed",
},
{
name: "redis dial timeout",
mutate: func(c *Config) { c.Redis.DialTimeoutSeconds = 0 },
wantErr: "redis.dial_timeout_seconds",
},
{
name: "redis read timeout",
mutate: func(c *Config) { c.Redis.ReadTimeoutSeconds = 0 },
wantErr: "redis.read_timeout_seconds",
},
{
name: "redis write timeout",
mutate: func(c *Config) { c.Redis.WriteTimeoutSeconds = 0 },
wantErr: "redis.write_timeout_seconds",
},
{
name: "redis pool size",
mutate: func(c *Config) { c.Redis.PoolSize = 0 },
wantErr: "redis.pool_size",
},
{
name: "redis idle exceeds pool",
mutate: func(c *Config) { c.Redis.MinIdleConns = c.Redis.PoolSize + 1 },
wantErr: "redis.min_idle_conns cannot exceed",
},
{
name: "dashboard cache disabled negative",
mutate: func(c *Config) { c.Dashboard.Enabled = false; c.Dashboard.StatsTTLSeconds = -1 },
wantErr: "dashboard_cache.stats_ttl_seconds",
},
{
name: "dashboard cache fresh ttl positive",
mutate: func(c *Config) { c.Dashboard.Enabled = true; c.Dashboard.StatsFreshTTLSeconds = 0 },
wantErr: "dashboard_cache.stats_fresh_ttl_seconds",
},
{
name: "dashboard aggregation enabled interval",
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.IntervalSeconds = 0 },
wantErr: "dashboard_aggregation.interval_seconds",
},
{
name: "dashboard aggregation backfill positive",
mutate: func(c *Config) {
c.DashboardAgg.Enabled = true
c.DashboardAgg.BackfillEnabled = true
c.DashboardAgg.BackfillMaxDays = 0
},
wantErr: "dashboard_aggregation.backfill_max_days",
},
{
name: "dashboard aggregation retention",
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
wantErr: "dashboard_aggregation.retention.usage_logs_days",
},
{
name: "dashboard aggregation disabled interval",
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
wantErr: "dashboard_aggregation.interval_seconds",
},
{
name: "usage cleanup max range",
mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.MaxRangeDays = 0 },
wantErr: "usage_cleanup.max_range_days",
},
{
name: "usage cleanup worker interval",
mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.WorkerIntervalSeconds = 0 },
wantErr: "usage_cleanup.worker_interval_seconds",
},
{
name: "usage cleanup batch size",
mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.BatchSize = 0 },
wantErr: "usage_cleanup.batch_size",
},
{
name: "usage cleanup disabled negative",
mutate: func(c *Config) { c.UsageCleanup.Enabled = false; c.UsageCleanup.BatchSize = -1 },
wantErr: "usage_cleanup.batch_size",
},
{
name: "gateway max body size",
mutate: func(c *Config) { c.Gateway.MaxBodySize = 0 },
wantErr: "gateway.max_body_size",
},
{
name: "gateway max idle conns",
mutate: func(c *Config) { c.Gateway.MaxIdleConns = 0 },
wantErr: "gateway.max_idle_conns",
},
{
name: "gateway max idle conns per host",
mutate: func(c *Config) { c.Gateway.MaxIdleConnsPerHost = 0 },
wantErr: "gateway.max_idle_conns_per_host",
},
{
name: "gateway idle timeout",
mutate: func(c *Config) { c.Gateway.IdleConnTimeoutSeconds = 0 },
wantErr: "gateway.idle_conn_timeout_seconds",
},
{
name: "gateway max upstream clients",
mutate: func(c *Config) { c.Gateway.MaxUpstreamClients = 0 },
wantErr: "gateway.max_upstream_clients",
},
{
name: "gateway client idle ttl",
mutate: func(c *Config) { c.Gateway.ClientIdleTTLSeconds = 0 },
wantErr: "gateway.client_idle_ttl_seconds",
},
{
name: "gateway concurrency slot ttl",
mutate: func(c *Config) { c.Gateway.ConcurrencySlotTTLMinutes = 0 },
wantErr: "gateway.concurrency_slot_ttl_minutes",
},
{
name: "gateway max conns per host",
mutate: func(c *Config) { c.Gateway.MaxConnsPerHost = -1 },
wantErr: "gateway.max_conns_per_host",
},
{
name: "gateway connection isolation",
mutate: func(c *Config) { c.Gateway.ConnectionPoolIsolation = "invalid" },
wantErr: "gateway.connection_pool_isolation",
},
{
name: "gateway stream keepalive range",
mutate: func(c *Config) { c.Gateway.StreamKeepaliveInterval = 4 },
wantErr: "gateway.stream_keepalive_interval",
},
{
name: "gateway stream data interval range",
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = 5 },
wantErr: "gateway.stream_data_interval_timeout",
},
{
name: "gateway stream data interval negative",
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = -1 },
wantErr: "gateway.stream_data_interval_timeout must be non-negative",
},
{
name: "gateway max line size",
mutate: func(c *Config) { c.Gateway.MaxLineSize = 1024 },
wantErr: "gateway.max_line_size must be at least",
},
{
name: "gateway max line size negative",
mutate: func(c *Config) { c.Gateway.MaxLineSize = -1 },
wantErr: "gateway.max_line_size must be non-negative",
},
{
name: "gateway scheduling sticky waiting",
mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 },
wantErr: "gateway.scheduling.sticky_session_max_waiting",
},
{
name: "gateway scheduling outbox poll",
mutate: func(c *Config) { c.Gateway.Scheduling.OutboxPollIntervalSeconds = 0 },
wantErr: "gateway.scheduling.outbox_poll_interval_seconds",
},
{
name: "gateway scheduling outbox failures",
mutate: func(c *Config) { c.Gateway.Scheduling.OutboxLagRebuildFailures = 0 },
wantErr: "gateway.scheduling.outbox_lag_rebuild_failures",
},
{
name: "gateway outbox lag rebuild",
mutate: func(c *Config) {
c.Gateway.Scheduling.OutboxLagWarnSeconds = 10
c.Gateway.Scheduling.OutboxLagRebuildSeconds = 5
},
wantErr: "gateway.scheduling.outbox_lag_rebuild_seconds",
},
{
name: "ops metrics collector ttl",
mutate: func(c *Config) { c.Ops.MetricsCollectorCache.TTL = -1 },
wantErr: "ops.metrics_collector_cache.ttl",
},
{
name: "ops cleanup retention",
mutate: func(c *Config) { c.Ops.Cleanup.ErrorLogRetentionDays = -1 },
wantErr: "ops.cleanup.error_log_retention_days",
},
{
name: "ops cleanup minute retention",
mutate: func(c *Config) { c.Ops.Cleanup.MinuteMetricsRetentionDays = -1 },
wantErr: "ops.cleanup.minute_metrics_retention_days",
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
cfg := buildValid(t)
tt.mutate(cfg)
err := cfg.Validate()
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("Validate() error = %v, want %q", err, tt.wantErr)
}
})
}
}

View File

@@ -44,6 +44,8 @@ type AccountHandler struct {
accountTestService *service.AccountTestService
concurrencyService *service.ConcurrencyService
crsSyncService *service.CRSSyncService
sessionLimitCache service.SessionLimitCache
tokenCacheInvalidator service.TokenCacheInvalidator
}
// NewAccountHandler creates a new admin account handler
@@ -58,6 +60,8 @@ func NewAccountHandler(
accountTestService *service.AccountTestService,
concurrencyService *service.ConcurrencyService,
crsSyncService *service.CRSSyncService,
sessionLimitCache service.SessionLimitCache,
tokenCacheInvalidator service.TokenCacheInvalidator,
) *AccountHandler {
return &AccountHandler{
adminService: adminService,
@@ -70,6 +74,8 @@ func NewAccountHandler(
accountTestService: accountTestService,
concurrencyService: concurrencyService,
crsSyncService: crsSyncService,
sessionLimitCache: sessionLimitCache,
tokenCacheInvalidator: tokenCacheInvalidator,
}
}
@@ -84,6 +90,7 @@ type CreateAccountRequest struct {
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
RateMultiplier *float64 `json:"rate_multiplier"`
GroupIDs []int64 `json:"group_ids"`
ExpiresAt *int64 `json:"expires_at"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
@@ -101,6 +108,7 @@ type UpdateAccountRequest struct {
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
RateMultiplier *float64 `json:"rate_multiplier"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
GroupIDs *[]int64 `json:"group_ids"`
ExpiresAt *int64 `json:"expires_at"`
@@ -115,6 +123,7 @@ type BulkUpdateAccountsRequest struct {
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
RateMultiplier *float64 `json:"rate_multiplier"`
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
Schedulable *bool `json:"schedulable"`
GroupIDs *[]int64 `json:"group_ids"`
@@ -127,6 +136,9 @@ type BulkUpdateAccountsRequest struct {
type AccountWithConcurrency struct {
*dto.Account
CurrentConcurrency int `json:"current_concurrency"`
// 以下字段仅对 Anthropic OAuth/SetupToken 账号有效,且仅在启用相应功能时返回
CurrentWindowCost *float64 `json:"current_window_cost,omitempty"` // 当前窗口费用
ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数
}
// List handles listing all accounts with pagination
@@ -161,13 +173,87 @@ func (h *AccountHandler) List(c *gin.Context) {
concurrencyCounts = make(map[int64]int)
}
// 识别需要查询窗口费用和会话数的账号Anthropic OAuth/SetupToken 且启用了相应功能)
windowCostAccountIDs := make([]int64, 0)
sessionLimitAccountIDs := make([]int64, 0)
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
for i := range accounts {
acc := &accounts[i]
if acc.IsAnthropicOAuthOrSetupToken() {
if acc.GetWindowCostLimit() > 0 {
windowCostAccountIDs = append(windowCostAccountIDs, acc.ID)
}
if acc.GetMaxSessions() > 0 {
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
}
}
}
// 并行获取窗口费用和活跃会话数
var windowCosts map[int64]float64
var activeSessions map[int64]int
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
if activeSessions == nil {
activeSessions = make(map[int64]int)
}
}
// 获取窗口费用(并行查询)
if len(windowCostAccountIDs) > 0 {
windowCosts = make(map[int64]float64)
var mu sync.Mutex
g, gctx := errgroup.WithContext(c.Request.Context())
g.SetLimit(10) // 限制并发数
for i := range accounts {
acc := &accounts[i]
if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 {
continue
}
accCopy := acc // 闭包捕获
g.Go(func() error {
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
startTime := accCopy.GetCurrentWindowStartTime()
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
if err == nil && stats != nil {
mu.Lock()
windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用
mu.Unlock()
}
return nil // 不返回错误,允许部分失败
})
}
_ = g.Wait()
}
// Build response with concurrency info
result := make([]AccountWithConcurrency, len(accounts))
for i := range accounts {
result[i] = AccountWithConcurrency{
Account: dto.AccountFromService(&accounts[i]),
CurrentConcurrency: concurrencyCounts[accounts[i].ID],
acc := &accounts[i]
item := AccountWithConcurrency{
Account: dto.AccountFromService(acc),
CurrentConcurrency: concurrencyCounts[acc.ID],
}
// 添加窗口费用(仅当启用时)
if windowCosts != nil {
if cost, ok := windowCosts[acc.ID]; ok {
item.CurrentWindowCost = &cost
}
}
// 添加活跃会话数(仅当启用时)
if activeSessions != nil {
if count, ok := activeSessions[acc.ID]; ok {
item.ActiveSessions = &count
}
}
result[i] = item
}
response.Paginated(c, result, total, page, pageSize)
@@ -199,6 +285,10 @@ func (h *AccountHandler) Create(c *gin.Context) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.RateMultiplier != nil && *req.RateMultiplier < 0 {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -213,6 +303,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
RateMultiplier: req.RateMultiplier,
GroupIDs: req.GroupIDs,
ExpiresAt: req.ExpiresAt,
AutoPauseOnExpired: req.AutoPauseOnExpired,
@@ -258,6 +349,10 @@ func (h *AccountHandler) Update(c *gin.Context) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.RateMultiplier != nil && *req.RateMultiplier < 0 {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -271,6 +366,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
ProxyID: req.ProxyID,
Concurrency: req.Concurrency, // 指针类型nil 表示未提供
Priority: req.Priority, // 指针类型nil 表示未提供
RateMultiplier: req.RateMultiplier,
Status: req.Status,
GroupIDs: req.GroupIDs,
ExpiresAt: req.ExpiresAt,
@@ -450,6 +546,41 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
newCredentials[k] = v
}
}
// 特殊处理 project_id如果新值为空但旧值非空保留旧值
// 这确保了即使 LoadCodeAssist 失败project_id 也不会丢失
if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
newCredentials["project_id"] = oldProjectID
}
}
// 如果 project_id 获取失败,更新凭证但不标记为 error
// LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试
if tokenInfo.ProjectIDMissing {
// 先更新凭证token 本身刷新成功了)
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: newCredentials,
})
if updateErr != nil {
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
return
}
// 不标记为 error只返回警告信息
response.Success(c, gin.H{
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
"warning": "missing_project_id_temporary",
})
return
}
// 成功获取到 project_id如果之前是 missing_project_id 错误则清除
if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") {
if _, clearErr := h.adminService.ClearAccountError(c.Request.Context(), accountID); clearErr != nil {
response.InternalError(c, "Failed to clear account error: "+clearErr.Error())
return
}
}
} else {
// Use Anthropic/Claude OAuth service to refresh token
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
@@ -485,6 +616,14 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
return
}
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
if h.tokenCacheInvalidator != nil {
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), updatedAccount); invalidateErr != nil {
// 缓存失效失败只记录日志,不影响主流程
_ = c.Error(invalidateErr)
}
}
response.Success(c, dto.AccountFromService(updatedAccount))
}
@@ -534,6 +673,15 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
return
}
// 清除错误后,同时清除 token 缓存,确保下次请求会获取最新的 token触发刷新或从 DB 读取)
// 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil {
// 缓存失效失败只记录日志,不影响主流程
_ = c.Error(invalidateErr)
}
}
response.Success(c, dto.AccountFromService(account))
}
@@ -652,6 +800,10 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.RateMultiplier != nil && *req.RateMultiplier < 0 {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -660,6 +812,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
req.ProxyID != nil ||
req.Concurrency != nil ||
req.Priority != nil ||
req.RateMultiplier != nil ||
req.Status != "" ||
req.Schedulable != nil ||
req.GroupIDs != nil ||
@@ -677,6 +830,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
RateMultiplier: req.RateMultiplier,
Status: req.Status,
Schedulable: req.Schedulable,
GroupIDs: req.GroupIDs,

View File

@@ -0,0 +1,262 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func setupAdminRouter() (*gin.Engine, *stubAdminService) {
gin.SetMode(gin.TestMode)
router := gin.New()
adminSvc := newStubAdminService()
userHandler := NewUserHandler(adminSvc)
groupHandler := NewGroupHandler(adminSvc)
proxyHandler := NewProxyHandler(adminSvc)
redeemHandler := NewRedeemHandler(adminSvc)
router.GET("/api/v1/admin/users", userHandler.List)
router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
router.POST("/api/v1/admin/users", userHandler.Create)
router.PUT("/api/v1/admin/users/:id", userHandler.Update)
router.DELETE("/api/v1/admin/users/:id", userHandler.Delete)
router.POST("/api/v1/admin/users/:id/balance", userHandler.UpdateBalance)
router.GET("/api/v1/admin/users/:id/api-keys", userHandler.GetUserAPIKeys)
router.GET("/api/v1/admin/users/:id/usage", userHandler.GetUserUsage)
router.GET("/api/v1/admin/groups", groupHandler.List)
router.GET("/api/v1/admin/groups/all", groupHandler.GetAll)
router.GET("/api/v1/admin/groups/:id", groupHandler.GetByID)
router.POST("/api/v1/admin/groups", groupHandler.Create)
router.PUT("/api/v1/admin/groups/:id", groupHandler.Update)
router.DELETE("/api/v1/admin/groups/:id", groupHandler.Delete)
router.GET("/api/v1/admin/groups/:id/stats", groupHandler.GetStats)
router.GET("/api/v1/admin/groups/:id/api-keys", groupHandler.GetGroupAPIKeys)
router.GET("/api/v1/admin/proxies", proxyHandler.List)
router.GET("/api/v1/admin/proxies/all", proxyHandler.GetAll)
router.GET("/api/v1/admin/proxies/:id", proxyHandler.GetByID)
router.POST("/api/v1/admin/proxies", proxyHandler.Create)
router.PUT("/api/v1/admin/proxies/:id", proxyHandler.Update)
router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete)
router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete)
router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test)
router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats)
router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts)
router.GET("/api/v1/admin/redeem-codes", redeemHandler.List)
router.GET("/api/v1/admin/redeem-codes/:id", redeemHandler.GetByID)
router.POST("/api/v1/admin/redeem-codes", redeemHandler.Generate)
router.DELETE("/api/v1/admin/redeem-codes/:id", redeemHandler.Delete)
router.POST("/api/v1/admin/redeem-codes/batch-delete", redeemHandler.BatchDelete)
router.POST("/api/v1/admin/redeem-codes/:id/expire", redeemHandler.Expire)
router.GET("/api/v1/admin/redeem-codes/:id/stats", redeemHandler.GetStats)
return router, adminSvc
}
func TestUserHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/users?page=1&page_size=20", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2}
body, _ := json.Marshal(createBody)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
updateBody := map[string]any{"email": "updated@example.com"}
body, _ = json.Marshal(updateBody)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/users/1", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/users/1", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/balance", bytes.NewBufferString(`{"balance":1,"operation":"add"}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1/api-keys", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1/usage?period=today", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestGroupHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/all", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ := json.Marshal(map[string]any{"name": "new", "platform": "anthropic", "subscription_type": "standard"})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/groups", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ = json.Marshal(map[string]any{"name": "update"})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/groups/2", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/groups/2", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2/stats", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2/api-keys", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestProxyHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/all", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ := json.Marshal(map[string]any{"name": "proxy", "protocol": "http", "host": "localhost", "port": 8080})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ = json.Marshal(map[string]any{"name": "proxy2"})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/proxies/4", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/proxies/4", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/batch-delete", bytes.NewBufferString(`{"ids":[1,2]}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/test", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/accounts", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestRedeemHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/5", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ := json.Marshal(map[string]any{"count": 1, "type": "balance", "value": 10})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/redeem-codes/5", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/batch-delete", bytes.NewBufferString(`{"ids":[1,2]}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/5/expire", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/5/stats", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}

View File

@@ -0,0 +1,134 @@
package admin
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/netip"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestParseTimeRange(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
req := httptest.NewRequest(http.MethodGet, "/?start_date=2024-01-01&end_date=2024-01-02&timezone=UTC", nil)
c.Request = req
start, end := parseTimeRange(c)
require.Equal(t, time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), start)
require.Equal(t, time.Date(2024, 1, 3, 0, 0, 0, 0, time.UTC), end)
req = httptest.NewRequest(http.MethodGet, "/?start_date=bad&timezone=UTC", nil)
c.Request = req
start, end = parseTimeRange(c)
require.False(t, start.IsZero())
require.False(t, end.IsZero())
}
func TestParseOpsViewParam(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/?view=excluded", nil)
require.Equal(t, opsListViewExcluded, parseOpsViewParam(c))
c2, _ := gin.CreateTestContext(w)
c2.Request = httptest.NewRequest(http.MethodGet, "/?view=all", nil)
require.Equal(t, opsListViewAll, parseOpsViewParam(c2))
c3, _ := gin.CreateTestContext(w)
c3.Request = httptest.NewRequest(http.MethodGet, "/?view=unknown", nil)
require.Equal(t, opsListViewErrors, parseOpsViewParam(c3))
require.Equal(t, "", parseOpsViewParam(nil))
}
func TestParseOpsDuration(t *testing.T) {
dur, ok := parseOpsDuration("1h")
require.True(t, ok)
require.Equal(t, time.Hour, dur)
_, ok = parseOpsDuration("invalid")
require.False(t, ok)
}
func TestParseOpsTimeRange(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
now := time.Now().UTC()
startStr := now.Add(-time.Hour).Format(time.RFC3339)
endStr := now.Format(time.RFC3339)
c.Request = httptest.NewRequest(http.MethodGet, "/?start_time="+startStr+"&end_time="+endStr, nil)
start, end, err := parseOpsTimeRange(c, "1h")
require.NoError(t, err)
require.True(t, start.Before(end))
c2, _ := gin.CreateTestContext(w)
c2.Request = httptest.NewRequest(http.MethodGet, "/?start_time=bad", nil)
_, _, err = parseOpsTimeRange(c2, "1h")
require.Error(t, err)
}
func TestParseOpsRealtimeWindow(t *testing.T) {
dur, label, ok := parseOpsRealtimeWindow("5m")
require.True(t, ok)
require.Equal(t, 5*time.Minute, dur)
require.Equal(t, "5min", label)
_, _, ok = parseOpsRealtimeWindow("invalid")
require.False(t, ok)
}
func TestPickThroughputBucketSeconds(t *testing.T) {
require.Equal(t, 60, pickThroughputBucketSeconds(30*time.Minute))
require.Equal(t, 300, pickThroughputBucketSeconds(6*time.Hour))
require.Equal(t, 3600, pickThroughputBucketSeconds(48*time.Hour))
}
func TestParseOpsQueryMode(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/?mode=raw", nil)
require.Equal(t, service.ParseOpsQueryMode("raw"), parseOpsQueryMode(c))
require.Equal(t, service.OpsQueryMode(""), parseOpsQueryMode(nil))
}
func TestOpsAlertRuleValidation(t *testing.T) {
raw := map[string]json.RawMessage{
"name": json.RawMessage(`"High error rate"`),
"metric_type": json.RawMessage(`"error_rate"`),
"operator": json.RawMessage(`">"`),
"threshold": json.RawMessage(`90`),
}
validated, err := validateOpsAlertRulePayload(raw)
require.NoError(t, err)
require.Equal(t, "High error rate", validated.Name)
_, err = validateOpsAlertRulePayload(map[string]json.RawMessage{})
require.Error(t, err)
require.True(t, isPercentOrRateMetric("error_rate"))
require.False(t, isPercentOrRateMetric("concurrency_queue_depth"))
}
func TestOpsWSHelpers(t *testing.T) {
prefixes, invalid := parseTrustedProxyList("10.0.0.0/8,invalid")
require.Len(t, prefixes, 1)
require.Len(t, invalid, 1)
host := hostWithoutPort("example.com:443")
require.Equal(t, "example.com", host)
addr := netip.MustParseAddr("10.0.0.1")
require.True(t, isAddrInTrustedProxies(addr, prefixes))
require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes))
}

View File

@@ -0,0 +1,294 @@
package admin
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type stubAdminService struct {
users []service.User
apiKeys []service.APIKey
groups []service.Group
accounts []service.Account
proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode
}
func newStubAdminService() *stubAdminService {
now := time.Now().UTC()
user := service.User{
ID: 1,
Email: "user@example.com",
Role: service.RoleUser,
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
apiKey := service.APIKey{
ID: 10,
UserID: user.ID,
Key: "sk-test",
Name: "test",
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
group := service.Group{
ID: 2,
Name: "group",
Platform: service.PlatformAnthropic,
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
account := service.Account{
ID: 3,
Name: "account",
Platform: service.PlatformAnthropic,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
proxy := service.Proxy{
ID: 4,
Name: "proxy",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
redeem := service.RedeemCode{
ID: 5,
Code: "R-TEST",
Type: service.RedeemTypeBalance,
Value: 10,
Status: service.StatusUnused,
CreatedAt: now,
}
return &stubAdminService{
users: []service.User{user},
apiKeys: []service.APIKey{apiKey},
groups: []service.Group{group},
accounts: []service.Account{account},
proxies: []service.Proxy{proxy},
proxyCounts: []service.ProxyWithAccountCount{{Proxy: proxy, AccountCount: 1}},
redeems: []service.RedeemCode{redeem},
}
}
func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters) ([]service.User, int64, error) {
return s.users, int64(len(s.users)), nil
}
func (s *stubAdminService) GetUser(ctx context.Context, id int64) (*service.User, error) {
for i := range s.users {
if s.users[i].ID == id {
return &s.users[i], nil
}
}
user := service.User{ID: id, Email: "user@example.com", Status: service.StatusActive}
return &user, nil
}
func (s *stubAdminService) CreateUser(ctx context.Context, input *service.CreateUserInput) (*service.User, error) {
user := service.User{ID: 100, Email: input.Email, Status: service.StatusActive}
return &user, nil
}
func (s *stubAdminService) UpdateUser(ctx context.Context, id int64, input *service.UpdateUserInput) (*service.User, error) {
user := service.User{ID: id, Email: "updated@example.com", Status: service.StatusActive}
return &user, nil
}
func (s *stubAdminService) DeleteUser(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*service.User, error) {
user := service.User{ID: userID, Balance: balance, Status: service.StatusActive}
return &user, nil
}
func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]service.APIKey, int64, error) {
return s.apiKeys, int64(len(s.apiKeys)), nil
}
func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
return map[string]any{"user_id": userID}, nil
}
func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]service.Group, int64, error) {
return s.groups, int64(len(s.groups)), nil
}
func (s *stubAdminService) GetAllGroups(ctx context.Context) ([]service.Group, error) {
return s.groups, nil
}
func (s *stubAdminService) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
return s.groups, nil
}
func (s *stubAdminService) GetGroup(ctx context.Context, id int64) (*service.Group, error) {
group := service.Group{ID: id, Name: "group", Status: service.StatusActive}
return &group, nil
}
func (s *stubAdminService) CreateGroup(ctx context.Context, input *service.CreateGroupInput) (*service.Group, error) {
group := service.Group{ID: 200, Name: input.Name, Status: service.StatusActive}
return &group, nil
}
func (s *stubAdminService) UpdateGroup(ctx context.Context, id int64, input *service.UpdateGroupInput) (*service.Group, error) {
group := service.Group{ID: id, Name: input.Name, Status: service.StatusActive}
return &group, nil
}
func (s *stubAdminService) DeleteGroup(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]service.APIKey, int64, error) {
return s.apiKeys, int64(len(s.apiKeys)), nil
}
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]service.Account, int64, error) {
return s.accounts, int64(len(s.accounts)), nil
}
func (s *stubAdminService) GetAccount(ctx context.Context, id int64) (*service.Account, error) {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
out := make([]*service.Account, 0, len(ids))
for _, id := range ids {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
out = append(out, &account)
}
return out, nil
}
func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.CreateAccountInput) (*service.Account, error) {
account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) DeleteAccount(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) RefreshAccountCredentials(ctx context.Context, id int64) (*service.Account, error) {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) ClearAccountError(ctx context.Context, id int64) (*service.Account, error) {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) SetAccountError(ctx context.Context, id int64, errorMsg string) error {
return nil
}
func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*service.Account, error) {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive, Schedulable: schedulable}
return &account, nil
}
func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *service.BulkUpdateAccountsInput) (*service.BulkUpdateAccountsResult, error) {
return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil
}
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
return s.proxies, int64(len(s.proxies)), nil
}
func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.ProxyWithAccountCount, int64, error) {
return s.proxyCounts, int64(len(s.proxyCounts)), nil
}
func (s *stubAdminService) GetAllProxies(ctx context.Context) ([]service.Proxy, error) {
return s.proxies, nil
}
func (s *stubAdminService) GetAllProxiesWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) {
return s.proxyCounts, nil
}
func (s *stubAdminService) GetProxy(ctx context.Context, id int64) (*service.Proxy, error) {
proxy := service.Proxy{ID: id, Name: "proxy", Status: service.StatusActive}
return &proxy, nil
}
func (s *stubAdminService) CreateProxy(ctx context.Context, input *service.CreateProxyInput) (*service.Proxy, error) {
proxy := service.Proxy{ID: 400, Name: input.Name, Status: service.StatusActive}
return &proxy, nil
}
func (s *stubAdminService) UpdateProxy(ctx context.Context, id int64, input *service.UpdateProxyInput) (*service.Proxy, error) {
proxy := service.Proxy{ID: id, Name: input.Name, Status: service.StatusActive}
return &proxy, nil
}
func (s *stubAdminService) DeleteProxy(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) BatchDeleteProxies(ctx context.Context, ids []int64) (*service.ProxyBatchDeleteResult, error) {
return &service.ProxyBatchDeleteResult{DeletedIDs: ids}, nil
}
func (s *stubAdminService) GetProxyAccounts(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
return []service.ProxyAccountSummary{{ID: 1, Name: "account"}}, nil
}
func (s *stubAdminService) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
return false, nil
}
func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.ProxyTestResult, error) {
return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
}
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) {
return s.redeems, int64(len(s.redeems)), nil
}
func (s *stubAdminService) GetRedeemCode(ctx context.Context, id int64) (*service.RedeemCode, error) {
code := service.RedeemCode{ID: id, Code: "R-TEST", Status: service.StatusUnused}
return &code, nil
}
func (s *stubAdminService) GenerateRedeemCodes(ctx context.Context, input *service.GenerateRedeemCodesInput) ([]service.RedeemCode, error) {
return s.redeems, nil
}
func (s *stubAdminService) DeleteRedeemCode(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) {
return int64(len(ids)), nil
}
func (s *stubAdminService) ExpireRedeemCode(ctx context.Context, id int64) (*service.RedeemCode, error) {
code := service.RedeemCode{ID: id, Code: "R-TEST", Status: service.StatusUsed}
return &code, nil
}
// Ensure stub implements interface.
var _ service.AdminService = (*stubAdminService)(nil)

View File

@@ -186,13 +186,17 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
// GetUsageTrend handles getting usage trend data
// GET /api/v1/admin/dashboard/trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
// Parse optional filter params
var userID, apiKeyID int64
var userID, apiKeyID, accountID, groupID int64
var model string
var stream *bool
var billingType *int8
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = id
@@ -203,8 +207,35 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
apiKeyID = id
}
}
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil {
accountID = id
}
}
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
groupID = id
}
}
if modelStr := c.Query("model"); modelStr != "" {
model = modelStr
}
if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
bt := int8(v)
billingType = &bt
} else {
response.BadRequest(c, "Invalid billing_type")
return
}
}
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID)
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
@@ -220,12 +251,15 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// GetModelStats handles getting model usage statistics
// GET /api/v1/admin/dashboard/models
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
// Parse optional filter params
var userID, apiKeyID int64
var userID, apiKeyID, accountID, groupID int64
var stream *bool
var billingType *int8
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = id
@@ -236,8 +270,32 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
apiKeyID = id
}
}
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil {
accountID = id
}
}
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
groupID = id
}
}
if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
bt := int8(v)
billingType = &bt
} else {
response.BadRequest(c, "Invalid billing_type")
return
}
}
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID)
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return

View File

@@ -40,6 +40,9 @@ type CreateGroupRequest struct {
ImagePrice4K *float64 `json:"image_price_4k"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
}
// UpdateGroupRequest represents update group request
@@ -60,6 +63,9 @@ type UpdateGroupRequest struct {
ImagePrice4K *float64 `json:"image_price_4k"`
ClaudeCodeOnly *bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
}
// List handles listing all groups with pagination
@@ -88,9 +94,9 @@ func (h *GroupHandler) List(c *gin.Context) {
return
}
outGroups := make([]dto.Group, 0, len(groups))
outGroups := make([]dto.AdminGroup, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
outGroups = append(outGroups, *dto.GroupFromServiceAdmin(&groups[i]))
}
response.Paginated(c, outGroups, total, page, pageSize)
}
@@ -114,9 +120,9 @@ func (h *GroupHandler) GetAll(c *gin.Context) {
return
}
outGroups := make([]dto.Group, 0, len(groups))
outGroups := make([]dto.AdminGroup, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
outGroups = append(outGroups, *dto.GroupFromServiceAdmin(&groups[i]))
}
response.Success(c, outGroups)
}
@@ -136,7 +142,7 @@ func (h *GroupHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.GroupFromService(group))
response.Success(c, dto.GroupFromServiceAdmin(group))
}
// Create handles creating a new group
@@ -149,27 +155,29 @@ func (h *GroupHandler) Create(c *gin.Context) {
}
group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{
Name: req.Name,
Description: req.Description,
Platform: req.Platform,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD,
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
Name: req.Name,
Description: req.Description,
Platform: req.Platform,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD,
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.GroupFromService(group))
response.Success(c, dto.GroupFromServiceAdmin(group))
}
// Update handles updating a group
@@ -188,28 +196,30 @@ func (h *GroupHandler) Update(c *gin.Context) {
}
group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{
Name: req.Name,
Description: req.Description,
Platform: req.Platform,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
Status: req.Status,
SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD,
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
Name: req.Name,
Description: req.Description,
Platform: req.Platform,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
Status: req.Status,
SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD,
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.GroupFromService(group))
response.Success(c, dto.GroupFromServiceAdmin(group))
}
// Delete handles deleting a group

View File

@@ -7,8 +7,10 @@ import (
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
@@ -18,8 +20,6 @@ var validOpsAlertMetricTypes = []string{
"success_rate",
"error_rate",
"upstream_error_rate",
"p95_latency_ms",
"p99_latency_ms",
"cpu_usage_percent",
"memory_usage_percent",
"concurrency_queue_depth",
@@ -372,8 +372,135 @@ func (h *OpsHandler) DeleteAlertRule(c *gin.Context) {
response.Success(c, gin.H{"deleted": true})
}
// GetAlertEvent returns a single ops alert event.
// GET /api/v1/admin/ops/alert-events/:id
func (h *OpsHandler) GetAlertEvent(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid event ID")
return
}
ev, err := h.opsService.GetAlertEventByID(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, ev)
}
// UpdateAlertEventStatus updates an ops alert event status.
// PUT /api/v1/admin/ops/alert-events/:id/status
func (h *OpsHandler) UpdateAlertEventStatus(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid event ID")
return
}
var payload struct {
Status string `json:"status"`
}
if err := c.ShouldBindJSON(&payload); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
payload.Status = strings.TrimSpace(payload.Status)
if payload.Status == "" {
response.BadRequest(c, "Invalid status")
return
}
if payload.Status != service.OpsAlertStatusResolved && payload.Status != service.OpsAlertStatusManualResolved {
response.BadRequest(c, "Invalid status")
return
}
var resolvedAt *time.Time
if payload.Status == service.OpsAlertStatusResolved || payload.Status == service.OpsAlertStatusManualResolved {
now := time.Now().UTC()
resolvedAt = &now
}
if err := h.opsService.UpdateAlertEventStatus(c.Request.Context(), id, payload.Status, resolvedAt); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"updated": true})
}
// ListAlertEvents lists recent ops alert events.
// GET /api/v1/admin/ops/alert-events
// CreateAlertSilence creates a scoped silence for ops alerts.
// POST /api/v1/admin/ops/alert-silences
func (h *OpsHandler) CreateAlertSilence(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
var payload struct {
RuleID int64 `json:"rule_id"`
Platform string `json:"platform"`
GroupID *int64 `json:"group_id"`
Region *string `json:"region"`
Until string `json:"until"`
Reason string `json:"reason"`
}
if err := c.ShouldBindJSON(&payload); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
until, err := time.Parse(time.RFC3339, strings.TrimSpace(payload.Until))
if err != nil {
response.BadRequest(c, "Invalid until")
return
}
createdBy := (*int64)(nil)
if subject, ok := middleware.GetAuthSubjectFromContext(c); ok {
uid := subject.UserID
createdBy = &uid
}
silence := &service.OpsAlertSilence{
RuleID: payload.RuleID,
Platform: strings.TrimSpace(payload.Platform),
GroupID: payload.GroupID,
Region: payload.Region,
Until: until,
Reason: strings.TrimSpace(payload.Reason),
CreatedBy: createdBy,
}
created, err := h.opsService.CreateAlertSilence(c.Request.Context(), silence)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, created)
}
func (h *OpsHandler) ListAlertEvents(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
@@ -384,7 +511,7 @@ func (h *OpsHandler) ListAlertEvents(c *gin.Context) {
return
}
limit := 100
limit := 20
if raw := strings.TrimSpace(c.Query("limit")); raw != "" {
n, err := strconv.Atoi(raw)
if err != nil || n <= 0 {
@@ -400,6 +527,49 @@ func (h *OpsHandler) ListAlertEvents(c *gin.Context) {
Severity: strings.TrimSpace(c.Query("severity")),
}
if v := strings.TrimSpace(c.Query("email_sent")); v != "" {
vv := strings.ToLower(v)
switch vv {
case "true", "1":
b := true
filter.EmailSent = &b
case "false", "0":
b := false
filter.EmailSent = &b
default:
response.BadRequest(c, "Invalid email_sent")
return
}
}
// Cursor pagination: both params must be provided together.
rawTS := strings.TrimSpace(c.Query("before_fired_at"))
rawID := strings.TrimSpace(c.Query("before_id"))
if (rawTS == "") != (rawID == "") {
response.BadRequest(c, "before_fired_at and before_id must be provided together")
return
}
if rawTS != "" {
ts, err := time.Parse(time.RFC3339Nano, rawTS)
if err != nil {
if t2, err2 := time.Parse(time.RFC3339, rawTS); err2 == nil {
ts = t2
} else {
response.BadRequest(c, "Invalid before_fired_at")
return
}
}
filter.BeforeFiredAt = &ts
}
if rawID != "" {
id, err := strconv.ParseInt(rawID, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid before_id")
return
}
filter.BeforeID = &id
}
// Optional global filter support (platform/group/time range).
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
filter.Platform = platform

View File

@@ -19,6 +19,57 @@ type OpsHandler struct {
opsService *service.OpsService
}
// GetErrorLogByID returns ops error log detail.
// GET /api/v1/admin/ops/errors/:id
func (h *OpsHandler) GetErrorLogByID(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
detail, err := h.opsService.GetErrorLogByID(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, detail)
}
const (
opsListViewErrors = "errors"
opsListViewExcluded = "excluded"
opsListViewAll = "all"
)
func parseOpsViewParam(c *gin.Context) string {
if c == nil {
return ""
}
v := strings.ToLower(strings.TrimSpace(c.Query("view")))
switch v {
case "", opsListViewErrors:
return opsListViewErrors
case opsListViewExcluded:
return opsListViewExcluded
case opsListViewAll:
return opsListViewAll
default:
return opsListViewErrors
}
}
func NewOpsHandler(opsService *service.OpsService) *OpsHandler {
return &OpsHandler{opsService: opsService}
}
@@ -47,16 +98,26 @@ func (h *OpsHandler) GetErrorLogs(c *gin.Context) {
return
}
filter := &service.OpsErrorLogFilter{
Page: page,
PageSize: pageSize,
}
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
if !startTime.IsZero() {
filter.StartTime = &startTime
}
if !endTime.IsZero() {
filter.EndTime = &endTime
}
filter.View = parseOpsViewParam(c)
filter.Phase = strings.TrimSpace(c.Query("phase"))
filter.Owner = strings.TrimSpace(c.Query("error_owner"))
filter.Source = strings.TrimSpace(c.Query("error_source"))
filter.Query = strings.TrimSpace(c.Query("q"))
filter.UserQuery = strings.TrimSpace(c.Query("user_query"))
// Force request errors: client-visible status >= 400.
// buildOpsErrorLogsWhere already applies this for non-upstream phase.
if strings.EqualFold(strings.TrimSpace(filter.Phase), "upstream") {
filter.Phase = ""
}
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
filter.Platform = platform
@@ -77,11 +138,19 @@ func (h *OpsHandler) GetErrorLogs(c *gin.Context) {
}
filter.AccountID = &id
}
if phase := strings.TrimSpace(c.Query("phase")); phase != "" {
filter.Phase = phase
}
if q := strings.TrimSpace(c.Query("q")); q != "" {
filter.Query = q
if v := strings.TrimSpace(c.Query("resolved")); v != "" {
switch strings.ToLower(v) {
case "1", "true", "yes":
b := true
filter.Resolved = &b
case "0", "false", "no":
b := false
filter.Resolved = &b
default:
response.BadRequest(c, "Invalid resolved")
return
}
}
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
parts := strings.Split(statusCodesStr, ",")
@@ -106,13 +175,120 @@ func (h *OpsHandler) GetErrorLogs(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
}
// GetErrorLogByID returns a single error log detail.
// GET /api/v1/admin/ops/errors/:id
func (h *OpsHandler) GetErrorLogByID(c *gin.Context) {
// ListRequestErrors lists client-visible request errors.
// GET /api/v1/admin/ops/request-errors
func (h *OpsHandler) ListRequestErrors(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
page, pageSize := response.ParsePagination(c)
if pageSize > 500 {
pageSize = 500
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
if !startTime.IsZero() {
filter.StartTime = &startTime
}
if !endTime.IsZero() {
filter.EndTime = &endTime
}
filter.View = parseOpsViewParam(c)
filter.Phase = strings.TrimSpace(c.Query("phase"))
filter.Owner = strings.TrimSpace(c.Query("error_owner"))
filter.Source = strings.TrimSpace(c.Query("error_source"))
filter.Query = strings.TrimSpace(c.Query("q"))
filter.UserQuery = strings.TrimSpace(c.Query("user_query"))
// Force request errors: client-visible status >= 400.
// buildOpsErrorLogsWhere already applies this for non-upstream phase.
if strings.EqualFold(strings.TrimSpace(filter.Phase), "upstream") {
filter.Phase = ""
}
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
filter.Platform = platform
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid account_id")
return
}
filter.AccountID = &id
}
if v := strings.TrimSpace(c.Query("resolved")); v != "" {
switch strings.ToLower(v) {
case "1", "true", "yes":
b := true
filter.Resolved = &b
case "0", "false", "no":
b := false
filter.Resolved = &b
default:
response.BadRequest(c, "Invalid resolved")
return
}
}
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
parts := strings.Split(statusCodesStr, ",")
out := make([]int, 0, len(parts))
for _, part := range parts {
p := strings.TrimSpace(part)
if p == "" {
continue
}
n, err := strconv.Atoi(p)
if err != nil || n < 0 {
response.BadRequest(c, "Invalid status_codes")
return
}
out = append(out, n)
}
filter.StatusCodes = out
}
result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
}
// GetRequestError returns request error detail.
// GET /api/v1/admin/ops/request-errors/:id
func (h *OpsHandler) GetRequestError(c *gin.Context) {
// same storage; just proxy to existing detail
h.GetErrorLogByID(c)
}
// ListRequestErrorUpstreamErrors lists upstream error logs correlated to a request error.
// GET /api/v1/admin/ops/request-errors/:id/upstream-errors
func (h *OpsHandler) ListRequestErrorUpstreamErrors(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
@@ -129,15 +305,306 @@ func (h *OpsHandler) GetErrorLogByID(c *gin.Context) {
return
}
// Load request error to get correlation keys.
detail, err := h.opsService.GetErrorLogByID(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, detail)
// Correlate by request_id/client_request_id.
requestID := strings.TrimSpace(detail.RequestID)
clientRequestID := strings.TrimSpace(detail.ClientRequestID)
if requestID == "" && clientRequestID == "" {
response.Paginated(c, []*service.OpsErrorLog{}, 0, 1, 10)
return
}
page, pageSize := response.ParsePagination(c)
if pageSize > 500 {
pageSize = 500
}
// Keep correlation window wide enough so linked upstream errors
// are discoverable even when UI defaults to 1h elsewhere.
startTime, endTime, err := parseOpsTimeRange(c, "30d")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
if !startTime.IsZero() {
filter.StartTime = &startTime
}
if !endTime.IsZero() {
filter.EndTime = &endTime
}
filter.View = "all"
filter.Phase = "upstream"
filter.Owner = "provider"
filter.Source = strings.TrimSpace(c.Query("error_source"))
filter.Query = strings.TrimSpace(c.Query("q"))
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
filter.Platform = platform
}
// Prefer exact match on request_id; if missing, fall back to client_request_id.
if requestID != "" {
filter.RequestID = requestID
} else {
filter.ClientRequestID = clientRequestID
}
result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
// If client asks for details, expand each upstream error log to include upstream response fields.
includeDetail := strings.TrimSpace(c.Query("include_detail"))
if includeDetail == "1" || strings.EqualFold(includeDetail, "true") || strings.EqualFold(includeDetail, "yes") {
details := make([]*service.OpsErrorLogDetail, 0, len(result.Errors))
for _, item := range result.Errors {
if item == nil {
continue
}
d, err := h.opsService.GetErrorLogByID(c.Request.Context(), item.ID)
if err != nil || d == nil {
continue
}
details = append(details, d)
}
response.Paginated(c, details, int64(result.Total), result.Page, result.PageSize)
return
}
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
}
// RetryRequestErrorClient retries the client request based on stored request body.
// POST /api/v1/admin/ops/request-errors/:id/retry-client
func (h *OpsHandler) RetryRequestErrorClient(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, service.OpsRetryModeClient, nil)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// RetryRequestErrorUpstreamEvent retries a specific upstream attempt using captured upstream_request_body.
// POST /api/v1/admin/ops/request-errors/:id/upstream-errors/:idx/retry
func (h *OpsHandler) RetryRequestErrorUpstreamEvent(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
idxStr := strings.TrimSpace(c.Param("idx"))
idx, err := strconv.Atoi(idxStr)
if err != nil || idx < 0 {
response.BadRequest(c, "Invalid upstream idx")
return
}
result, err := h.opsService.RetryUpstreamEvent(c.Request.Context(), subject.UserID, id, idx)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// ResolveRequestError toggles resolved status.
// PUT /api/v1/admin/ops/request-errors/:id/resolve
func (h *OpsHandler) ResolveRequestError(c *gin.Context) {
h.UpdateErrorResolution(c)
}
// ListUpstreamErrors lists independent upstream errors.
// GET /api/v1/admin/ops/upstream-errors
func (h *OpsHandler) ListUpstreamErrors(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
page, pageSize := response.ParsePagination(c)
if pageSize > 500 {
pageSize = 500
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
if !startTime.IsZero() {
filter.StartTime = &startTime
}
if !endTime.IsZero() {
filter.EndTime = &endTime
}
filter.View = parseOpsViewParam(c)
filter.Phase = "upstream"
filter.Owner = "provider"
filter.Source = strings.TrimSpace(c.Query("error_source"))
filter.Query = strings.TrimSpace(c.Query("q"))
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
filter.Platform = platform
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid account_id")
return
}
filter.AccountID = &id
}
if v := strings.TrimSpace(c.Query("resolved")); v != "" {
switch strings.ToLower(v) {
case "1", "true", "yes":
b := true
filter.Resolved = &b
case "0", "false", "no":
b := false
filter.Resolved = &b
default:
response.BadRequest(c, "Invalid resolved")
return
}
}
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
parts := strings.Split(statusCodesStr, ",")
out := make([]int, 0, len(parts))
for _, part := range parts {
p := strings.TrimSpace(part)
if p == "" {
continue
}
n, err := strconv.Atoi(p)
if err != nil || n < 0 {
response.BadRequest(c, "Invalid status_codes")
return
}
out = append(out, n)
}
filter.StatusCodes = out
}
result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
}
// GetUpstreamError returns upstream error detail.
// GET /api/v1/admin/ops/upstream-errors/:id
func (h *OpsHandler) GetUpstreamError(c *gin.Context) {
h.GetErrorLogByID(c)
}
// RetryUpstreamError retries upstream error using the original account_id.
// POST /api/v1/admin/ops/upstream-errors/:id/retry
func (h *OpsHandler) RetryUpstreamError(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, service.OpsRetryModeUpstream, nil)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// ResolveUpstreamError toggles resolved status.
// PUT /api/v1/admin/ops/upstream-errors/:id/resolve
func (h *OpsHandler) ResolveUpstreamError(c *gin.Context) {
h.UpdateErrorResolution(c)
}
// ==================== Existing endpoints ====================
// ListRequestDetails returns a request-level list (success + error) for drill-down.
// GET /api/v1/admin/ops/requests
func (h *OpsHandler) ListRequestDetails(c *gin.Context) {
@@ -242,6 +709,11 @@ func (h *OpsHandler) ListRequestDetails(c *gin.Context) {
type opsRetryRequest struct {
Mode string `json:"mode"`
PinnedAccountID *int64 `json:"pinned_account_id"`
Force bool `json:"force"`
}
type opsResolveRequest struct {
Resolved bool `json:"resolved"`
}
// RetryErrorRequest retries a failed request using stored request_body.
@@ -278,6 +750,16 @@ func (h *OpsHandler) RetryErrorRequest(c *gin.Context) {
req.Mode = service.OpsRetryModeClient
}
// Force flag is currently a UI-level acknowledgement. Server may still enforce safety constraints.
_ = req.Force
// Legacy endpoint safety: only allow retrying the client request here.
// Upstream retries must go through the split endpoints.
if strings.EqualFold(strings.TrimSpace(req.Mode), service.OpsRetryModeUpstream) {
response.BadRequest(c, "upstream retry is not supported on this endpoint")
return
}
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, req.Mode, req.PinnedAccountID)
if err != nil {
response.ErrorFrom(c, err)
@@ -287,6 +769,81 @@ func (h *OpsHandler) RetryErrorRequest(c *gin.Context) {
response.Success(c, result)
}
// ListRetryAttempts lists retry attempts for an error log.
// GET /api/v1/admin/ops/errors/:id/retries
func (h *OpsHandler) ListRetryAttempts(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
limit := 50
if v := strings.TrimSpace(c.Query("limit")); v != "" {
n, err := strconv.Atoi(v)
if err != nil || n <= 0 {
response.BadRequest(c, "Invalid limit")
return
}
limit = n
}
items, err := h.opsService.ListRetryAttemptsByErrorID(c.Request.Context(), id, limit)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, items)
}
// UpdateErrorResolution allows manual resolve/unresolve.
// PUT /api/v1/admin/ops/errors/:id/resolve
func (h *OpsHandler) UpdateErrorResolution(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
var req opsResolveRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
uid := subject.UserID
if err := h.opsService.UpdateErrorResolution(c.Request.Context(), id, req.Resolved, &uid, nil); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"ok": true})
}
func parseOpsTimeRange(c *gin.Context, defaultRange string) (time.Time, time.Time, error) {
startStr := strings.TrimSpace(c.Query("start_time"))
endStr := strings.TrimSpace(c.Query("end_time"))
@@ -358,6 +915,10 @@ func parseOpsDuration(v string) (time.Duration, bool) {
return 6 * time.Hour, true
case "24h":
return 24 * time.Hour, true
case "7d":
return 7 * 24 * time.Hour, true
case "30d":
return 30 * 24 * time.Hour, true
default:
return 0, false
}

View File

@@ -196,6 +196,28 @@ func (h *ProxyHandler) Delete(c *gin.Context) {
response.Success(c, gin.H{"message": "Proxy deleted successfully"})
}
// BatchDelete handles batch deleting proxies
// POST /api/v1/admin/proxies/batch-delete
func (h *ProxyHandler) BatchDelete(c *gin.Context) {
type BatchDeleteRequest struct {
IDs []int64 `json:"ids" binding:"required,min=1"`
}
var req BatchDeleteRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
result, err := h.adminService.BatchDeleteProxies(c.Request.Context(), req.IDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// Test handles testing proxy connectivity
// POST /api/v1/admin/proxies/:id/test
func (h *ProxyHandler) Test(c *gin.Context) {
@@ -243,19 +265,17 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
return
}
page, pageSize := response.ParsePagination(c)
accounts, total, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID, page, pageSize)
accounts, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.Account, 0, len(accounts))
out := make([]dto.ProxyAccountSummary, 0, len(accounts))
for i := range accounts {
out = append(out, *dto.AccountFromService(&accounts[i]))
out = append(out, *dto.ProxyAccountSummaryFromService(&accounts[i]))
}
response.Paginated(c, out, total, page, pageSize)
response.Success(c, out)
}
// BatchCreateProxyItem represents a single proxy in batch create request

View File

@@ -54,9 +54,9 @@ func (h *RedeemHandler) List(c *gin.Context) {
return
}
out := make([]dto.RedeemCode, 0, len(codes))
out := make([]dto.AdminRedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
@@ -76,7 +76,7 @@ func (h *RedeemHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.RedeemCodeFromService(code))
response.Success(c, dto.RedeemCodeFromServiceAdmin(code))
}
// Generate handles generating new redeem codes
@@ -100,9 +100,9 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
return
}
out := make([]dto.RedeemCode, 0, len(codes))
out := make([]dto.AdminRedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
}
response.Success(c, out)
}
@@ -163,7 +163,7 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
return
}
response.Success(c, dto.RedeemCodeFromService(code))
response.Success(c, dto.RedeemCodeFromServiceAdmin(code))
}
// GetStats handles getting redeem code statistics

View File

@@ -47,6 +47,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
TotpEnabled: settings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername,
@@ -68,6 +72,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
EnableModelFallback: settings.EnableModelFallback,
@@ -87,8 +92,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
// UpdateSettingsRequest 更新设置请求
type UpdateSettingsRequest struct {
// 注册设置
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
// 邮件服务设置
SMTPHost string `json:"smtp_host"`
@@ -111,13 +119,14 @@ type UpdateSettingsRequest struct {
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
// OEM设置
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
@@ -194,6 +203,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
// TOTP 双因素认证参数验证
// 只有手动配置了加密密钥才允许启用 TOTP 功能
if req.TotpEnabled && !previousSettings.TotpEnabled {
// 尝试启用 TOTP检查加密密钥是否已手动配置
if !h.settingService.IsTotpEncryptionKeyConfigured() {
response.BadRequest(c, "Cannot enable TOTP: TOTP_ENCRYPTION_KEY environment variable must be configured first. Generate a key with 'openssl rand -hex 32' and set it in your environment.")
return
}
}
// LinuxDo Connect 参数验证
if req.LinuxDoConnectEnabled {
req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID)
@@ -238,6 +257,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
PromoCodeEnabled: req.PromoCodeEnabled,
PasswordResetEnabled: req.PasswordResetEnabled,
TotpEnabled: req.TotpEnabled,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
@@ -259,6 +281,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
ContactInfo: req.ContactInfo,
DocURL: req.DocURL,
HomeContent: req.HomeContent,
HideCcsImportButton: req.HideCcsImportButton,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
EnableModelFallback: req.EnableModelFallback,
@@ -311,6 +334,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
TotpEnabled: updatedSettings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername,
@@ -332,6 +359,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
ContactInfo: updatedSettings.ContactInfo,
DocURL: updatedSettings.DocURL,
HomeContent: updatedSettings.HomeContent,
HideCcsImportButton: updatedSettings.HideCcsImportButton,
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
EnableModelFallback: updatedSettings.EnableModelFallback,
@@ -376,6 +404,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
changed = append(changed, "email_verify_enabled")
}
if before.PasswordResetEnabled != after.PasswordResetEnabled {
changed = append(changed, "password_reset_enabled")
}
if before.TotpEnabled != after.TotpEnabled {
changed = append(changed, "totp_enabled")
}
if before.SMTPHost != after.SMTPHost {
changed = append(changed, "smtp_host")
}
@@ -439,6 +473,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.HomeContent != after.HomeContent {
changed = append(changed, "home_content")
}
if before.HideCcsImportButton != after.HideCcsImportButton {
changed = append(changed, "hide_ccs_import_button")
}
if before.DefaultConcurrency != after.DefaultConcurrency {
changed = append(changed, "default_concurrency")
}

View File

@@ -53,9 +53,9 @@ type BulkAssignSubscriptionRequest struct {
Notes string `json:"notes"`
}
// ExtendSubscriptionRequest represents extend subscription request
type ExtendSubscriptionRequest struct {
Days int `json:"days" binding:"required,min=1,max=36500"` // max 100 years
// AdjustSubscriptionRequest represents adjust subscription request (extend or shorten)
type AdjustSubscriptionRequest struct {
Days int `json:"days" binding:"required,min=-36500,max=36500"` // negative to shorten, positive to extend
}
// List handles listing all subscriptions with pagination and filters
@@ -77,15 +77,19 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
}
status := c.Query("status")
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
// Parse sorting parameters
sortBy := c.DefaultQuery("sort_by", "created_at")
sortOrder := c.DefaultQuery("sort_order", "desc")
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.UserSubscription, 0, len(subscriptions))
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
}
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
}
@@ -105,7 +109,7 @@ func (h *SubscriptionHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.UserSubscriptionFromService(subscription))
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
}
// GetProgress handles getting subscription usage progress
@@ -150,7 +154,7 @@ func (h *SubscriptionHandler) Assign(c *gin.Context) {
return
}
response.Success(c, dto.UserSubscriptionFromService(subscription))
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
}
// BulkAssign handles bulk assigning subscriptions to multiple users
@@ -180,7 +184,7 @@ func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
response.Success(c, dto.BulkAssignResultFromService(result))
}
// Extend handles extending a subscription
// Extend handles adjusting a subscription (extend or shorten)
// POST /api/v1/admin/subscriptions/:id/extend
func (h *SubscriptionHandler) Extend(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
@@ -189,7 +193,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
return
}
var req ExtendSubscriptionRequest
var req AdjustSubscriptionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
@@ -201,7 +205,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
return
}
response.Success(c, dto.UserSubscriptionFromService(subscription))
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
}
// Revoke handles revoking a subscription
@@ -239,9 +243,9 @@ func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
return
}
out := make([]dto.UserSubscription, 0, len(subscriptions))
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
}
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
}
@@ -261,9 +265,9 @@ func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
return
}
out := make([]dto.UserSubscription, 0, len(subscriptions))
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
}
response.Success(c, out)
}

View File

@@ -0,0 +1,377 @@
package admin
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type cleanupRepoStub struct {
mu sync.Mutex
created []*service.UsageCleanupTask
listTasks []service.UsageCleanupTask
listResult *pagination.PaginationResult
listErr error
statusByID map[int64]string
}
func (s *cleanupRepoStub) CreateTask(ctx context.Context, task *service.UsageCleanupTask) error {
if task == nil {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
if task.ID == 0 {
task.ID = int64(len(s.created) + 1)
}
if task.CreatedAt.IsZero() {
task.CreatedAt = time.Now().UTC()
}
task.UpdatedAt = task.CreatedAt
clone := *task
s.created = append(s.created, &clone)
return nil
}
func (s *cleanupRepoStub) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) {
s.mu.Lock()
defer s.mu.Unlock()
return s.listTasks, s.listResult, s.listErr
}
func (s *cleanupRepoStub) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*service.UsageCleanupTask, error) {
return nil, nil
}
func (s *cleanupRepoStub) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.statusByID == nil {
return "", sql.ErrNoRows
}
status, ok := s.statusByID[taskID]
if !ok {
return "", sql.ErrNoRows
}
return status, nil
}
func (s *cleanupRepoStub) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error {
return nil
}
func (s *cleanupRepoStub) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.statusByID == nil {
s.statusByID = map[int64]string{}
}
status := s.statusByID[taskID]
if status != service.UsageCleanupStatusPending && status != service.UsageCleanupStatusRunning {
return false, nil
}
s.statusByID[taskID] = service.UsageCleanupStatusCanceled
return true, nil
}
func (s *cleanupRepoStub) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error {
return nil
}
func (s *cleanupRepoStub) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
return nil
}
func (s *cleanupRepoStub) DeleteUsageLogsBatch(ctx context.Context, filters service.UsageCleanupFilters, limit int) (int64, error) {
return 0, nil
}
var _ service.UsageCleanupRepository = (*cleanupRepoStub)(nil)
func setupCleanupRouter(cleanupService *service.UsageCleanupService, userID int64) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
if userID > 0 {
router.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: userID})
c.Next()
})
}
handler := NewUsageHandler(nil, nil, nil, cleanupService)
router.POST("/api/v1/admin/usage/cleanup-tasks", handler.CreateCleanupTask)
router.GET("/api/v1/admin/usage/cleanup-tasks", handler.ListCleanupTasks)
router.POST("/api/v1/admin/usage/cleanup-tasks/:id/cancel", handler.CancelCleanupTask)
return router
}
func TestUsageHandlerCreateCleanupTaskUnauthorized(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 0)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString(`{}`))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusUnauthorized, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskUnavailable(t *testing.T) {
router := setupCleanupRouter(nil, 1)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString(`{}`))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusServiceUnavailable, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskBindError(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 88)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString("{bad-json"))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskMissingRange(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 88)
payload := map[string]any{
"start_date": "2024-01-01",
"timezone": "UTC",
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskInvalidDate(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 88)
payload := map[string]any{
"start_date": "2024-13-01",
"end_date": "2024-01-02",
"timezone": "UTC",
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskInvalidEndDate(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 88)
payload := map[string]any{
"start_date": "2024-01-01",
"end_date": "2024-02-40",
"timezone": "UTC",
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 99)
payload := map[string]any{
"start_date": " 2024-01-01 ",
"end_date": "2024-01-02",
"timezone": "UTC",
"model": "gpt-4",
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
var resp response.Response
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.created, 1)
created := repo.created[0]
require.Equal(t, int64(99), created.CreatedBy)
require.NotNil(t, created.Filters.Model)
require.Equal(t, "gpt-4", *created.Filters.Model)
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC).Add(24*time.Hour - time.Nanosecond)
require.True(t, created.Filters.StartTime.Equal(start))
require.True(t, created.Filters.EndTime.Equal(end))
}
func TestUsageHandlerListCleanupTasksUnavailable(t *testing.T) {
router := setupCleanupRouter(nil, 0)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusServiceUnavailable, recorder.Code)
}
func TestUsageHandlerListCleanupTasksSuccess(t *testing.T) {
repo := &cleanupRepoStub{}
repo.listTasks = []service.UsageCleanupTask{
{
ID: 7,
Status: service.UsageCleanupStatusSucceeded,
CreatedBy: 4,
},
}
repo.listResult = &pagination.PaginationResult{Total: 1, Page: 1, PageSize: 20, Pages: 1}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data struct {
Items []dto.UsageCleanupTask `json:"items"`
Total int64 `json:"total"`
Page int `json:"page"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Len(t, resp.Data.Items, 1)
require.Equal(t, int64(7), resp.Data.Items[0].ID)
require.Equal(t, int64(1), resp.Data.Total)
require.Equal(t, 1, resp.Data.Page)
}
func TestUsageHandlerListCleanupTasksError(t *testing.T) {
repo := &cleanupRepoStub{listErr: errors.New("boom")}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusInternalServerError, recorder.Code)
}
func TestUsageHandlerCancelCleanupTaskUnauthorized(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 0)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/1/cancel", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestUsageHandlerCancelCleanupTaskNotFound(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/999/cancel", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestUsageHandlerCancelCleanupTaskConflict(t *testing.T) {
repo := &cleanupRepoStub{statusByID: map[int64]string{2: service.UsageCleanupStatusSucceeded}}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/2/cancel", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusConflict, rec.Code)
}
func TestUsageHandlerCancelCleanupTaskSuccess(t *testing.T) {
repo := &cleanupRepoStub{statusByID: map[int64]string{3: service.UsageCleanupStatusPending}}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/3/cancel", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}

View File

@@ -1,7 +1,10 @@
package admin
import (
"log"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
@@ -9,6 +12,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -16,9 +20,10 @@ import (
// UsageHandler handles admin usage-related requests
type UsageHandler struct {
usageService *service.UsageService
apiKeyService *service.APIKeyService
adminService service.AdminService
usageService *service.UsageService
apiKeyService *service.APIKeyService
adminService service.AdminService
cleanupService *service.UsageCleanupService
}
// NewUsageHandler creates a new admin usage handler
@@ -26,14 +31,30 @@ func NewUsageHandler(
usageService *service.UsageService,
apiKeyService *service.APIKeyService,
adminService service.AdminService,
cleanupService *service.UsageCleanupService,
) *UsageHandler {
return &UsageHandler{
usageService: usageService,
apiKeyService: apiKeyService,
adminService: adminService,
usageService: usageService,
apiKeyService: apiKeyService,
adminService: adminService,
cleanupService: cleanupService,
}
}
// CreateUsageCleanupTaskRequest represents cleanup task creation request
type CreateUsageCleanupTaskRequest struct {
StartDate string `json:"start_date"`
EndDate string `json:"end_date"`
UserID *int64 `json:"user_id"`
APIKeyID *int64 `json:"api_key_id"`
AccountID *int64 `json:"account_id"`
GroupID *int64 `json:"group_id"`
Model *string `json:"model"`
Stream *bool `json:"stream"`
BillingType *int8 `json:"billing_type"`
Timezone string `json:"timezone"`
}
// List handles listing all usage records with filters
// GET /api/v1/admin/usage
func (h *UsageHandler) List(c *gin.Context) {
@@ -142,7 +163,7 @@ func (h *UsageHandler) List(c *gin.Context) {
return
}
out := make([]dto.UsageLog, 0, len(records))
out := make([]dto.AdminUsageLog, 0, len(records))
for i := range records {
out = append(out, *dto.UsageLogFromServiceAdmin(&records[i]))
}
@@ -344,3 +365,162 @@ func (h *UsageHandler) SearchAPIKeys(c *gin.Context) {
response.Success(c, result)
}
// ListCleanupTasks handles listing usage cleanup tasks
// GET /api/v1/admin/usage/cleanup-tasks
func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
if h.cleanupService == nil {
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
return
}
operator := int64(0)
if subject, ok := middleware.GetAuthSubjectFromContext(c); ok {
operator = subject.UserID
}
page, pageSize := response.ParsePagination(c)
log.Printf("[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params)
if err != nil {
log.Printf("[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
response.ErrorFrom(c, err)
return
}
out := make([]dto.UsageCleanupTask, 0, len(tasks))
for i := range tasks {
out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i]))
}
log.Printf("[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
response.Paginated(c, out, result.Total, page, pageSize)
}
// CreateCleanupTask handles creating a usage cleanup task
// POST /api/v1/admin/usage/cleanup-tasks
func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
if h.cleanupService == nil {
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Unauthorized(c, "Unauthorized")
return
}
var req CreateUsageCleanupTaskRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
req.StartDate = strings.TrimSpace(req.StartDate)
req.EndDate = strings.TrimSpace(req.EndDate)
if req.StartDate == "" || req.EndDate == "" {
response.BadRequest(c, "start_date and end_date are required")
return
}
startTime, err := timezone.ParseInUserLocation("2006-01-02", req.StartDate, req.Timezone)
if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return
}
endTime, err := timezone.ParseInUserLocation("2006-01-02", req.EndDate, req.Timezone)
if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
filters := service.UsageCleanupFilters{
StartTime: startTime,
EndTime: endTime,
UserID: req.UserID,
APIKeyID: req.APIKeyID,
AccountID: req.AccountID,
GroupID: req.GroupID,
Model: req.Model,
Stream: req.Stream,
BillingType: req.BillingType,
}
var userID any
if filters.UserID != nil {
userID = *filters.UserID
}
var apiKeyID any
if filters.APIKeyID != nil {
apiKeyID = *filters.APIKeyID
}
var accountID any
if filters.AccountID != nil {
accountID = *filters.AccountID
}
var groupID any
if filters.GroupID != nil {
groupID = *filters.GroupID
}
var model any
if filters.Model != nil {
model = *filters.Model
}
var stream any
if filters.Stream != nil {
stream = *filters.Stream
}
var billingType any
if filters.BillingType != nil {
billingType = *filters.BillingType
}
log.Printf("[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
subject.UserID,
filters.StartTime.Format(time.RFC3339),
filters.EndTime.Format(time.RFC3339),
userID,
apiKeyID,
accountID,
groupID,
model,
stream,
billingType,
req.Timezone,
)
task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID)
if err != nil {
log.Printf("[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
response.ErrorFrom(c, err)
return
}
log.Printf("[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
response.Success(c, dto.UsageCleanupTaskFromService(task))
}
// CancelCleanupTask handles canceling a usage cleanup task
// POST /api/v1/admin/usage/cleanup-tasks/:id/cancel
func (h *UsageHandler) CancelCleanupTask(c *gin.Context) {
if h.cleanupService == nil {
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Unauthorized(c, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
taskID, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || taskID <= 0 {
response.BadRequest(c, "Invalid task id")
return
}
log.Printf("[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil {
log.Printf("[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
response.ErrorFrom(c, err)
return
}
log.Printf("[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled})
}

View File

@@ -84,9 +84,9 @@ func (h *UserHandler) List(c *gin.Context) {
return
}
out := make([]dto.User, 0, len(users))
out := make([]dto.AdminUser, 0, len(users))
for i := range users {
out = append(out, *dto.UserFromService(&users[i]))
out = append(out, *dto.UserFromServiceAdmin(&users[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
@@ -129,7 +129,7 @@ func (h *UserHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.UserFromService(user))
response.Success(c, dto.UserFromServiceAdmin(user))
}
// Create handles creating a new user
@@ -155,7 +155,7 @@ func (h *UserHandler) Create(c *gin.Context) {
return
}
response.Success(c, dto.UserFromService(user))
response.Success(c, dto.UserFromServiceAdmin(user))
}
// Update handles updating a user
@@ -189,7 +189,7 @@ func (h *UserHandler) Update(c *gin.Context) {
return
}
response.Success(c, dto.UserFromService(user))
response.Success(c, dto.UserFromServiceAdmin(user))
}
// Delete handles deleting a user
@@ -231,7 +231,7 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
return
}
response.Success(c, dto.UserFromService(user))
response.Success(c, dto.UserFromServiceAdmin(user))
}
// GetUserAPIKeys handles getting user's API keys

View File

@@ -1,6 +1,8 @@
package handler
import (
"log/slog"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
@@ -18,16 +20,18 @@ type AuthHandler struct {
userService *service.UserService
settingSvc *service.SettingService
promoService *service.PromoService
totpService *service.TotpService
}
// NewAuthHandler creates a new AuthHandler
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService) *AuthHandler {
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, totpService *service.TotpService) *AuthHandler {
return &AuthHandler{
cfg: cfg,
authService: authService,
userService: userService,
settingSvc: settingService,
promoService: promoService,
totpService: totpService,
}
}
@@ -144,6 +148,100 @@ func (h *AuthHandler) Login(c *gin.Context) {
return
}
// Check if TOTP 2FA is enabled for this user
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
// Create a temporary login session for 2FA
tempToken, err := h.totpService.CreateLoginSession(c.Request.Context(), user.ID, user.Email)
if err != nil {
response.InternalError(c, "Failed to create 2FA session")
return
}
response.Success(c, TotpLoginResponse{
Requires2FA: true,
TempToken: tempToken,
UserEmailMasked: service.MaskEmail(user.Email),
})
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
User: dto.UserFromService(user),
})
}
// TotpLoginResponse represents the response when 2FA is required
type TotpLoginResponse struct {
Requires2FA bool `json:"requires_2fa"`
TempToken string `json:"temp_token,omitempty"`
UserEmailMasked string `json:"user_email_masked,omitempty"`
}
// Login2FARequest represents the 2FA login request
type Login2FARequest struct {
TempToken string `json:"temp_token" binding:"required"`
TotpCode string `json:"totp_code" binding:"required,len=6"`
}
// Login2FA completes the login with 2FA verification
// POST /api/v1/auth/login/2fa
func (h *AuthHandler) Login2FA(c *gin.Context) {
var req Login2FARequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
slog.Debug("login_2fa_request",
"temp_token_len", len(req.TempToken),
"totp_code_len", len(req.TotpCode))
// Get the login session
session, err := h.totpService.GetLoginSession(c.Request.Context(), req.TempToken)
if err != nil || session == nil {
tokenPrefix := ""
if len(req.TempToken) >= 8 {
tokenPrefix = req.TempToken[:8]
}
slog.Debug("login_2fa_session_invalid",
"temp_token_prefix", tokenPrefix,
"error", err)
response.BadRequest(c, "Invalid or expired 2FA session")
return
}
slog.Debug("login_2fa_session_found",
"user_id", session.UserID,
"email", session.Email)
// Verify the TOTP code
if err := h.totpService.VerifyCode(c.Request.Context(), session.UserID, req.TotpCode); err != nil {
slog.Debug("login_2fa_verify_failed",
"user_id", session.UserID,
"error", err)
response.ErrorFrom(c, err)
return
}
// Delete the login session
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
// Get the user
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Generate the JWT token
token, err := h.authService.GenerateToken(user)
if err != nil {
response.InternalError(c, "Failed to generate token")
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
@@ -195,6 +293,15 @@ type ValidatePromoCodeResponse struct {
// ValidatePromoCode 验证优惠码(公开接口,注册前调用)
// POST /api/v1/auth/validate-promo-code
func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
// 检查优惠码功能是否启用
if h.settingSvc != nil && !h.settingSvc.IsPromoCodeEnabled(c.Request.Context()) {
response.Success(c, ValidatePromoCodeResponse{
Valid: false,
ErrorCode: "PROMO_CODE_DISABLED",
})
return
}
var req ValidatePromoCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
@@ -238,3 +345,85 @@ func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
BonusAmount: promoCode.BonusAmount,
})
}
// ForgotPasswordRequest 忘记密码请求
type ForgotPasswordRequest struct {
Email string `json:"email" binding:"required,email"`
TurnstileToken string `json:"turnstile_token"`
}
// ForgotPasswordResponse 忘记密码响应
type ForgotPasswordResponse struct {
Message string `json:"message"`
}
// ForgotPassword 请求密码重置
// POST /api/v1/auth/forgot-password
func (h *AuthHandler) ForgotPassword(c *gin.Context) {
var req ForgotPasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Turnstile 验证
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
response.ErrorFrom(c, err)
return
}
// Build frontend base URL from request
scheme := "https"
if c.Request.TLS == nil {
// Check X-Forwarded-Proto header (common in reverse proxy setups)
if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" {
scheme = proto
} else {
scheme = "http"
}
}
frontendBaseURL := scheme + "://" + c.Request.Host
// Request password reset (async)
// Note: This returns success even if email doesn't exist (to prevent enumeration)
if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, ForgotPasswordResponse{
Message: "If your email is registered, you will receive a password reset link shortly.",
})
}
// ResetPasswordRequest 重置密码请求
type ResetPasswordRequest struct {
Email string `json:"email" binding:"required,email"`
Token string `json:"token" binding:"required"`
NewPassword string `json:"new_password" binding:"required,min=6"`
}
// ResetPasswordResponse 重置密码响应
type ResetPasswordResponse struct {
Message string `json:"message"`
}
// ResetPassword 重置密码
// POST /api/v1/auth/reset-password
func (h *AuthHandler) ResetPassword(c *gin.Context) {
var req ResetPasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Reset password
if err := h.authService.ResetPassword(c.Request.Context(), req.Email, req.Token, req.NewPassword); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, ResetPasswordResponse{
Message: "Your password has been reset successfully. You can now log in with your new password.",
})
}

View File

@@ -15,7 +15,6 @@ func UserFromServiceShallow(u *service.User) *User {
ID: u.ID,
Email: u.Email,
Username: u.Username,
Notes: u.Notes,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
@@ -48,6 +47,22 @@ func UserFromService(u *service.User) *User {
return out
}
// UserFromServiceAdmin converts a service User to DTO for admin users.
// It includes notes - user-facing endpoints must not use this.
func UserFromServiceAdmin(u *service.User) *AdminUser {
if u == nil {
return nil
}
base := UserFromService(u)
if base == nil {
return nil
}
return &AdminUser{
User: *base,
Notes: u.Notes,
}
}
func APIKeyFromService(k *service.APIKey) *APIKey {
if k == nil {
return nil
@@ -72,7 +87,41 @@ func GroupFromServiceShallow(g *service.Group) *Group {
if g == nil {
return nil
}
return &Group{
out := groupFromServiceBase(g)
return &out
}
func GroupFromService(g *service.Group) *Group {
if g == nil {
return nil
}
return GroupFromServiceShallow(g)
}
// GroupFromServiceAdmin converts a service Group to DTO for admin users.
// It includes internal fields like model_routing and account_count.
func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
if g == nil {
return nil
}
out := &AdminGroup{
Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
AccountCount: g.AccountCount,
}
if len(g.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
for i := range g.AccountGroups {
ag := g.AccountGroups[i]
out.AccountGroups = append(out.AccountGroups, *AccountGroupFromService(&ag))
}
}
return out
}
func groupFromServiceBase(g *service.Group) Group {
return Group{
ID: g.ID,
Name: g.Name,
Description: g.Description,
@@ -91,30 +140,14 @@ func GroupFromServiceShallow(g *service.Group) *Group {
FallbackGroupID: g.FallbackGroupID,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
AccountCount: g.AccountCount,
}
}
func GroupFromService(g *service.Group) *Group {
if g == nil {
return nil
}
out := GroupFromServiceShallow(g)
if len(g.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
for i := range g.AccountGroups {
ag := g.AccountGroups[i]
out.AccountGroups = append(out.AccountGroups, *AccountGroupFromService(&ag))
}
}
return out
}
func AccountFromServiceShallow(a *service.Account) *Account {
if a == nil {
return nil
}
return &Account{
out := &Account{
ID: a.ID,
Name: a.Name,
Notes: a.Notes,
@@ -125,6 +158,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
Priority: a.Priority,
RateMultiplier: a.BillingRateMultiplier(),
Status: a.Status,
ErrorMessage: a.ErrorMessage,
LastUsedAt: a.LastUsedAt,
@@ -143,6 +177,34 @@ func AccountFromServiceShallow(a *service.Account) *Account {
SessionWindowStatus: a.SessionWindowStatus,
GroupIDs: a.GroupIDs,
}
// 提取 5h 窗口费用控制和会话数量控制配置(仅 Anthropic OAuth/SetupToken 账号有效)
if a.IsAnthropicOAuthOrSetupToken() {
if limit := a.GetWindowCostLimit(); limit > 0 {
out.WindowCostLimit = &limit
}
if reserve := a.GetWindowCostStickyReserve(); reserve > 0 {
out.WindowCostStickyReserve = &reserve
}
if maxSessions := a.GetMaxSessions(); maxSessions > 0 {
out.MaxSessions = &maxSessions
}
if idleTimeout := a.GetSessionIdleTimeoutMinutes(); idleTimeout > 0 {
out.SessionIdleTimeoutMin = &idleTimeout
}
// TLS指纹伪装开关
if a.IsTLSFingerprintEnabled() {
enabled := true
out.EnableTLSFingerprint = &enabled
}
// 会话ID伪装开关
if a.IsSessionIDMaskingEnabled() {
enabled := true
out.EnableSessionIDMasking = &enabled
}
}
return out
}
func AccountFromService(a *service.Account) *Account {
@@ -212,8 +274,29 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
return nil
}
return &ProxyWithAccountCount{
Proxy: *ProxyFromService(&p.Proxy),
AccountCount: p.AccountCount,
Proxy: *ProxyFromService(&p.Proxy),
AccountCount: p.AccountCount,
LatencyMs: p.LatencyMs,
LatencyStatus: p.LatencyStatus,
LatencyMessage: p.LatencyMessage,
IPAddress: p.IPAddress,
Country: p.Country,
CountryCode: p.CountryCode,
Region: p.Region,
City: p.City,
}
}
func ProxyAccountSummaryFromService(a *service.ProxyAccountSummary) *ProxyAccountSummary {
if a == nil {
return nil
}
return &ProxyAccountSummary{
ID: a.ID,
Name: a.Name,
Platform: a.Platform,
Type: a.Type,
Notes: a.Notes,
}
}
@@ -221,7 +304,24 @@ func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
if rc == nil {
return nil
}
return &RedeemCode{
out := redeemCodeFromServiceBase(rc)
return &out
}
// RedeemCodeFromServiceAdmin converts a service RedeemCode to DTO for admin users.
// It includes notes - user-facing endpoints must not use this.
func RedeemCodeFromServiceAdmin(rc *service.RedeemCode) *AdminRedeemCode {
if rc == nil {
return nil
}
return &AdminRedeemCode{
RedeemCode: redeemCodeFromServiceBase(rc),
Notes: rc.Notes,
}
}
func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode {
return RedeemCode{
ID: rc.ID,
Code: rc.Code,
Type: rc.Type,
@@ -229,7 +329,6 @@ func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
Status: rc.Status,
UsedBy: rc.UsedBy,
UsedAt: rc.UsedAt,
Notes: rc.Notes,
CreatedAt: rc.CreatedAt,
GroupID: rc.GroupID,
ValidityDays: rc.ValidityDays,
@@ -250,14 +349,9 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary {
}
}
// usageLogFromServiceBase is a helper that converts service UsageLog to DTO.
// The account parameter allows caller to control what Account info is included.
// The includeIPAddress parameter controls whether to include the IP address (admin-only).
func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, includeIPAddress bool) *UsageLog {
if l == nil {
return nil
}
result := &UsageLog{
func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
// 普通用户 DTO严禁包含管理员字段例如 account_rate_multiplier、ip_address、account
return UsageLog{
ID: l.ID,
UserID: l.UserID,
APIKeyID: l.APIKeyID,
@@ -289,30 +383,63 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, inclu
CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User),
APIKey: APIKeyFromService(l.APIKey),
Account: account,
Group: GroupFromServiceShallow(l.Group),
Subscription: UserSubscriptionFromService(l.Subscription),
}
// IP 地址仅对管理员可见
if includeIPAddress {
result.IPAddress = l.IPAddress
}
return result
}
// UsageLogFromService converts a service UsageLog to DTO for regular users.
// It excludes Account details and IP address - users should not see these.
func UsageLogFromService(l *service.UsageLog) *UsageLog {
return usageLogFromServiceBase(l, nil, false)
if l == nil {
return nil
}
u := usageLogFromServiceUser(l)
return &u
}
// UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users.
// It includes minimal Account info (ID, Name only) and IP address.
func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog {
func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog {
if l == nil {
return nil
}
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true)
return &AdminUsageLog{
UsageLog: usageLogFromServiceUser(l),
AccountRateMultiplier: l.AccountRateMultiplier,
IPAddress: l.IPAddress,
Account: AccountSummaryFromService(l.Account),
}
}
func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTask {
if task == nil {
return nil
}
return &UsageCleanupTask{
ID: task.ID,
Status: task.Status,
Filters: UsageCleanupFilters{
StartTime: task.Filters.StartTime,
EndTime: task.Filters.EndTime,
UserID: task.Filters.UserID,
APIKeyID: task.Filters.APIKeyID,
AccountID: task.Filters.AccountID,
GroupID: task.Filters.GroupID,
Model: task.Filters.Model,
Stream: task.Filters.Stream,
BillingType: task.Filters.BillingType,
},
CreatedBy: task.CreatedBy,
DeletedRows: task.DeletedRows,
ErrorMessage: task.ErrorMsg,
CanceledBy: task.CanceledBy,
CanceledAt: task.CanceledAt,
StartedAt: task.StartedAt,
FinishedAt: task.FinishedAt,
CreatedAt: task.CreatedAt,
UpdatedAt: task.UpdatedAt,
}
}
func SettingFromService(s *service.Setting) *Setting {
@@ -331,7 +458,27 @@ func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscriptio
if sub == nil {
return nil
}
return &UserSubscription{
out := userSubscriptionFromServiceBase(sub)
return &out
}
// UserSubscriptionFromServiceAdmin converts a service UserSubscription to DTO for admin users.
// It includes assignment metadata and notes.
func UserSubscriptionFromServiceAdmin(sub *service.UserSubscription) *AdminUserSubscription {
if sub == nil {
return nil
}
return &AdminUserSubscription{
UserSubscription: userSubscriptionFromServiceBase(sub),
AssignedBy: sub.AssignedBy,
AssignedAt: sub.AssignedAt,
Notes: sub.Notes,
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
}
}
func userSubscriptionFromServiceBase(sub *service.UserSubscription) UserSubscription {
return UserSubscription{
ID: sub.ID,
UserID: sub.UserID,
GroupID: sub.GroupID,
@@ -344,14 +491,10 @@ func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscriptio
DailyUsageUSD: sub.DailyUsageUSD,
WeeklyUsageUSD: sub.WeeklyUsageUSD,
MonthlyUsageUSD: sub.MonthlyUsageUSD,
AssignedBy: sub.AssignedBy,
AssignedAt: sub.AssignedAt,
Notes: sub.Notes,
CreatedAt: sub.CreatedAt,
UpdatedAt: sub.UpdatedAt,
User: UserFromServiceShallow(sub.User),
Group: GroupFromServiceShallow(sub.Group),
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
}
}
@@ -359,9 +502,9 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
if r == nil {
return nil
}
subs := make([]UserSubscription, 0, len(r.Subscriptions))
subs := make([]AdminUserSubscription, 0, len(r.Subscriptions))
for i := range r.Subscriptions {
subs = append(subs, *UserSubscriptionFromService(&r.Subscriptions[i]))
subs = append(subs, *UserSubscriptionFromServiceAdmin(&r.Subscriptions[i]))
}
return &BulkAssignResult{
SuccessCount: r.SuccessCount,

View File

@@ -2,8 +2,12 @@ package dto
// SystemSettings represents the admin settings API response payload.
type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
@@ -22,13 +26,14 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
@@ -52,19 +57,23 @@ type SystemSettings struct {
}
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version"`
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version"`
}
// StreamTimeoutSettings 流超时处理配置 DTO

View File

@@ -6,7 +6,6 @@ type User struct {
ID int64 `json:"id"`
Email string `json:"email"`
Username string `json:"username"`
Notes string `json:"notes"`
Role string `json:"role"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
@@ -19,6 +18,14 @@ type User struct {
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
}
// AdminUser 是管理员接口使用的 user DTO包含敏感/内部字段)。
// 注意:普通用户接口不得返回 notes 等管理员备注信息。
type AdminUser struct {
User
Notes string `json:"notes"`
}
type APIKey struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
@@ -60,6 +67,16 @@ type Group struct {
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// AdminGroup 是管理员接口使用的 group DTO包含敏感/内部字段)。
// 注意:普通用户接口不得返回 model_routing/account_count/account_groups 等内部信息。
type AdminGroup struct {
Group
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
AccountCount int64 `json:"account_count,omitempty"`
@@ -76,6 +93,7 @@ type Account struct {
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
RateMultiplier float64 `json:"rate_multiplier"`
Status string `json:"status"`
ErrorMessage string `json:"error_message"`
LastUsedAt *time.Time `json:"last_used_at"`
@@ -97,6 +115,25 @@ type Account struct {
SessionWindowEnd *time.Time `json:"session_window_end"`
SessionWindowStatus string `json:"session_window_status"`
// 5h窗口费用控制仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
WindowCostLimit *float64 `json:"window_cost_limit,omitempty"`
WindowCostStickyReserve *float64 `json:"window_cost_sticky_reserve,omitempty"`
// 会话数量控制(仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
MaxSessions *int `json:"max_sessions,omitempty"`
SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"`
// TLS指纹伪装仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
EnableTLSFingerprint *bool `json:"enable_tls_fingerprint,omitempty"`
// 会话ID伪装仅 Anthropic OAuth/SetupToken 账号有效)
// 启用后将在15分钟内固定 metadata.user_id 中的 session ID
// 从 extra 字段提取,方便前端显示和编辑
EnableSessionIDMasking *bool `json:"session_id_masking_enabled,omitempty"`
Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
@@ -129,7 +166,23 @@ type Proxy struct {
type ProxyWithAccountCount struct {
Proxy
AccountCount int64 `json:"account_count"`
AccountCount int64 `json:"account_count"`
LatencyMs *int64 `json:"latency_ms,omitempty"`
LatencyStatus string `json:"latency_status,omitempty"`
LatencyMessage string `json:"latency_message,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
Country string `json:"country,omitempty"`
CountryCode string `json:"country_code,omitempty"`
Region string `json:"region,omitempty"`
City string `json:"city,omitempty"`
}
type ProxyAccountSummary struct {
ID int64 `json:"id"`
Name string `json:"name"`
Platform string `json:"platform"`
Type string `json:"type"`
Notes *string `json:"notes,omitempty"`
}
type RedeemCode struct {
@@ -140,7 +193,6 @@ type RedeemCode struct {
Status string `json:"status"`
UsedBy *int64 `json:"used_by"`
UsedAt *time.Time `json:"used_at"`
Notes string `json:"notes"`
CreatedAt time.Time `json:"created_at"`
GroupID *int64 `json:"group_id"`
@@ -150,6 +202,15 @@ type RedeemCode struct {
Group *Group `json:"group,omitempty"`
}
// AdminRedeemCode 是管理员接口使用的 redeem code DTO包含 notes 等字段)。
// 注意:普通用户接口不得返回 notes 等内部信息。
type AdminRedeemCode struct {
RedeemCode
Notes string `json:"notes"`
}
// UsageLog 是普通用户接口使用的 usage log DTO不包含管理员字段
type UsageLog struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
@@ -189,18 +250,55 @@ type UsageLog struct {
// User-Agent
UserAgent *string `json:"user_agent"`
// IP 地址(仅管理员可见)
IPAddress *string `json:"ip_address,omitempty"`
CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"`
APIKey *APIKey `json:"api_key,omitempty"`
Account *AccountSummary `json:"account,omitempty"` // Use minimal AccountSummary to prevent data leakage
Group *Group `json:"group,omitempty"`
Subscription *UserSubscription `json:"subscription,omitempty"`
}
// AdminUsageLog 是管理员接口使用的 usage log DTO包含管理员字段
type AdminUsageLog struct {
UsageLog
// AccountRateMultiplier 账号计费倍率快照nil 表示按 1.0 处理)
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
// IPAddress 用户请求 IP仅管理员可见
IPAddress *string `json:"ip_address,omitempty"`
// Account 最小账号信息(避免泄露敏感字段)
Account *AccountSummary `json:"account,omitempty"`
}
type UsageCleanupFilters struct {
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
UserID *int64 `json:"user_id,omitempty"`
APIKeyID *int64 `json:"api_key_id,omitempty"`
AccountID *int64 `json:"account_id,omitempty"`
GroupID *int64 `json:"group_id,omitempty"`
Model *string `json:"model,omitempty"`
Stream *bool `json:"stream,omitempty"`
BillingType *int8 `json:"billing_type,omitempty"`
}
type UsageCleanupTask struct {
ID int64 `json:"id"`
Status string `json:"status"`
Filters UsageCleanupFilters `json:"filters"`
CreatedBy int64 `json:"created_by"`
DeletedRows int64 `json:"deleted_rows"`
ErrorMessage *string `json:"error_message,omitempty"`
CanceledBy *int64 `json:"canceled_by,omitempty"`
CanceledAt *time.Time `json:"canceled_at,omitempty"`
StartedAt *time.Time `json:"started_at,omitempty"`
FinishedAt *time.Time `json:"finished_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// AccountSummary is a minimal account info for usage log display.
// It intentionally excludes sensitive fields like Credentials, Proxy, etc.
type AccountSummary struct {
@@ -232,23 +330,30 @@ type UserSubscription struct {
WeeklyUsageUSD float64 `json:"weekly_usage_usd"`
MonthlyUsageUSD float64 `json:"monthly_usage_usd"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
}
// AdminUserSubscription 是管理员接口使用的订阅 DTO包含分配信息/备注等字段)。
// 注意:普通用户接口不得返回 assigned_by/assigned_at/notes/assigned_by_user 等管理员字段。
type AdminUserSubscription struct {
UserSubscription
AssignedBy *int64 `json:"assigned_by"`
AssignedAt time.Time `json:"assigned_at"`
Notes string `json:"notes"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
AssignedByUser *User `json:"assigned_by_user,omitempty"`
AssignedByUser *User `json:"assigned_by_user,omitempty"`
}
type BulkAssignResult struct {
SuccessCount int `json:"success_count"`
FailedCount int `json:"failed_count"`
Subscriptions []UserSubscription `json:"subscriptions"`
Errors []string `json:"errors"`
SuccessCount int `json:"success_count"`
FailedCount int `json:"failed_count"`
Subscriptions []AdminUserSubscription `json:"subscriptions"`
Errors []string `json:"errors"`
}
// PromoCode 注册优惠码

View File

@@ -31,6 +31,8 @@ type GatewayHandler struct {
userService *service.UserService
billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
maxAccountSwitchesGemini int
}
// NewGatewayHandler creates a new GatewayHandler
@@ -44,8 +46,16 @@ func NewGatewayHandler(
cfg *config.Config,
) *GatewayHandler {
pingInterval := time.Duration(0)
maxAccountSwitches := 10
maxAccountSwitchesGemini := 3
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
if cfg.Gateway.MaxAccountSwitches > 0 {
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
}
if cfg.Gateway.MaxAccountSwitchesGemini > 0 {
maxAccountSwitchesGemini = cfg.Gateway.MaxAccountSwitchesGemini
}
}
return &GatewayHandler{
gatewayService: gatewayService,
@@ -54,6 +64,8 @@ func NewGatewayHandler(
userService: userService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
maxAccountSwitches: maxAccountSwitches,
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
}
}
@@ -179,13 +191,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
if platform == service.PlatformGemini {
const maxAccountSwitches = 3
maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
@@ -197,17 +209,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account := selection.Account
setOpsSelectedAccount(c, account.ID)
// 检查预热请求拦截(在账号选择后、转发前检查
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
// 检查请求拦截(预热请求、SUGGESTION MODE等
if account.IsInterceptWarmupEnabled() {
interceptType := detectInterceptType(body)
if interceptType != InterceptTypeNone {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if reqStream {
sendMockInterceptStream(c, reqModel, interceptType)
} else {
sendMockInterceptResponse(c, reqModel, interceptType)
}
return
}
if reqStream {
sendMockWarmupStream(c, reqModel)
} else {
sendMockWarmupResponse(c, reqModel)
}
return
}
// 3. 获取账号并发槽位
@@ -275,12 +290,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverStatus = failoverErr.StatusCode
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
lastFailoverStatus = failoverErr.StatusCode
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
@@ -314,14 +328,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
}
const maxAccountSwitches = 10
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
for {
// 选择支持该模型的账号
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
@@ -333,17 +347,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account := selection.Account
setOpsSelectedAccount(c, account.ID)
// 检查预热请求拦截(在账号选择后、转发前检查
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
// 检查请求拦截(预热请求、SUGGESTION MODE等
if account.IsInterceptWarmupEnabled() {
interceptType := detectInterceptType(body)
if interceptType != InterceptTypeNone {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if reqStream {
sendMockInterceptStream(c, reqModel, interceptType)
} else {
sendMockInterceptResponse(c, reqModel, interceptType)
}
return
}
if reqStream {
sendMockWarmupStream(c, reqModel)
} else {
sendMockWarmupResponse(c, reqModel)
}
return
}
// 3. 获取账号并发槽位
@@ -409,12 +426,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverStatus = failoverErr.StatusCode
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
lastFailoverStatus = failoverErr.StatusCode
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
@@ -755,17 +771,30 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
}
}
// isWarmupRequest 检测是否为预热请求标题生成、Warmup等
func isWarmupRequest(body []byte) bool {
// 快速检查如果body不包含关键字直接返回false
// InterceptType 表示请求拦截类型
type InterceptType int
const (
InterceptTypeNone InterceptType = iota
InterceptTypeWarmup // 预热请求(返回 "New Conversation"
InterceptTypeSuggestionMode // SUGGESTION MODE返回空字符串
)
// detectInterceptType 检测请求是否需要拦截,返回拦截类型
func detectInterceptType(body []byte) InterceptType {
// 快速检查:如果不包含任何关键字,直接返回
bodyStr := string(body)
if !strings.Contains(bodyStr, "title") && !strings.Contains(bodyStr, "Warmup") {
return false
hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:")
hasWarmupKeyword := strings.Contains(bodyStr, "title") || strings.Contains(bodyStr, "Warmup")
if !hasSuggestionMode && !hasWarmupKeyword {
return InterceptTypeNone
}
// 解析完整请求
// 解析请求(只解析一次)
var req struct {
Messages []struct {
Role string `json:"role"`
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
@@ -776,43 +805,71 @@ func isWarmupRequest(body []byte) bool {
} `json:"system"`
}
if err := json.Unmarshal(body, &req); err != nil {
return false
return InterceptTypeNone
}
// 检查 messages 中的标题提示模式
for _, msg := range req.Messages {
for _, content := range msg.Content {
if content.Type == "text" {
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
content.Text == "Warmup" {
return true
// 检查 SUGGESTION MODE最后一条 user 消息)
if hasSuggestionMode && len(req.Messages) > 0 {
lastMsg := req.Messages[len(req.Messages)-1]
if lastMsg.Role == "user" && len(lastMsg.Content) > 0 &&
lastMsg.Content[0].Type == "text" &&
strings.HasPrefix(lastMsg.Content[0].Text, "[SUGGESTION MODE:") {
return InterceptTypeSuggestionMode
}
}
// 检查 Warmup 请求
if hasWarmupKeyword {
// 检查 messages 中的标题提示模式
for _, msg := range req.Messages {
for _, content := range msg.Content {
if content.Type == "text" {
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
content.Text == "Warmup" {
return InterceptTypeWarmup
}
}
}
}
// 检查 system 中的标题提取模式
for _, sys := range req.System {
if strings.Contains(sys.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
return InterceptTypeWarmup
}
}
}
// 检查 system 中的标题提取模式
for _, system := range req.System {
if strings.Contains(system.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
return true
}
}
return false
return InterceptTypeNone
}
// sendMockWarmupStream 发送流式 mock 响应(用于预热请求拦截)
func sendMockWarmupStream(c *gin.Context, model string) {
// sendMockInterceptStream 发送流式 mock 响应(用于请求拦截)
func sendMockInterceptStream(c *gin.Context, model string, interceptType InterceptType) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
// 根据拦截类型决定响应内容
var msgID string
var outputTokens int
var textDeltas []string
switch interceptType {
case InterceptTypeSuggestionMode:
msgID = "msg_mock_suggestion"
outputTokens = 1
textDeltas = []string{""} // 空内容
default: // InterceptTypeWarmup
msgID = "msg_mock_warmup"
outputTokens = 2
textDeltas = []string{"New", " Conversation"}
}
// Build message_start event with proper JSON marshaling
messageStart := map[string]any{
"type": "message_start",
"message": map[string]any{
"id": "msg_mock_warmup",
"id": msgID,
"type": "message",
"role": "assistant",
"model": model,
@@ -827,16 +884,46 @@ func sendMockWarmupStream(c *gin.Context, model string) {
}
messageStartJSON, _ := json.Marshal(messageStart)
// Build events
events := []string{
`event: message_start` + "\n" + `data: ` + string(messageStartJSON),
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
`event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
`event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
`event: content_block_stop` + "\n" + `data: {"index":0,"type":"content_block_stop"}`,
`event: message_delta` + "\n" + `data: {"delta":{"stop_reason":"end_turn","stop_sequence":null},"type":"message_delta","usage":{"input_tokens":10,"output_tokens":2}}`,
`event: message_stop` + "\n" + `data: {"type":"message_stop"}`,
}
// Add text deltas
for _, text := range textDeltas {
delta := map[string]any{
"type": "content_block_delta",
"index": 0,
"delta": map[string]string{
"type": "text_delta",
"text": text,
},
}
deltaJSON, _ := json.Marshal(delta)
events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON))
}
// Add final events
messageDelta := map[string]any{
"type": "message_delta",
"delta": map[string]any{
"stop_reason": "end_turn",
"stop_sequence": nil,
},
"usage": map[string]int{
"input_tokens": 10,
"output_tokens": outputTokens,
},
}
messageDeltaJSON, _ := json.Marshal(messageDelta)
events = append(events,
`event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`,
`event: message_delta`+"\n"+`data: `+string(messageDeltaJSON),
`event: message_stop`+"\n"+`data: {"type":"message_stop"}`,
)
for _, event := range events {
_, _ = c.Writer.WriteString(event + "\n\n")
c.Writer.Flush()
@@ -844,18 +931,32 @@ func sendMockWarmupStream(c *gin.Context, model string) {
}
}
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
func sendMockWarmupResponse(c *gin.Context, model string) {
// sendMockInterceptResponse 发送非流式 mock 响应(用于请求拦截)
func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) {
var msgID, text string
var outputTokens int
switch interceptType {
case InterceptTypeSuggestionMode:
msgID = "msg_mock_suggestion"
text = ""
outputTokens = 1
default: // InterceptTypeWarmup
msgID = "msg_mock_warmup"
text = "New Conversation"
outputTokens = 2
}
c.JSON(http.StatusOK, gin.H{
"id": "msg_mock_warmup",
"id": msgID,
"type": "message",
"role": "assistant",
"model": model,
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
"content": []gin.H{{"type": "text", "text": text}},
"stop_reason": "end_turn",
"usage": gin.H{
"input_tokens": 10,
"output_tokens": 2,
"output_tokens": outputTokens,
},
})
}

View File

@@ -0,0 +1,122 @@
//go:build unit
package handler
import (
"crypto/sha256"
"encoding/hex"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestExtractGeminiCLISessionHash(t *testing.T) {
tests := []struct {
name string
body string
privilegedUserID string
wantEmpty bool
wantHash string
}{
{
name: "with privileged-user-id and tmp dir",
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: false,
wantHash: func() string {
combined := "90785f52-8bbe-4b17-b111-a1ddea1636c3:f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"
hash := sha256.Sum256([]byte(combined))
return hex.EncodeToString(hash[:])
}(),
},
{
name: "without privileged-user-id but with tmp dir",
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
privilegedUserID: "",
wantEmpty: false,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "without tmp dir",
body: `{"contents":[{"parts":[{"text":"Hello world"}]}]}`,
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: true,
},
{
name: "empty body",
body: "",
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 创建测试上下文
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/test", nil)
if tt.privilegedUserID != "" {
c.Request.Header.Set("x-gemini-api-privileged-user-id", tt.privilegedUserID)
}
// 调用函数
result := extractGeminiCLISessionHash(c, []byte(tt.body))
// 验证结果
if tt.wantEmpty {
require.Empty(t, result, "expected empty session hash")
} else {
require.NotEmpty(t, result, "expected non-empty session hash")
require.Equal(t, tt.wantHash, result, "session hash mismatch")
}
})
}
}
func TestGeminiCLITmpDirRegex(t *testing.T) {
tests := []struct {
name string
input string
wantMatch bool
wantHash string
}{
{
name: "valid tmp dir path",
input: "/Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
wantMatch: true,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "valid tmp dir path in text",
input: "The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740\nOther text",
wantMatch: true,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "invalid hash length",
input: "/Users/ianshaw/.gemini/tmp/abc123",
wantMatch: false,
},
{
name: "no tmp dir",
input: "Hello world",
wantMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
match := geminiCLITmpDirRegex.FindStringSubmatch(tt.input)
if tt.wantMatch {
require.NotNil(t, match, "expected regex to match")
require.Len(t, match, 2, "expected 2 capture groups")
require.Equal(t, tt.wantHash, match[1], "hash mismatch")
} else {
require.Nil(t, match, "expected regex not to match")
}
})
}
}

View File

@@ -1,11 +1,15 @@
package handler
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"io"
"log"
"net/http"
"regexp"
"strings"
"time"
@@ -19,6 +23,17 @@ import (
"github.com/gin-gonic/gin"
)
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
func isGeminiCLIRequest(c *gin.Context, body []byte) bool {
if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" {
return true
}
return geminiCLITmpDirRegex.Match(body)
}
// GeminiV1BetaListModels proxies:
// GET /v1beta/models
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
@@ -214,19 +229,33 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 3) select account (sticky session based on request body)
parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 优先使用 Gemini CLI 的会话标识privileged-user-id + tmp 目录哈希)
sessionHash := extractGeminiCLISessionHash(c, body)
if sessionHash == "" {
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
}
sessionKey := sessionHash
if sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
const maxAccountSwitches = 3
// 查询粘性会话绑定的账号 ID用于检测账号切换
var sessionBoundAccountID int64
if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
}
isCLI := isGeminiCLIRequest(c, body)
cleanedForUnknownBinding := false
maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
if err != nil {
if len(failedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
@@ -238,6 +267,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
account := selection.Account
setOpsSelectedAccount(c, account.ID)
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
// 注意Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
body = service.CleanGeminiNativeThoughtSignatures(body)
sessionBoundAccountID = account.ID
} else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
// 无缓存绑定但请求里已有 thoughtSignature常见于缓存丢失/TTL 过期后CLI 继续携带旧签名。
// 为避免第一次转发就 400这里做一次确定性清理让新账号重新生成签名链路。
log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively")
body = service.CleanGeminiNativeThoughtSignatures(body)
cleanedForUnknownBinding = true
sessionBoundAccountID = account.ID
} else if sessionBoundAccountID == 0 {
// 记录本次请求中首次选择到的账号,便于同一请求内 failover 时检测切换。
sessionBoundAccountID = account.ID
}
// 4) account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
@@ -433,3 +480,38 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
}
return false
}
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
//
// 会话标识生成策略:
// 1. 从请求体中提取 tmp 目录哈希64位十六进制
// 2. 从 header 中提取 privileged-user-idUUID
// 3. 组合两者生成 SHA256 哈希作为最终的会话标识
//
// 如果找不到 tmp 目录哈希,返回空字符串(不使用粘性会话)。
//
// extractGeminiCLISessionHash extracts session identifier from Gemini CLI requests.
// Combines x-gemini-api-privileged-user-id header with tmp directory hash from request body.
func extractGeminiCLISessionHash(c *gin.Context, body []byte) string {
// 1. 从请求体中提取 tmp 目录哈希
match := geminiCLITmpDirRegex.FindSubmatch(body)
if len(match) < 2 {
return "" // 没有找到 tmp 目录,不使用粘性会话
}
tmpDirHash := string(match[1])
// 2. 提取 privileged-user-id
privilegedUserID := strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id"))
// 3. 组合生成最终的 session hash
if privilegedUserID != "" {
// 组合两个标识符privileged-user-id + tmp 目录哈希
combined := privilegedUserID + ":" + tmpDirHash
hash := sha256.Sum256([]byte(combined))
return hex.EncodeToString(hash[:])
}
// 如果没有 privileged-user-id直接使用 tmp 目录哈希
return tmpDirHash
}

View File

@@ -37,6 +37,7 @@ type Handlers struct {
Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler
Setting *SettingHandler
Totp *TotpHandler
}
// BuildInfo contains build-time information

View File

@@ -25,6 +25,7 @@ type OpenAIGatewayHandler struct {
gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
@@ -35,13 +36,18 @@ func NewOpenAIGatewayHandler(
cfg *config.Config,
) *OpenAIGatewayHandler {
pingInterval := time.Duration(0)
maxAccountSwitches := 3
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
if cfg.Gateway.MaxAccountSwitches > 0 {
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
}
}
return &OpenAIGatewayHandler{
gatewayService: gatewayService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
}
}
@@ -114,6 +120,26 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
// 要求 previous_response_id或 input 内存在带 call_id 的 tool_call/function_call
// 或带 id 且与 call_id 匹配的 item_reference。
if service.HasFunctionCallOutput(reqBody) {
previousResponseID, _ := reqBody["previous_response_id"].(string)
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
if service.HasFunctionCallOutputMissingCallID(reqBody) {
log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
return
}
callIDs := service.FunctionCallOutputCallIDs(reqBody)
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
return
}
}
}
// Track if we've started streaming (for error handling)
streamStarted := false
@@ -166,10 +192,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
// Generate session hash (from header for OpenAI)
sessionHash := h.gatewayService.GenerateSessionHash(c)
// Generate session hash (header first; fallback to prompt_cache_key)
sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody)
const maxAccountSwitches = 3
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0

View File

@@ -544,6 +544,11 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
body := w.buf.Bytes()
parsed := parseOpsErrorResponse(body)
// Skip logging if the error should be filtered based on settings
if shouldSkipOpsErrorLog(c.Request.Context(), ops, parsed.Message, string(body), c.Request.URL.Path) {
return
}
apiKey, _ := middleware2.GetAPIKeyFromContext(c)
clientRequestID, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
@@ -832,28 +837,30 @@ func normalizeOpsErrorType(errType string, code string) string {
func classifyOpsPhase(errType, message, code string) string {
msg := strings.ToLower(message)
// Standardized phases: request|auth|routing|upstream|network|internal
// Map billing/concurrency/response => request; scheduling => routing.
switch strings.TrimSpace(code) {
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
return "billing"
return "request"
}
switch errType {
case "authentication_error":
return "auth"
case "billing_error", "subscription_error":
return "billing"
return "request"
case "rate_limit_error":
if strings.Contains(msg, "concurrency") || strings.Contains(msg, "pending") || strings.Contains(msg, "queue") {
return "concurrency"
return "request"
}
return "upstream"
case "invalid_request_error":
return "response"
return "request"
case "upstream_error", "overloaded_error":
return "upstream"
case "api_error":
if strings.Contains(msg, "no available accounts") {
return "scheduling"
return "routing"
}
return "internal"
default:
@@ -914,34 +921,38 @@ func classifyOpsIsBusinessLimited(errType, phase, code string, status int, messa
}
func classifyOpsErrorOwner(phase string, message string) string {
// Standardized owners: client|provider|platform
switch phase {
case "upstream", "network":
return "provider"
case "billing", "concurrency", "auth", "response":
case "request", "auth":
return "client"
case "routing", "internal":
return "platform"
default:
if strings.Contains(strings.ToLower(message), "upstream") {
return "provider"
}
return "sub2api"
return "platform"
}
}
func classifyOpsErrorSource(phase string, message string) string {
// Standardized sources: client_request|upstream_http|gateway
switch phase {
case "upstream":
return "upstream_http"
case "network":
return "upstream_network"
case "billing":
return "billing"
case "concurrency":
return "concurrency"
return "gateway"
case "request", "auth":
return "client_request"
case "routing", "internal":
return "gateway"
default:
if strings.Contains(strings.ToLower(message), "upstream") {
return "upstream_http"
}
return "internal"
return "gateway"
}
}
@@ -963,3 +974,42 @@ func truncateString(s string, max int) string {
func strconvItoa(v int) string {
return strconv.Itoa(v)
}
// shouldSkipOpsErrorLog determines if an error should be skipped from logging based on settings.
// Returns true for errors that should be filtered according to OpsAdvancedSettings.
func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message, body, requestPath string) bool {
if ops == nil {
return false
}
// Get advanced settings to check filter configuration
settings, err := ops.GetOpsAdvancedSettings(ctx)
if err != nil || settings == nil {
// If we can't get settings, don't skip (fail open)
return false
}
msgLower := strings.ToLower(message)
bodyLower := strings.ToLower(body)
// Check if count_tokens errors should be ignored
if settings.IgnoreCountTokensErrors && strings.Contains(requestPath, "/count_tokens") {
return true
}
// Check if context canceled errors should be ignored (client disconnects)
if settings.IgnoreContextCanceled {
if strings.Contains(msgLower, "context canceled") || strings.Contains(bodyLower, "context canceled") {
return true
}
}
// Check if "no available accounts" errors should be ignored
if settings.IgnoreNoAvailableAccounts {
if strings.Contains(msgLower, "no available accounts") || strings.Contains(bodyLower, "no available accounts") {
return true
}
}
return false
}

View File

@@ -32,18 +32,21 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
}
response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: h.version,
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: h.version,
})
}

View File

@@ -0,0 +1,181 @@
package handler
import (
"github.com/gin-gonic/gin"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// TotpHandler handles TOTP-related requests
type TotpHandler struct {
totpService *service.TotpService
}
// NewTotpHandler creates a new TotpHandler
func NewTotpHandler(totpService *service.TotpService) *TotpHandler {
return &TotpHandler{
totpService: totpService,
}
}
// TotpStatusResponse represents the TOTP status response
type TotpStatusResponse struct {
Enabled bool `json:"enabled"`
EnabledAt *int64 `json:"enabled_at,omitempty"` // Unix timestamp
FeatureEnabled bool `json:"feature_enabled"`
}
// GetStatus returns the TOTP status for the current user
// GET /api/v1/user/totp/status
func (h *TotpHandler) GetStatus(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
status, err := h.totpService.GetStatus(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
resp := TotpStatusResponse{
Enabled: status.Enabled,
FeatureEnabled: status.FeatureEnabled,
}
if status.EnabledAt != nil {
ts := status.EnabledAt.Unix()
resp.EnabledAt = &ts
}
response.Success(c, resp)
}
// TotpSetupRequest represents the request to initiate TOTP setup
type TotpSetupRequest struct {
EmailCode string `json:"email_code"`
Password string `json:"password"`
}
// TotpSetupResponse represents the TOTP setup response
type TotpSetupResponse struct {
Secret string `json:"secret"`
QRCodeURL string `json:"qr_code_url"`
SetupToken string `json:"setup_token"`
Countdown int `json:"countdown"`
}
// InitiateSetup starts the TOTP setup process
// POST /api/v1/user/totp/setup
func (h *TotpHandler) InitiateSetup(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req TotpSetupRequest
if err := c.ShouldBindJSON(&req); err != nil {
// Allow empty body (optional params)
req = TotpSetupRequest{}
}
result, err := h.totpService.InitiateSetup(c.Request.Context(), subject.UserID, req.EmailCode, req.Password)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, TotpSetupResponse{
Secret: result.Secret,
QRCodeURL: result.QRCodeURL,
SetupToken: result.SetupToken,
Countdown: result.Countdown,
})
}
// TotpEnableRequest represents the request to enable TOTP
type TotpEnableRequest struct {
TotpCode string `json:"totp_code" binding:"required,len=6"`
SetupToken string `json:"setup_token" binding:"required"`
}
// Enable completes the TOTP setup
// POST /api/v1/user/totp/enable
func (h *TotpHandler) Enable(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req TotpEnableRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.totpService.CompleteSetup(c.Request.Context(), subject.UserID, req.TotpCode, req.SetupToken); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"success": true})
}
// TotpDisableRequest represents the request to disable TOTP
type TotpDisableRequest struct {
EmailCode string `json:"email_code"`
Password string `json:"password"`
}
// Disable disables TOTP for the current user
// POST /api/v1/user/totp/disable
func (h *TotpHandler) Disable(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req TotpDisableRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.totpService.Disable(c.Request.Context(), subject.UserID, req.EmailCode, req.Password); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"success": true})
}
// GetVerificationMethod returns the verification method for TOTP operations
// GET /api/v1/user/totp/verification-method
func (h *TotpHandler) GetVerificationMethod(c *gin.Context) {
method := h.totpService.GetVerificationMethod(c.Request.Context())
response.Success(c, method)
}
// SendVerifyCode sends an email verification code for TOTP operations
// POST /api/v1/user/totp/send-code
func (h *TotpHandler) SendVerifyCode(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"success": true})
}

View File

@@ -47,9 +47,6 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
return
}
// 清空notes字段普通用户不应看到备注
userData.Notes = ""
response.Success(c, dto.UserFromService(userData))
}
@@ -105,8 +102,5 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
return
}
// 清空notes字段普通用户不应看到备注
updatedUser.Notes = ""
response.Success(c, dto.UserFromService(updatedUser))
}

View File

@@ -70,6 +70,7 @@ func ProvideHandlers(
gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler,
settingHandler *SettingHandler,
totpHandler *TotpHandler,
) *Handlers {
return &Handlers{
Auth: authHandler,
@@ -82,6 +83,7 @@ func ProvideHandlers(
Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler,
Setting: settingHandler,
Totp: totpHandler,
}
}
@@ -96,6 +98,7 @@ var ProviderSet = wire.NewSet(
NewSubscriptionHandler,
NewGatewayHandler,
NewOpenAIGatewayHandler,
NewTotpHandler,
ProvideSettingHandler,
// Admin handlers

View File

@@ -2,7 +2,10 @@ package middleware
import (
"context"
"fmt"
"log"
"net/http"
"strconv"
"time"
"github.com/gin-gonic/gin"
@@ -25,15 +28,34 @@ type RateLimitOptions struct {
var rateLimitScript = redis.NewScript(`
local current = redis.call('INCR', KEYS[1])
local ttl = redis.call('PTTL', KEYS[1])
if current == 1 or ttl == -1 then
local repaired = 0
if current == 1 then
redis.call('PEXPIRE', KEYS[1], ARGV[1])
elseif ttl == -1 then
redis.call('PEXPIRE', KEYS[1], ARGV[1])
repaired = 1
end
return current
return {current, repaired}
`)
// rateLimitRun 允许测试覆写脚本执行逻辑
var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) {
return rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Int64()
var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) {
values, err := rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Slice()
if err != nil {
return 0, false, err
}
if len(values) < 2 {
return 0, false, fmt.Errorf("rate limit script returned %d values", len(values))
}
count, err := parseInt64(values[0])
if err != nil {
return 0, false, err
}
repaired, err := parseInt64(values[1])
if err != nil {
return 0, false, err
}
return count, repaired == 1, nil
}
// RateLimiter Redis 速率限制器
@@ -74,8 +96,9 @@ func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Durati
windowMillis := windowTTLMillis(window)
// 使用 Lua 脚本原子操作增加计数并设置过期
count, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis)
count, repaired, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis)
if err != nil {
log.Printf("[RateLimit] redis error: key=%s mode=%s err=%v", redisKey, failureModeLabel(failureMode), err)
if failureMode == RateLimitFailClose {
abortRateLimit(c)
return
@@ -84,6 +107,9 @@ func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Durati
c.Next()
return
}
if repaired {
log.Printf("[RateLimit] ttl repaired: key=%s window_ms=%d", redisKey, windowMillis)
}
// 超过限制
if count > int64(limit) {
@@ -109,3 +135,27 @@ func abortRateLimit(c *gin.Context) {
"message": "Too many requests, please try again later",
})
}
func failureModeLabel(mode RateLimitFailureMode) string {
if mode == RateLimitFailClose {
return "fail-close"
}
return "fail-open"
}
func parseInt64(value any) (int64, error) {
switch v := value.(type) {
case int64:
return v, nil
case int:
return int64(v), nil
case string:
parsed, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return 0, err
}
return parsed, nil
default:
return 0, fmt.Errorf("unexpected value type %T", value)
}
}

View File

@@ -7,6 +7,9 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strconv"
"testing"
"time"
@@ -88,6 +91,7 @@ func performRequest(router *gin.Engine) *httptest.ResponseRecorder {
func startRedis(t *testing.T, ctx context.Context) *redis.Client {
t.Helper()
ensureDockerAvailable(t)
redisContainer, err := tcredis.Run(ctx, redisImageTag)
require.NoError(t, err)
@@ -112,3 +116,43 @@ func startRedis(t *testing.T, ctx context.Context) *redis.Client {
return rdb
}
func ensureDockerAvailable(t *testing.T) {
t.Helper()
if dockerAvailable() {
return
}
t.Skip("Docker 未启用,跳过依赖 testcontainers 的集成测试")
}
func dockerAvailable() bool {
if os.Getenv("DOCKER_HOST") != "" {
return true
}
socketCandidates := []string{
"/var/run/docker.sock",
filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"),
filepath.Join(userHomeDir(), ".docker", "run", "docker.sock"),
filepath.Join(userHomeDir(), ".docker", "desktop", "docker.sock"),
filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"),
}
for _, socket := range socketCandidates {
if socket == "" {
continue
}
if _, err := os.Stat(socket); err == nil {
return true
}
}
return false
}
func userHomeDir() string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return home
}

View File

@@ -66,13 +66,13 @@ func TestRateLimiterSuccessAndLimit(t *testing.T) {
originalRun := rateLimitRun
counts := []int64{1, 2}
callIndex := 0
rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) {
rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) {
if callIndex >= len(counts) {
return counts[len(counts)-1], nil
return counts[len(counts)-1], false, nil
}
value := counts[callIndex]
callIndex++
return value, nil
return value, false, nil
}
t.Cleanup(func() {
rateLimitRun = originalRun

View File

@@ -16,15 +16,6 @@ import (
"time"
)
// resolveHost 从 URL 解析 host
func resolveHost(urlStr string) string {
parsed, err := url.Parse(urlStr)
if err != nil {
return ""
}
return parsed.Host
}
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求v1internal 端点)
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
// 构建 URL流式请求添加 ?alt=sse 参数
@@ -39,23 +30,11 @@ func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken stri
return nil, err
}
// 基础 Headers
// 基础 Headers(与 Antigravity-Manager 保持一致,只设置这 3 个)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", UserAgent)
// Accept Header 根据请求类型设置
if isStream {
req.Header.Set("Accept", "text/event-stream")
} else {
req.Header.Set("Accept", "application/json")
}
// 显式设置 Host Header
if host := resolveHost(apiURL); host != "" {
req.Host = host
}
return req, nil
}
@@ -195,12 +174,15 @@ func isConnectionError(err error) bool {
}
// shouldFallbackToNextURL 判断是否应切换到下一个 URL
// 仅连接错误和 HTTP 429 触发 URL 降级
// 与 Antigravity-Manager 保持一致连接错误、429、408、404、5xx 触发 URL 降级
func shouldFallbackToNextURL(err error, statusCode int) bool {
if isConnectionError(err) {
return true
}
return statusCode == http.StatusTooManyRequests
return statusCode == http.StatusTooManyRequests ||
statusCode == http.StatusRequestTimeout ||
statusCode == http.StatusNotFound ||
statusCode >= 500
}
// ExchangeCode 用 authorization code 交换 token
@@ -321,11 +303,8 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
}
// 获取可用的 URL 列表
availableURLs := DefaultURLAvailability.GetAvailableURLs()
if len(availableURLs) == 0 {
availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有
}
// 固定顺序prod -> daily
availableURLs := BaseURLs
var lastErr error
for urlIdx, baseURL := range availableURLs {
@@ -343,7 +322,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
if err != nil {
lastErr = fmt.Errorf("loadCodeAssist 请求失败: %w", err)
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity] loadCodeAssist URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
continue
}
@@ -358,7 +336,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
// 检查是否需要 URL 降级
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity] loadCodeAssist URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
continue
}
@@ -376,6 +353,8 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
var rawResp map[string]any
_ = json.Unmarshal(respBodyBytes, &rawResp)
// 标记成功的 URL下次优先使用
DefaultURLAvailability.MarkSuccess(baseURL)
return &loadResp, rawResp, nil
}
@@ -412,11 +391,8 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
}
// 获取可用的 URL 列表
availableURLs := DefaultURLAvailability.GetAvailableURLs()
if len(availableURLs) == 0 {
availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有
}
// 固定顺序prod -> daily
availableURLs := BaseURLs
var lastErr error
for urlIdx, baseURL := range availableURLs {
@@ -434,7 +410,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
if err != nil {
lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity] fetchAvailableModels URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
continue
}
@@ -449,7 +424,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
// 检查是否需要 URL 降级
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity] fetchAvailableModels URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
continue
}
@@ -467,6 +441,8 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
var rawResp map[string]any
_ = json.Unmarshal(respBodyBytes, &rawResp)
// 标记成功的 URL下次优先使用
DefaultURLAvailability.MarkSuccess(baseURL)
return &modelsResp, rawResp, nil
}

View File

@@ -143,9 +143,10 @@ type GeminiResponse struct {
// GeminiCandidate Gemini 候选响应
type GeminiCandidate struct {
Content *GeminiContent `json:"content,omitempty"`
FinishReason string `json:"finishReason,omitempty"`
Index int `json:"index,omitempty"`
Content *GeminiContent `json:"content,omitempty"`
FinishReason string `json:"finishReason,omitempty"`
Index int `json:"index,omitempty"`
GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"`
}
// GeminiUsageMetadata Gemini 用量元数据
@@ -156,6 +157,23 @@ type GeminiUsageMetadata struct {
TotalTokenCount int `json:"totalTokenCount,omitempty"`
}
// GeminiGroundingMetadata Gemini grounding 元数据Web Search
type GeminiGroundingMetadata struct {
WebSearchQueries []string `json:"webSearchQueries,omitempty"`
GroundingChunks []GeminiGroundingChunk `json:"groundingChunks,omitempty"`
}
// GeminiGroundingChunk Gemini grounding chunk
type GeminiGroundingChunk struct {
Web *GeminiGroundingWeb `json:"web,omitempty"`
}
// GeminiGroundingWeb Gemini grounding web 信息
type GeminiGroundingWeb struct {
Title string `json:"title,omitempty"`
URI string `json:"uri,omitempty"`
}
// DefaultSafetySettings 默认安全设置(关闭所有过滤)
var DefaultSafetySettings = []GeminiSafetySetting{
{Category: "HARM_CATEGORY_HARASSMENT", Threshold: "OFF"},

View File

@@ -32,8 +32,8 @@ const (
"https://www.googleapis.com/auth/cclog " +
"https://www.googleapis.com/auth/experimentsandconfigs"
// User-Agent模拟官方客户端
UserAgent = "antigravity/1.104.0 darwin/arm64"
// User-Agent与 Antigravity-Manager 保持一致
UserAgent = "antigravity/1.11.9 windows/amd64"
// Session 过期时间
SessionTTL = 30 * time.Minute
@@ -42,22 +42,21 @@ const (
URLAvailabilityTTL = 5 * time.Minute
)
// BaseURLs 定义 Antigravity API 端点,按优先级排序
// fallback 顺序: sandbox → daily → prod
// BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致)
var BaseURLs = []string{
"https://daily-cloudcode-pa.sandbox.googleapis.com", // sandbox
"https://daily-cloudcode-pa.googleapis.com", // daily
"https://cloudcode-pa.googleapis.com", // prod
"https://cloudcode-pa.googleapis.com", // prod (优先)
"https://daily-cloudcode-pa.sandbox.googleapis.com", // daily sandbox (备用)
}
// BaseURL 默认 URL保持向后兼容
var BaseURL = BaseURLs[0]
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复)
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级
type URLAvailability struct {
mu sync.RWMutex
unavailable map[string]time.Time // URL -> 恢复时间
ttl time.Duration
lastSuccess string // 最近成功请求的 URL优先使用
}
// DefaultURLAvailability 全局 URL 可用性管理器
@@ -78,6 +77,15 @@ func (u *URLAvailability) MarkUnavailable(url string) {
u.unavailable[url] = time.Now().Add(u.ttl)
}
// MarkSuccess 标记 URL 请求成功,将其设为优先使用
func (u *URLAvailability) MarkSuccess(url string) {
u.mu.Lock()
defer u.mu.Unlock()
u.lastSuccess = url
// 成功后清除该 URL 的不可用标记
delete(u.unavailable, url)
}
// IsAvailable 检查 URL 是否可用
func (u *URLAvailability) IsAvailable(url string) bool {
u.mu.RLock()
@@ -89,14 +97,29 @@ func (u *URLAvailability) IsAvailable(url string) bool {
return time.Now().After(expiry)
}
// GetAvailableURLs 返回可用的 URL 列表(保持优先级顺序)
// GetAvailableURLs 返回可用的 URL 列表
// 最近成功的 URL 优先,其他按默认顺序
func (u *URLAvailability) GetAvailableURLs() []string {
u.mu.RLock()
defer u.mu.RUnlock()
now := time.Now()
result := make([]string, 0, len(BaseURLs))
// 如果有最近成功的 URL 且可用,放在最前面
if u.lastSuccess != "" {
expiry, exists := u.unavailable[u.lastSuccess]
if !exists || now.After(expiry) {
result = append(result, u.lastSuccess)
}
}
// 添加其他可用的 URL按默认顺序
for _, url := range BaseURLs {
// 跳过已添加的 lastSuccess
if url == u.lastSuccess {
continue
}
expiry, exists := u.unavailable[url]
if !exists || now.After(expiry) {
result = append(result, url)
@@ -240,24 +263,3 @@ func BuildAuthorizationURL(state, codeChallenge string) string {
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
// GenerateMockProjectID 生成随机 project_id当 API 不返回时使用)
// 格式:{形容词}-{名词}-{5位随机字符}
func GenerateMockProjectID() string {
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
randBytes, _ := GenerateRandomBytes(7)
adj := adjectives[int(randBytes[0])%len(adjectives)]
noun := nouns[int(randBytes[1])%len(nouns)]
// 生成 5 位随机字符a-z0-9
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
suffix := make([]byte, 5)
for i := 0; i < 5; i++ {
suffix[i] = charset[int(randBytes[i+2])%len(charset)]
}
return fmt.Sprintf("%s-%s-%s", adj, noun, string(suffix))
}

View File

@@ -7,13 +7,11 @@ import (
"fmt"
"log"
"math/rand"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
@@ -54,6 +52,9 @@ func DefaultTransformOptions() TransformOptions {
}
}
// webSearchFallbackModel web_search 请求使用的降级模型
const webSearchFallbackModel = "gemini-2.5-flash"
// TransformClaudeToGemini 将 Claude 请求转换为 v1internal Gemini 格式
func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel string) ([]byte, error) {
return TransformClaudeToGeminiWithOptions(claudeReq, projectID, mappedModel, DefaultTransformOptions())
@@ -64,12 +65,23 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
// 用于存储 tool_use id -> name 映射
toolIDToName := make(map[string]string)
// 检测是否有 web_search 工具
hasWebSearchTool := hasWebSearchTool(claudeReq.Tools)
requestType := "agent"
targetModel := mappedModel
if hasWebSearchTool {
requestType = "web_search"
if targetModel != webSearchFallbackModel {
targetModel = webSearchFallbackModel
}
}
// 检测是否启用 thinking
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 只有 Gemini 模型支持 dummy thought workaround
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
allowDummyThought := strings.HasPrefix(targetModel, "gemini-")
// 1. 构建 contents
contents, strippedThinking, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
@@ -78,7 +90,7 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
}
// 2. 构建 systemInstruction
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts)
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts, claudeReq.Tools)
// 3. 构建 generationConfig
reqForConfig := claudeReq
@@ -89,6 +101,11 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
reqCopy.Thinking = nil
reqForConfig = &reqCopy
}
if targetModel != "" && targetModel != reqForConfig.Model {
reqCopy := *reqForConfig
reqCopy.Model = targetModel
reqForConfig = &reqCopy
}
generationConfig := buildGenerationConfig(reqForConfig)
// 4. 构建 tools
@@ -127,8 +144,8 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
Project: projectID,
RequestID: "agent-" + uuid.New().String(),
UserAgent: "antigravity", // 固定值,与官方客户端一致
RequestType: "agent",
Model: mappedModel,
RequestType: requestType,
Model: targetModel,
Request: innerRequest,
}
@@ -154,8 +171,40 @@ func GetDefaultIdentityPatch() string {
return antigravityIdentity
}
// buildSystemInstruction 构建 systemInstruction
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions) *GeminiContent {
// mcpXMLProtocol MCP XML 工具调用协议(与 Antigravity-Manager 保持一致)
const mcpXMLProtocol = `
==== MCP XML 工具调用协议 (Workaround) ====
当你需要调用名称以 ` + "`mcp__`" + ` 开头的 MCP 工具时:
1) 优先尝试 XML 格式调用:输出 ` + "`<mcp__tool_name>{\"arg\":\"value\"}</mcp__tool_name>`" + `
2) 必须直接输出 XML 块,无需 markdown 包装,内容为 JSON 格式的入参。
3) 这种方式具有更高的连通性和容错性,适用于大型结果返回场景。
===========================================`
// hasMCPTools 检测是否有 mcp__ 前缀的工具
func hasMCPTools(tools []ClaudeTool) bool {
for _, tool := range tools {
if strings.HasPrefix(tool.Name, "mcp__") {
return true
}
}
return false
}
// filterOpenCodePrompt 过滤 OpenCode 默认提示词,只保留用户自定义指令
func filterOpenCodePrompt(text string) string {
if !strings.Contains(text, "You are an interactive CLI tool") {
return text
}
// 提取 "Instructions from:" 及之后的部分
if idx := strings.Index(text, "Instructions from:"); idx >= 0 {
return text[idx:]
}
// 如果没有自定义指令,返回空
return ""
}
// buildSystemInstruction 构建 systemInstruction与 Antigravity-Manager 保持一致)
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
var parts []GeminiPart
// 先解析用户的 system prompt检测是否已包含 Antigravity identity
@@ -167,10 +216,14 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
var sysStr string
if err := json.Unmarshal(system, &sysStr); err == nil {
if strings.TrimSpace(sysStr) != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: sysStr})
if strings.Contains(sysStr, "You are Antigravity") {
userHasAntigravityIdentity = true
}
// 过滤 OpenCode 默认提示词
filtered := filterOpenCodePrompt(sysStr)
if filtered != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
}
}
} else {
// 尝试解析为数组
@@ -178,10 +231,14 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
if err := json.Unmarshal(system, &sysBlocks); err == nil {
for _, block := range sysBlocks {
if block.Type == "text" && strings.TrimSpace(block.Text) != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: block.Text})
if strings.Contains(block.Text, "You are Antigravity") {
userHasAntigravityIdentity = true
}
// 过滤 OpenCode 默认提示词
filtered := filterOpenCodePrompt(block.Text)
if filtered != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
}
}
}
}
@@ -200,6 +257,16 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
// 添加用户的 system prompt
parts = append(parts, userSystemParts...)
// 检测是否有 MCP 工具,如有则注入 XML 调用协议
if hasMCPTools(tools) {
parts = append(parts, GeminiPart{Text: mcpXMLProtocol})
}
// 如果用户没有提供 Antigravity 身份,添加结束标记
if !userHasAntigravityIdentity {
parts = append(parts, GeminiPart{Text: "\n--- [SYSTEM_PROMPT_END] ---"})
}
if len(parts) == 0 {
return nil
}
@@ -300,8 +367,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
Text: block.Thinking,
Thought: true,
}
// 保留原有 signatureClaude 模型需要有效的 signature
if block.Signature != "" {
// signature 处理:
// - Claude 模型allowDummyThought=false必须是上游返回的真实 signaturedummy 视为缺失)
// - Gemini 模型allowDummyThought=true优先透传真实 signature缺失时使用 dummy signature
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
part.ThoughtSignature = block.Signature
} else if !allowDummyThought {
// Claude 模型需要有效 signature在缺失时降级为普通文本并在上层禁用 thinking mode。
@@ -340,12 +409,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
},
}
// tool_use 的 signature 处理:
// - Gemini 模型:使用 dummy signature跳过 thought_signature 校验
// - Claude 模型:透传上游返回的真实 signatureVertex/Google 需要完整签名链路)
if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature
} else if block.Signature != "" && block.Signature != dummyThoughtSignature {
// - Claude 模型allowDummyThought=false必须是上游返回的真实 signaturedummy 视为缺失
// - Gemini 模型allowDummyThought=true优先透传真实 signature缺失时使用 dummy signature
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
part.ThoughtSignature = block.Signature
} else if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature
}
parts = append(parts, part)
@@ -429,6 +498,11 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
StopSequences: DefaultStopSequences,
}
// 如果请求中指定了 MaxTokens使用请求值
if req.MaxTokens > 0 {
config.MaxOutputTokens = req.MaxTokens
}
// Thinking 配置
if req.Thinking != nil && req.Thinking.Type == "enabled" {
config.ThinkingConfig = &GeminiThinkingConfig{
@@ -458,37 +532,43 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
return config
}
func hasWebSearchTool(tools []ClaudeTool) bool {
for _, tool := range tools {
if isWebSearchTool(tool) {
return true
}
}
return false
}
func isWebSearchTool(tool ClaudeTool) bool {
if strings.HasPrefix(tool.Type, "web_search") || tool.Type == "google_search" {
return true
}
name := strings.TrimSpace(tool.Name)
switch name {
case "web_search", "google_search", "web_search_20250305":
return true
default:
return false
}
}
// buildTools 构建 tools
func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
if len(tools) == 0 {
return nil
}
// 检查是否有 web_search 工具
hasWebSearch := false
for _, tool := range tools {
if tool.Name == "web_search" {
hasWebSearch = true
break
}
}
if hasWebSearch {
// Web Search 工具映射
return []GeminiToolDeclaration{{
GoogleSearch: &GeminiGoogleSearch{
EnhancedContent: &GeminiEnhancedContent{
ImageSearch: &GeminiImageSearch{
MaxResultCount: 5,
},
},
},
}}
}
hasWebSearch := hasWebSearchTool(tools)
// 普通工具
var funcDecls []GeminiFunctionDecl
for _, tool := range tools {
if isWebSearchTool(tool) {
continue
}
// 跳过无效工具名称
if strings.TrimSpace(tool.Name) == "" {
log.Printf("Warning: skipping tool with empty name")
@@ -514,11 +594,14 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
}
// 清理 JSON Schema
params := cleanJSONSchema(inputSchema)
// 1. 深度清理 [undefined] 值
DeepCleanUndefined(inputSchema)
// 2. 转换为符合 Gemini v1internal 的 schema
params := CleanJSONSchema(inputSchema)
// 为 nil schema 提供默认值
if params == nil {
params = map[string]any{
"type": "OBJECT",
"type": "object", // lowercase type
"properties": map[string]any{},
}
}
@@ -531,243 +614,23 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
}
if len(funcDecls) == 0 {
return nil
if !hasWebSearch {
return nil
}
// Web Search 工具映射
return []GeminiToolDeclaration{{
GoogleSearch: &GeminiGoogleSearch{
EnhancedContent: &GeminiEnhancedContent{
ImageSearch: &GeminiImageSearch{
MaxResultCount: 5,
},
},
},
}}
}
return []GeminiToolDeclaration{{
FunctionDeclarations: funcDecls,
}}
}
// cleanJSONSchema 清理 JSON Schema移除 Antigravity/Gemini 不支持的字段
// 参考 proxycast 的实现,确保 schema 符合 JSON Schema draft 2020-12
func cleanJSONSchema(schema map[string]any) map[string]any {
if schema == nil {
return nil
}
cleaned := cleanSchemaValue(schema, "$")
result, ok := cleaned.(map[string]any)
if !ok {
return nil
}
// 确保有 type 字段(默认 OBJECT
if _, hasType := result["type"]; !hasType {
result["type"] = "OBJECT"
}
// 确保有 properties 字段(默认空对象)
if _, hasProps := result["properties"]; !hasProps {
result["properties"] = make(map[string]any)
}
// 验证 required 中的字段都存在于 properties 中
if required, ok := result["required"].([]any); ok {
if props, ok := result["properties"].(map[string]any); ok {
validRequired := make([]any, 0, len(required))
for _, r := range required {
if reqName, ok := r.(string); ok {
if _, exists := props[reqName]; exists {
validRequired = append(validRequired, r)
}
}
}
if len(validRequired) > 0 {
result["required"] = validRequired
} else {
delete(result, "required")
}
}
}
return result
}
var schemaValidationKeys = map[string]bool{
"minLength": true,
"maxLength": true,
"pattern": true,
"minimum": true,
"maximum": true,
"exclusiveMinimum": true,
"exclusiveMaximum": true,
"multipleOf": true,
"uniqueItems": true,
"minItems": true,
"maxItems": true,
"minProperties": true,
"maxProperties": true,
"patternProperties": true,
"propertyNames": true,
"dependencies": true,
"dependentSchemas": true,
"dependentRequired": true,
}
var warnedSchemaKeys sync.Map
func schemaCleaningWarningsEnabled() bool {
// 可通过环境变量强制开关方便排查SUB2API_SCHEMA_CLEAN_WARN=true/false
if v := strings.TrimSpace(os.Getenv("SUB2API_SCHEMA_CLEAN_WARN")); v != "" {
switch strings.ToLower(v) {
case "1", "true", "yes", "on":
return true
case "0", "false", "no", "off":
return false
}
}
// 默认:非 release 模式下输出debug/test
return gin.Mode() != gin.ReleaseMode
}
func warnSchemaKeyRemovedOnce(key, path string) {
if !schemaCleaningWarningsEnabled() {
return
}
if !schemaValidationKeys[key] {
return
}
if _, loaded := warnedSchemaKeys.LoadOrStore(key, struct{}{}); loaded {
return
}
log.Printf("[SchemaClean] removed unsupported JSON Schema validation field key=%q path=%q", key, path)
}
// excludedSchemaKeys 不支持的 schema 字段
// 基于 Claude API (Vertex AI) 的实际支持情况
// 支持: type, description, enum, properties, required, additionalProperties, items
// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段
var excludedSchemaKeys = map[string]bool{
// 元 schema 字段
"$schema": true,
"$id": true,
"$ref": true,
// 字符串验证Gemini 不支持)
"minLength": true,
"maxLength": true,
"pattern": true,
// 数字验证Claude API 通过 Vertex AI 不支持这些字段)
"minimum": true,
"maximum": true,
"exclusiveMinimum": true,
"exclusiveMaximum": true,
"multipleOf": true,
// 数组验证Claude API 通过 Vertex AI 不支持这些字段)
"uniqueItems": true,
"minItems": true,
"maxItems": true,
// 组合 schemaGemini 不支持)
"oneOf": true,
"anyOf": true,
"allOf": true,
"not": true,
"if": true,
"then": true,
"else": true,
"$defs": true,
"definitions": true,
// 对象验证(仅保留 properties/required/additionalProperties
"minProperties": true,
"maxProperties": true,
"patternProperties": true,
"propertyNames": true,
"dependencies": true,
"dependentSchemas": true,
"dependentRequired": true,
// 其他不支持的字段
"default": true,
"const": true,
"examples": true,
"deprecated": true,
"readOnly": true,
"writeOnly": true,
"contentMediaType": true,
"contentEncoding": true,
// Claude 特有字段
"strict": true,
}
// cleanSchemaValue 递归清理 schema 值
func cleanSchemaValue(value any, path string) any {
switch v := value.(type) {
case map[string]any:
result := make(map[string]any)
for k, val := range v {
// 跳过不支持的字段
if excludedSchemaKeys[k] {
warnSchemaKeyRemovedOnce(k, path)
continue
}
// 特殊处理 type 字段
if k == "type" {
result[k] = cleanTypeValue(val)
continue
}
// 特殊处理 format 字段:只保留 Gemini 支持的 format 值
if k == "format" {
if formatStr, ok := val.(string); ok {
// Gemini 只支持 date-time, date, time
if formatStr == "date-time" || formatStr == "date" || formatStr == "time" {
result[k] = val
}
// 其他 format 值直接跳过
}
continue
}
// 特殊处理 additionalPropertiesClaude API 只支持布尔值,不支持 schema 对象
if k == "additionalProperties" {
if boolVal, ok := val.(bool); ok {
result[k] = boolVal
} else {
// 如果是 schema 对象,转换为 false更安全的默认值
result[k] = false
}
continue
}
// 递归清理所有值
result[k] = cleanSchemaValue(val, path+"."+k)
}
return result
case []any:
// 递归处理数组中的每个元素
cleaned := make([]any, 0, len(v))
for i, item := range v {
cleaned = append(cleaned, cleanSchemaValue(item, fmt.Sprintf("%s[%d]", path, i)))
}
return cleaned
default:
return value
}
}
// cleanTypeValue 处理 type 字段,转换为大写
func cleanTypeValue(value any) any {
switch v := value.(type) {
case string:
return strings.ToUpper(v)
case []any:
// 联合类型 ["string", "null"] -> 取第一个非 null 类型
for _, t := range v {
if ts, ok := t.(string); ok && ts != "null" {
return strings.ToUpper(ts)
}
}
// 如果只有 null返回 STRING
return "STRING"
default:
return value
}
}

View File

@@ -100,7 +100,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"}
]`
t.Run("Gemini uses dummy tool_use signature", func(t *testing.T) {
t.Run("Gemini preserves provided tool_use signature", func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true)
if err != nil {
@@ -109,6 +109,23 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
if len(parts) != 1 || parts[0].FunctionCall == nil {
t.Fatalf("expected 1 functionCall part, got %+v", parts)
}
if parts[0].ThoughtSignature != "sig_tool_abc" {
t.Fatalf("expected preserved tool signature %q, got %q", "sig_tool_abc", parts[0].ThoughtSignature)
}
})
t.Run("Gemini falls back to dummy tool_use signature when missing", func(t *testing.T) {
contentNoSig := `[
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}}
]`
toolIDToName := make(map[string]string)
parts, _, err := buildParts(json.RawMessage(contentNoSig), toolIDToName, true)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
if len(parts) != 1 || parts[0].FunctionCall == nil {
t.Fatalf("expected 1 functionCall part, got %+v", parts)
}
if parts[0].ThoughtSignature != dummyThoughtSignature {
t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature)
}

View File

@@ -3,6 +3,8 @@ package antigravity
import (
"encoding/json"
"fmt"
"log"
"strings"
)
// TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式)
@@ -18,6 +20,15 @@ func TransformGeminiToClaude(geminiResp []byte, originalModel string) ([]byte, *
v1Resp.Response = directResp
v1Resp.ResponseID = directResp.ResponseID
v1Resp.ModelVersion = directResp.ModelVersion
} else if len(v1Resp.Response.Candidates) == 0 {
// 第一次解析成功但 candidates 为空,说明是直接的 GeminiResponse 格式
var directResp GeminiResponse
if err2 := json.Unmarshal(geminiResp, &directResp); err2 != nil {
return nil, nil, fmt.Errorf("parse gemini response as direct: %w", err2)
}
v1Resp.Response = directResp
v1Resp.ResponseID = directResp.ResponseID
v1Resp.ModelVersion = directResp.ModelVersion
}
// 使用处理器转换
@@ -63,6 +74,12 @@ func (p *NonStreamingProcessor) Process(geminiResp *GeminiResponse, responseID,
p.processPart(&part)
}
if len(geminiResp.Candidates) > 0 {
if grounding := geminiResp.Candidates[0].GroundingMetadata; grounding != nil {
p.processGrounding(grounding)
}
}
// 刷新剩余内容
p.flushThinking()
p.flushText()
@@ -166,16 +183,20 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
p.trailingSignature = ""
}
p.textBuilder += part.Text
// 非空 text 带签名 - 立即刷新并输出空 thinking 块
// 非空 text 带签名 - 特殊处理:先输出 text再输出空 thinking 块
if signature != "" {
p.flushText()
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "text",
Text: part.Text,
})
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking",
Thinking: "",
Signature: signature,
})
} else {
// 普通 text (无签名) - 累积到 builder
p.textBuilder += part.Text
}
}
}
@@ -190,6 +211,18 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
}
}
func (p *NonStreamingProcessor) processGrounding(grounding *GeminiGroundingMetadata) {
groundingText := buildGroundingText(grounding)
if groundingText == "" {
return
}
p.flushThinking()
p.flushText()
p.textBuilder += groundingText
p.flushText()
}
// flushText 刷新 text builder
func (p *NonStreamingProcessor) flushText() {
if p.textBuilder == "" {
@@ -223,6 +256,14 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
var finishReason string
if len(geminiResp.Candidates) > 0 {
finishReason = geminiResp.Candidates[0].FinishReason
if finishReason == "MALFORMED_FUNCTION_CALL" {
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in response for model %s", originalModel)
if geminiResp.Candidates[0].Content != nil {
if b, err := json.Marshal(geminiResp.Candidates[0].Content); err == nil {
log.Printf("[Antigravity] Malformed content: %s", string(b))
}
}
}
}
stopReason := "end_turn"
@@ -262,6 +303,44 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
}
}
func buildGroundingText(grounding *GeminiGroundingMetadata) string {
if grounding == nil {
return ""
}
var builder strings.Builder
if len(grounding.WebSearchQueries) > 0 {
_, _ = builder.WriteString("\n\n---\nWeb search queries: ")
_, _ = builder.WriteString(strings.Join(grounding.WebSearchQueries, ", "))
}
if len(grounding.GroundingChunks) > 0 {
var links []string
for i, chunk := range grounding.GroundingChunks {
if chunk.Web == nil {
continue
}
title := strings.TrimSpace(chunk.Web.Title)
if title == "" {
title = "Source"
}
uri := strings.TrimSpace(chunk.Web.URI)
if uri == "" {
uri = "#"
}
links = append(links, fmt.Sprintf("[%d] [%s](%s)", i+1, title, uri))
}
if len(links) > 0 {
_, _ = builder.WriteString("\n\nSources:\n")
_, _ = builder.WriteString(strings.Join(links, "\n"))
}
}
return builder.String()
}
// generateRandomID 生成随机 ID
func generateRandomID() string {
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"

View File

@@ -0,0 +1,519 @@
package antigravity
import (
"fmt"
"strings"
)
// CleanJSONSchema 清理 JSON Schema移除 Antigravity/Gemini 不支持的字段
// 参考 Antigravity-Manager/src-tauri/src/proxy/common/json_schema.rs 实现
// 确保 schema 符合 JSON Schema draft 2020-12 且适配 Gemini v1internal
func CleanJSONSchema(schema map[string]any) map[string]any {
if schema == nil {
return nil
}
// 0. 预处理:展开 $ref (Schema Flattening)
// (Go map 是引用的,直接修改 schema)
flattenRefs(schema, extractDefs(schema))
// 递归清理
cleaned := cleanJSONSchemaRecursive(schema)
result, ok := cleaned.(map[string]any)
if !ok {
return nil
}
return result
}
// extractDefs 提取并移除定义的 helper
func extractDefs(schema map[string]any) map[string]any {
defs := make(map[string]any)
if d, ok := schema["$defs"].(map[string]any); ok {
for k, v := range d {
defs[k] = v
}
delete(schema, "$defs")
}
if d, ok := schema["definitions"].(map[string]any); ok {
for k, v := range d {
defs[k] = v
}
delete(schema, "definitions")
}
return defs
}
// flattenRefs 递归展开 $ref
func flattenRefs(schema map[string]any, defs map[string]any) {
if len(defs) == 0 {
return // 无需展开
}
// 检查并替换 $ref
if ref, ok := schema["$ref"].(string); ok {
delete(schema, "$ref")
// 解析引用名 (例如 #/$defs/MyType -> MyType)
parts := strings.Split(ref, "/")
refName := parts[len(parts)-1]
if defSchema, exists := defs[refName]; exists {
if defMap, ok := defSchema.(map[string]any); ok {
// 合并定义内容 (不覆盖现有 key)
for k, v := range defMap {
if _, has := schema[k]; !has {
schema[k] = deepCopy(v) // 需深拷贝避免共享引用
}
}
// 递归处理刚刚合并进来的内容
flattenRefs(schema, defs)
}
}
}
// 遍历子节点
for _, v := range schema {
if subMap, ok := v.(map[string]any); ok {
flattenRefs(subMap, defs)
} else if subArr, ok := v.([]any); ok {
for _, item := range subArr {
if itemMap, ok := item.(map[string]any); ok {
flattenRefs(itemMap, defs)
}
}
}
}
}
// deepCopy 深拷贝 (简单实现,仅针对 JSON 类型)
func deepCopy(src any) any {
if src == nil {
return nil
}
switch v := src.(type) {
case map[string]any:
dst := make(map[string]any)
for k, val := range v {
dst[k] = deepCopy(val)
}
return dst
case []any:
dst := make([]any, len(v))
for i, val := range v {
dst[i] = deepCopy(val)
}
return dst
default:
return src
}
}
// cleanJSONSchemaRecursive 递归核心清理逻辑
// 返回处理后的值 (通常是 input map但可能修改内部结构)
func cleanJSONSchemaRecursive(value any) any {
schemaMap, ok := value.(map[string]any)
if !ok {
return value
}
// 0. [NEW] 合并 allOf
mergeAllOf(schemaMap)
// 1. [CRITICAL] 深度递归处理子项
if props, ok := schemaMap["properties"].(map[string]any); ok {
for _, v := range props {
cleanJSONSchemaRecursive(v)
}
// Go 中不需要像 Rust 那样显式处理 nullable_keys remove required
// 因为我们在子项处理中会正确设置 type 和 description
} else if items, ok := schemaMap["items"]; ok {
// [FIX] Gemini 期望 "items" 是单个 Schema 对象(列表验证),而不是数组(元组验证)。
if itemsArr, ok := items.([]any); ok {
// 策略:将元组 [A, B] 视为 A、B 中的最佳匹配项。
best := extractBestSchemaFromUnion(itemsArr)
if best == nil {
// 回退到通用字符串
best = map[string]any{"type": "string"}
}
// 用处理后的对象替换原有数组
cleanedBest := cleanJSONSchemaRecursive(best)
schemaMap["items"] = cleanedBest
} else {
cleanJSONSchemaRecursive(items)
}
} else {
// 遍历所有值递归
for _, v := range schemaMap {
if _, isMap := v.(map[string]any); isMap {
cleanJSONSchemaRecursive(v)
} else if arr, isArr := v.([]any); isArr {
for _, item := range arr {
cleanJSONSchemaRecursive(item)
}
}
}
}
// 2. [FIX] 处理 anyOf/oneOf 联合类型: 合并属性而非直接删除
var unionArray []any
typeStr, _ := schemaMap["type"].(string)
if typeStr == "" || typeStr == "object" {
if anyOf, ok := schemaMap["anyOf"].([]any); ok {
unionArray = anyOf
} else if oneOf, ok := schemaMap["oneOf"].([]any); ok {
unionArray = oneOf
}
}
if len(unionArray) > 0 {
if bestBranch := extractBestSchemaFromUnion(unionArray); bestBranch != nil {
if bestMap, ok := bestBranch.(map[string]any); ok {
// 合并分支内容
for k, v := range bestMap {
if k == "properties" {
targetProps, _ := schemaMap["properties"].(map[string]any)
if targetProps == nil {
targetProps = make(map[string]any)
schemaMap["properties"] = targetProps
}
if sourceProps, ok := v.(map[string]any); ok {
for pk, pv := range sourceProps {
if _, exists := targetProps[pk]; !exists {
targetProps[pk] = deepCopy(pv)
}
}
}
} else if k == "required" {
targetReq, _ := schemaMap["required"].([]any)
if sourceReq, ok := v.([]any); ok {
for _, rv := range sourceReq {
// 简单的去重添加
exists := false
for _, tr := range targetReq {
if tr == rv {
exists = true
break
}
}
if !exists {
targetReq = append(targetReq, rv)
}
}
schemaMap["required"] = targetReq
}
} else if _, exists := schemaMap[k]; !exists {
schemaMap[k] = deepCopy(v)
}
}
}
}
}
// 3. [SAFETY] 检查当前对象是否为 JSON Schema 节点
looksLikeSchema := hasKey(schemaMap, "type") ||
hasKey(schemaMap, "properties") ||
hasKey(schemaMap, "items") ||
hasKey(schemaMap, "enum") ||
hasKey(schemaMap, "anyOf") ||
hasKey(schemaMap, "oneOf") ||
hasKey(schemaMap, "allOf")
if looksLikeSchema {
// 4. [ROBUST] 约束迁移
migrateConstraints(schemaMap)
// 5. [CRITICAL] 白名单过滤
allowedFields := map[string]bool{
"type": true,
"description": true,
"properties": true,
"required": true,
"items": true,
"enum": true,
"title": true,
}
for k := range schemaMap {
if !allowedFields[k] {
delete(schemaMap, k)
}
}
// 6. [SAFETY] 处理空 Object
if t, _ := schemaMap["type"].(string); t == "object" {
hasProps := false
if props, ok := schemaMap["properties"].(map[string]any); ok && len(props) > 0 {
hasProps = true
}
if !hasProps {
schemaMap["properties"] = map[string]any{
"reason": map[string]any{
"type": "string",
"description": "Reason for calling this tool",
},
}
schemaMap["required"] = []any{"reason"}
}
}
// 7. [SAFETY] Required 字段对齐
if props, ok := schemaMap["properties"].(map[string]any); ok {
if req, ok := schemaMap["required"].([]any); ok {
var validReq []any
for _, r := range req {
if rStr, ok := r.(string); ok {
if _, exists := props[rStr]; exists {
validReq = append(validReq, r)
}
}
}
if len(validReq) > 0 {
schemaMap["required"] = validReq
} else {
delete(schemaMap, "required")
}
}
}
// 8. 处理 type 字段 (Lowercase + Nullable 提取)
isEffectivelyNullable := false
if typeVal, exists := schemaMap["type"]; exists {
var selectedType string
switch v := typeVal.(type) {
case string:
lower := strings.ToLower(v)
if lower == "null" {
isEffectivelyNullable = true
selectedType = "string" // fallback
} else {
selectedType = lower
}
case []any:
// ["string", "null"]
for _, t := range v {
if ts, ok := t.(string); ok {
lower := strings.ToLower(ts)
if lower == "null" {
isEffectivelyNullable = true
} else if selectedType == "" {
selectedType = lower
}
}
}
if selectedType == "" {
selectedType = "string"
}
}
schemaMap["type"] = selectedType
} else {
// 默认 object 如果有 properties (虽然上面白名单过滤可能删了 type 如果它不在... 但 type 必在 allowlist)
// 如果没有 type但有 properties补一个
if hasKey(schemaMap, "properties") {
schemaMap["type"] = "object"
} else {
// 默认为 string ? or object? Gemini 通常需要明确 type
schemaMap["type"] = "object"
}
}
if isEffectivelyNullable {
desc, _ := schemaMap["description"].(string)
if !strings.Contains(desc, "nullable") {
if desc != "" {
desc += " "
}
desc += "(nullable)"
schemaMap["description"] = desc
}
}
// 9. Enum 值强制转字符串
if enumVals, ok := schemaMap["enum"].([]any); ok {
hasNonString := false
for i, val := range enumVals {
if _, isStr := val.(string); !isStr {
hasNonString = true
if val == nil {
enumVals[i] = "null"
} else {
enumVals[i] = fmt.Sprintf("%v", val)
}
}
}
// If we mandated string values, we must ensure type is string
if hasNonString {
schemaMap["type"] = "string"
}
}
}
return schemaMap
}
func hasKey(m map[string]any, k string) bool {
_, ok := m[k]
return ok
}
func migrateConstraints(m map[string]any) {
constraints := []struct {
key string
label string
}{
{"minLength", "minLen"},
{"maxLength", "maxLen"},
{"pattern", "pattern"},
{"minimum", "min"},
{"maximum", "max"},
{"multipleOf", "multipleOf"},
{"exclusiveMinimum", "exclMin"},
{"exclusiveMaximum", "exclMax"},
{"minItems", "minItems"},
{"maxItems", "maxItems"},
{"propertyNames", "propertyNames"},
{"format", "format"},
}
var hints []string
for _, c := range constraints {
if val, ok := m[c.key]; ok && val != nil {
hints = append(hints, fmt.Sprintf("%s: %v", c.label, val))
}
}
if len(hints) > 0 {
suffix := fmt.Sprintf(" [Constraint: %s]", strings.Join(hints, ", "))
desc, _ := m["description"].(string)
if !strings.Contains(desc, suffix) {
m["description"] = desc + suffix
}
}
}
// mergeAllOf 合并 allOf
func mergeAllOf(m map[string]any) {
allOf, ok := m["allOf"].([]any)
if !ok {
return
}
delete(m, "allOf")
mergedProps := make(map[string]any)
mergedReq := make(map[string]bool)
otherFields := make(map[string]any)
for _, sub := range allOf {
if subMap, ok := sub.(map[string]any); ok {
// Props
if props, ok := subMap["properties"].(map[string]any); ok {
for k, v := range props {
mergedProps[k] = v
}
}
// Required
if reqs, ok := subMap["required"].([]any); ok {
for _, r := range reqs {
if s, ok := r.(string); ok {
mergedReq[s] = true
}
}
}
// Others
for k, v := range subMap {
if k != "properties" && k != "required" && k != "allOf" {
if _, exists := otherFields[k]; !exists {
otherFields[k] = v
}
}
}
}
}
// Apply
for k, v := range otherFields {
if _, exists := m[k]; !exists {
m[k] = v
}
}
if len(mergedProps) > 0 {
existProps, _ := m["properties"].(map[string]any)
if existProps == nil {
existProps = make(map[string]any)
m["properties"] = existProps
}
for k, v := range mergedProps {
if _, exists := existProps[k]; !exists {
existProps[k] = v
}
}
}
if len(mergedReq) > 0 {
existReq, _ := m["required"].([]any)
var validReqs []any
for _, r := range existReq {
if s, ok := r.(string); ok {
validReqs = append(validReqs, s)
delete(mergedReq, s) // already exists
}
}
// append new
for r := range mergedReq {
validReqs = append(validReqs, r)
}
m["required"] = validReqs
}
}
// extractBestSchemaFromUnion 从 anyOf/oneOf 中选取最佳分支
func extractBestSchemaFromUnion(unionArray []any) any {
var bestOption any
bestScore := -1
for _, item := range unionArray {
score := scoreSchemaOption(item)
if score > bestScore {
bestScore = score
bestOption = item
}
}
return bestOption
}
func scoreSchemaOption(val any) int {
m, ok := val.(map[string]any)
if !ok {
return 0
}
typeStr, _ := m["type"].(string)
if hasKey(m, "properties") || typeStr == "object" {
return 3
}
if hasKey(m, "items") || typeStr == "array" {
return 2
}
if typeStr != "" && typeStr != "null" {
return 1
}
return 0
}
// DeepCleanUndefined 深度清理值为 "[undefined]" 的字段
func DeepCleanUndefined(value any) {
if value == nil {
return
}
switch v := value.(type) {
case map[string]any:
for k, val := range v {
if s, ok := val.(string); ok && s == "[undefined]" {
delete(v, k)
continue
}
DeepCleanUndefined(val)
}
case []any:
for _, val := range v {
DeepCleanUndefined(val)
}
}
}

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"log"
"strings"
)
@@ -27,6 +28,8 @@ type StreamingProcessor struct {
pendingSignature string
trailingSignature string
originalModel string
webSearchQueries []string
groundingChunks []GeminiGroundingChunk
// 累计 usage
inputTokens int
@@ -93,9 +96,21 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
}
}
if len(geminiResp.Candidates) > 0 {
p.captureGrounding(geminiResp.Candidates[0].GroundingMetadata)
}
// 检查是否结束
if len(geminiResp.Candidates) > 0 {
finishReason := geminiResp.Candidates[0].FinishReason
if finishReason == "MALFORMED_FUNCTION_CALL" {
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in stream for model %s", p.originalModel)
if geminiResp.Candidates[0].Content != nil {
if b, err := json.Marshal(geminiResp.Candidates[0].Content); err == nil {
log.Printf("[Antigravity] Malformed content: %s", string(b))
}
}
}
if finishReason != "" {
_, _ = result.Write(p.emitFinish(finishReason))
}
@@ -200,6 +215,20 @@ func (p *StreamingProcessor) processPart(part *GeminiPart) []byte {
return result.Bytes()
}
func (p *StreamingProcessor) captureGrounding(grounding *GeminiGroundingMetadata) {
if grounding == nil {
return
}
if len(grounding.WebSearchQueries) > 0 && len(p.webSearchQueries) == 0 {
p.webSearchQueries = append([]string(nil), grounding.WebSearchQueries...)
}
if len(grounding.GroundingChunks) > 0 && len(p.groundingChunks) == 0 {
p.groundingChunks = append([]GeminiGroundingChunk(nil), grounding.GroundingChunks...)
}
}
// processThinking 处理 thinking
func (p *StreamingProcessor) processThinking(text, signature string) []byte {
var result bytes.Buffer
@@ -417,6 +446,23 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
p.trailingSignature = ""
}
if len(p.webSearchQueries) > 0 || len(p.groundingChunks) > 0 {
groundingText := buildGroundingText(&GeminiGroundingMetadata{
WebSearchQueries: p.webSearchQueries,
GroundingChunks: p.groundingChunks,
})
if groundingText != "" {
_, _ = result.Write(p.startBlock(BlockTypeText, map[string]any{
"type": "text",
"text": "",
}))
_, _ = result.Write(p.emitDelta("text_delta", map[string]any{
"text": groundingText,
}))
_, _ = result.Write(p.endBlock())
}
}
// 确定 stop_reason
stopReason := "end_turn"
if p.usedTool {

View File

@@ -16,14 +16,11 @@ type ModelsListResponse struct {
func DefaultModels() []Model {
methods := []string{"generateContent", "streamGenerateContent"}
return []Model{
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-pro", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-flash-8b", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
}
}

View File

@@ -12,10 +12,10 @@ type Model struct {
// DefaultModels is the curated Gemini model list used by the admin UI "test account" flow.
var DefaultModels = []Model{
{ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""},
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
}
// DefaultTestModel is the default model to preselect in test flows.

View File

@@ -13,20 +13,26 @@ import (
"time"
)
// Claude OAuth Constants (from CRS project)
// Claude OAuth Constants
const (
// OAuth Client ID for Claude
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
// OAuth endpoints
AuthorizeURL = "https://claude.ai/oauth/authorize"
TokenURL = "https://console.anthropic.com/v1/oauth/token"
RedirectURI = "https://console.anthropic.com/oauth/code/callback"
TokenURL = "https://platform.claude.com/v1/oauth/token"
RedirectURI = "https://platform.claude.com/oauth/code/callback"
// Scopes
ScopeProfile = "user:profile"
// Scopes - Browser URL (includes org:create_api_key for user authorization)
ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers"
// Scopes - Internal API call (org:create_api_key not supported in API)
ScopeAPI = "user:profile user:inference user:sessions:claude_code user:mcp_servers"
// Scopes - Setup token (inference only)
ScopeInference = "user:inference"
// Code Verifier character set (RFC 7636 compliant)
codeVerifierCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
// Session TTL
SessionTTL = 30 * time.Minute
)
@@ -53,7 +59,6 @@ func NewSessionStore() *SessionStore {
sessions: make(map[string]*OAuthSession),
stopCh: make(chan struct{}),
}
// Start cleanup goroutine
go store.cleanup()
return store
}
@@ -78,7 +83,6 @@ func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
if !ok {
return nil, false
}
// Check if expired
if time.Since(session.CreatedAt) > SessionTTL {
return nil, false
}
@@ -122,13 +126,13 @@ func GenerateRandomBytes(n int) ([]byte, error) {
return b, nil
}
// GenerateState generates a random state string for OAuth
// GenerateState generates a random state string for OAuth (base64url encoded)
func GenerateState() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
return base64URLEncode(bytes), nil
}
// GenerateSessionID generates a unique session ID
@@ -140,13 +144,30 @@ func GenerateSessionID() (string, error) {
return hex.EncodeToString(bytes), nil
}
// GenerateCodeVerifier generates a PKCE code verifier (32 bytes -> base64url)
// GenerateCodeVerifier generates a PKCE code verifier using character set method
func GenerateCodeVerifier() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
const targetLen = 32
charsetLen := len(codeVerifierCharset)
limit := 256 - (256 % charsetLen)
result := make([]byte, 0, targetLen)
randBuf := make([]byte, targetLen*2)
for len(result) < targetLen {
if _, err := rand.Read(randBuf); err != nil {
return "", err
}
for _, b := range randBuf {
if int(b) < limit {
result = append(result, codeVerifierCharset[int(b)%charsetLen])
if len(result) >= targetLen {
break
}
}
}
}
return base64URLEncode(bytes), nil
return base64URLEncode(result), nil
}
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
@@ -158,42 +179,31 @@ func GenerateCodeChallenge(verifier string) string {
// base64URLEncode encodes bytes to base64url without padding
func base64URLEncode(data []byte) string {
encoded := base64.URLEncoding.EncodeToString(data)
// Remove padding
return strings.TrimRight(encoded, "=")
}
// BuildAuthorizationURL builds the OAuth authorization URL
// BuildAuthorizationURL builds the OAuth authorization URL with correct parameter order
func BuildAuthorizationURL(state, codeChallenge, scope string) string {
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", ClientID)
params.Set("redirect_uri", RedirectURI)
params.Set("scope", scope)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
encodedRedirectURI := url.QueryEscape(RedirectURI)
encodedScope := strings.ReplaceAll(url.QueryEscape(scope), "%20", "+")
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
// TokenRequest represents the token exchange request body
type TokenRequest struct {
GrantType string `json:"grant_type"`
ClientID string `json:"client_id"`
Code string `json:"code"`
RedirectURI string `json:"redirect_uri"`
CodeVerifier string `json:"code_verifier"`
State string `json:"state"`
return fmt.Sprintf("%s?code=true&client_id=%s&response_type=code&redirect_uri=%s&scope=%s&code_challenge=%s&code_challenge_method=S256&state=%s",
AuthorizeURL,
ClientID,
encodedRedirectURI,
encodedScope,
codeChallenge,
state,
)
}
// TokenResponse represents the token response from OAuth provider
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
// Organization and Account info from OAuth response
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
Organization *OrgInfo `json:"organization,omitempty"`
Account *AccountInfo `json:"account,omitempty"`
}
@@ -205,33 +215,6 @@ type OrgInfo struct {
// AccountInfo represents account info from OAuth response
type AccountInfo struct {
UUID string `json:"uuid"`
}
// RefreshTokenRequest represents the refresh token request
type RefreshTokenRequest struct {
GrantType string `json:"grant_type"`
RefreshToken string `json:"refresh_token"`
ClientID string `json:"client_id"`
}
// BuildTokenRequest creates a token exchange request
func BuildTokenRequest(code, codeVerifier, state string) *TokenRequest {
return &TokenRequest{
GrantType: "authorization_code",
ClientID: ClientID,
Code: code,
RedirectURI: RedirectURI,
CodeVerifier: codeVerifier,
State: state,
}
}
// BuildRefreshTokenRequest creates a refresh token request
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
return &RefreshTokenRequest{
GrantType: "refresh_token",
RefreshToken: refreshToken,
ClientID: ClientID,
}
UUID string `json:"uuid"`
EmailAddress string `json:"email_address"`
}

View File

@@ -162,11 +162,11 @@ func ParsePagination(c *gin.Context) (page, pageSize int) {
// 支持 page_size 和 limit 两种参数名
if ps := c.Query("page_size"); ps != "" {
if val, err := parseInt(ps); err == nil && val > 0 && val <= 100 {
if val, err := parseInt(ps); err == nil && val > 0 && val <= 1000 {
pageSize = val
}
} else if l := c.Query("limit"); l != "" {
if val, err := parseInt(l); err == nil && val > 0 && val <= 100 {
if val, err := parseInt(l); err == nil && val > 0 && val <= 1000 {
pageSize = val
}
}

View File

@@ -0,0 +1,568 @@
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
// It uses the utls library to create TLS connections that mimic Node.js/Claude Code clients.
package tlsfingerprint
import (
"bufio"
"context"
"encoding/base64"
"fmt"
"log/slog"
"net"
"net/http"
"net/url"
utls "github.com/refraction-networking/utls"
"golang.org/x/net/proxy"
)
// Profile contains TLS fingerprint configuration.
type Profile struct {
Name string // Profile name for identification
CipherSuites []uint16
Curves []uint16
PointFormats []uint8
EnableGREASE bool
}
// Dialer creates TLS connections with custom fingerprints.
type Dialer struct {
profile *Profile
baseDialer func(ctx context.Context, network, addr string) (net.Conn, error)
}
// HTTPProxyDialer creates TLS connections through HTTP/HTTPS proxies with custom fingerprints.
// It handles the CONNECT tunnel establishment before performing TLS handshake.
type HTTPProxyDialer struct {
profile *Profile
proxyURL *url.URL
}
// SOCKS5ProxyDialer creates TLS connections through SOCKS5 proxies with custom fingerprints.
// It uses golang.org/x/net/proxy to establish the SOCKS5 tunnel.
type SOCKS5ProxyDialer struct {
profile *Profile
proxyURL *url.URL
}
// Default TLS fingerprint values captured from Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)
// Captured using: tshark -i lo -f "tcp port 8443" -Y "tls.handshake.type == 1" -V
// JA3 Hash: 1a28e69016765d92e3b381168d68922c
//
// Note: JA3/JA4 may have slight variations due to:
// - Session ticket presence/absence
// - Extension negotiation state
var (
// defaultCipherSuites contains all 59 cipher suites from Claude CLI
// Order is critical for JA3 fingerprint matching
defaultCipherSuites = []uint16{
// TLS 1.3 cipher suites (MUST be first)
0x1302, // TLS_AES_256_GCM_SHA384
0x1303, // TLS_CHACHA20_POLY1305_SHA256
0x1301, // TLS_AES_128_GCM_SHA256
// ECDHE + AES-GCM
0xc02f, // TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
0xc02b, // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
0xc030, // TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
0xc02c, // TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
// DHE + AES-GCM
0x009e, // TLS_DHE_RSA_WITH_AES_128_GCM_SHA256
// ECDHE/DHE + AES-CBC-SHA256/384
0xc027, // TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256
0x0067, // TLS_DHE_RSA_WITH_AES_128_CBC_SHA256
0xc028, // TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384
0x006b, // TLS_DHE_RSA_WITH_AES_256_CBC_SHA256
// DHE-DSS/RSA + AES-GCM
0x00a3, // TLS_DHE_DSS_WITH_AES_256_GCM_SHA384
0x009f, // TLS_DHE_RSA_WITH_AES_256_GCM_SHA384
// ChaCha20-Poly1305
0xcca9, // TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
0xcca8, // TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
0xccaa, // TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256
// AES-CCM (256-bit)
0xc0af, // TLS_ECDHE_ECDSA_WITH_AES_256_CCM_8
0xc0ad, // TLS_ECDHE_ECDSA_WITH_AES_256_CCM
0xc0a3, // TLS_DHE_RSA_WITH_AES_256_CCM_8
0xc09f, // TLS_DHE_RSA_WITH_AES_256_CCM
// ARIA (256-bit)
0xc05d, // TLS_ECDHE_ECDSA_WITH_ARIA_256_GCM_SHA384
0xc061, // TLS_ECDHE_RSA_WITH_ARIA_256_GCM_SHA384
0xc057, // TLS_DHE_DSS_WITH_ARIA_256_GCM_SHA384
0xc053, // TLS_DHE_RSA_WITH_ARIA_256_GCM_SHA384
// DHE-DSS + AES-GCM (128-bit)
0x00a2, // TLS_DHE_DSS_WITH_AES_128_GCM_SHA256
// AES-CCM (128-bit)
0xc0ae, // TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8
0xc0ac, // TLS_ECDHE_ECDSA_WITH_AES_128_CCM
0xc0a2, // TLS_DHE_RSA_WITH_AES_128_CCM_8
0xc09e, // TLS_DHE_RSA_WITH_AES_128_CCM
// ARIA (128-bit)
0xc05c, // TLS_ECDHE_ECDSA_WITH_ARIA_128_GCM_SHA256
0xc060, // TLS_ECDHE_RSA_WITH_ARIA_128_GCM_SHA256
0xc056, // TLS_DHE_DSS_WITH_ARIA_128_GCM_SHA256
0xc052, // TLS_DHE_RSA_WITH_ARIA_128_GCM_SHA256
// ECDHE/DHE + AES-CBC-SHA384/256 (more)
0xc024, // TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384
0x006a, // TLS_DHE_DSS_WITH_AES_256_CBC_SHA256
0xc023, // TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256
0x0040, // TLS_DHE_DSS_WITH_AES_128_CBC_SHA256
// ECDHE/DHE + AES-CBC-SHA (legacy)
0xc00a, // TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA
0xc014, // TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA
0x0039, // TLS_DHE_RSA_WITH_AES_256_CBC_SHA
0x0038, // TLS_DHE_DSS_WITH_AES_256_CBC_SHA
0xc009, // TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA
0xc013, // TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA
0x0033, // TLS_DHE_RSA_WITH_AES_128_CBC_SHA
0x0032, // TLS_DHE_DSS_WITH_AES_128_CBC_SHA
// RSA + AES-GCM/CCM/ARIA (non-PFS, 256-bit)
0x009d, // TLS_RSA_WITH_AES_256_GCM_SHA384
0xc0a1, // TLS_RSA_WITH_AES_256_CCM_8
0xc09d, // TLS_RSA_WITH_AES_256_CCM
0xc051, // TLS_RSA_WITH_ARIA_256_GCM_SHA384
// RSA + AES-GCM/CCM/ARIA (non-PFS, 128-bit)
0x009c, // TLS_RSA_WITH_AES_128_GCM_SHA256
0xc0a0, // TLS_RSA_WITH_AES_128_CCM_8
0xc09c, // TLS_RSA_WITH_AES_128_CCM
0xc050, // TLS_RSA_WITH_ARIA_128_GCM_SHA256
// RSA + AES-CBC (non-PFS, legacy)
0x003d, // TLS_RSA_WITH_AES_256_CBC_SHA256
0x003c, // TLS_RSA_WITH_AES_128_CBC_SHA256
0x0035, // TLS_RSA_WITH_AES_256_CBC_SHA
0x002f, // TLS_RSA_WITH_AES_128_CBC_SHA
// Renegotiation indication
0x00ff, // TLS_EMPTY_RENEGOTIATION_INFO_SCSV
}
// defaultCurves contains the 10 supported groups from Claude CLI (including FFDHE)
defaultCurves = []utls.CurveID{
utls.X25519, // 0x001d
utls.CurveP256, // 0x0017 (secp256r1)
utls.CurveID(0x001e), // x448
utls.CurveP521, // 0x0019 (secp521r1)
utls.CurveP384, // 0x0018 (secp384r1)
utls.CurveID(0x0100), // ffdhe2048
utls.CurveID(0x0101), // ffdhe3072
utls.CurveID(0x0102), // ffdhe4096
utls.CurveID(0x0103), // ffdhe6144
utls.CurveID(0x0104), // ffdhe8192
}
// defaultPointFormats contains all 3 point formats from Claude CLI
defaultPointFormats = []uint8{
0, // uncompressed
1, // ansiX962_compressed_prime
2, // ansiX962_compressed_char2
}
// defaultSignatureAlgorithms contains the 20 signature algorithms from Claude CLI
defaultSignatureAlgorithms = []utls.SignatureScheme{
0x0403, // ecdsa_secp256r1_sha256
0x0503, // ecdsa_secp384r1_sha384
0x0603, // ecdsa_secp521r1_sha512
0x0807, // ed25519
0x0808, // ed448
0x0809, // rsa_pss_pss_sha256
0x080a, // rsa_pss_pss_sha384
0x080b, // rsa_pss_pss_sha512
0x0804, // rsa_pss_rsae_sha256
0x0805, // rsa_pss_rsae_sha384
0x0806, // rsa_pss_rsae_sha512
0x0401, // rsa_pkcs1_sha256
0x0501, // rsa_pkcs1_sha384
0x0601, // rsa_pkcs1_sha512
0x0303, // ecdsa_sha224
0x0301, // rsa_pkcs1_sha224
0x0302, // dsa_sha224
0x0402, // dsa_sha256
0x0502, // dsa_sha384
0x0602, // dsa_sha512
}
)
// NewDialer creates a new TLS fingerprint dialer.
// baseDialer is used for TCP connection establishment (supports proxy scenarios).
// If baseDialer is nil, direct TCP dial is used.
func NewDialer(profile *Profile, baseDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *Dialer {
if baseDialer == nil {
baseDialer = (&net.Dialer{}).DialContext
}
return &Dialer{profile: profile, baseDialer: baseDialer}
}
// NewHTTPProxyDialer creates a new TLS fingerprint dialer that works through HTTP/HTTPS proxies.
// It establishes a CONNECT tunnel before performing TLS handshake with custom fingerprint.
func NewHTTPProxyDialer(profile *Profile, proxyURL *url.URL) *HTTPProxyDialer {
return &HTTPProxyDialer{profile: profile, proxyURL: proxyURL}
}
// NewSOCKS5ProxyDialer creates a new TLS fingerprint dialer that works through SOCKS5 proxies.
// It establishes a SOCKS5 tunnel before performing TLS handshake with custom fingerprint.
func NewSOCKS5ProxyDialer(profile *Profile, proxyURL *url.URL) *SOCKS5ProxyDialer {
return &SOCKS5ProxyDialer{profile: profile, proxyURL: proxyURL}
}
// DialTLSContext establishes a TLS connection through SOCKS5 proxy with the configured fingerprint.
// Flow: SOCKS5 CONNECT to target -> TLS handshake with utls on the tunnel
func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
slog.Debug("tls_fingerprint_socks5_connecting", "proxy", d.proxyURL.Host, "target", addr)
// Step 1: Create SOCKS5 dialer
var auth *proxy.Auth
if d.proxyURL.User != nil {
username := d.proxyURL.User.Username()
password, _ := d.proxyURL.User.Password()
auth = &proxy.Auth{
User: username,
Password: password,
}
}
// Determine proxy address
proxyAddr := d.proxyURL.Host
if d.proxyURL.Port() == "" {
proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "1080") // Default SOCKS5 port
}
socksDialer, err := proxy.SOCKS5("tcp", proxyAddr, auth, proxy.Direct)
if err != nil {
slog.Debug("tls_fingerprint_socks5_dialer_failed", "error", err)
return nil, fmt.Errorf("create SOCKS5 dialer: %w", err)
}
// Step 2: Establish SOCKS5 tunnel to target
slog.Debug("tls_fingerprint_socks5_establishing_tunnel", "target", addr)
conn, err := socksDialer.Dial("tcp", addr)
if err != nil {
slog.Debug("tls_fingerprint_socks5_connect_failed", "error", err)
return nil, fmt.Errorf("SOCKS5 connect: %w", err)
}
slog.Debug("tls_fingerprint_socks5_tunnel_established")
// Step 3: Perform TLS handshake on the tunnel with utls fingerprint
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
slog.Debug("tls_fingerprint_socks5_starting_handshake", "host", host)
// Build ClientHello specification from profile (Node.js/Claude CLI fingerprint)
spec := buildClientHelloSpecFromProfile(d.profile)
slog.Debug("tls_fingerprint_socks5_clienthello_spec",
"cipher_suites", len(spec.CipherSuites),
"extensions", len(spec.Extensions),
"compression_methods", spec.CompressionMethods,
"tls_vers_max", fmt.Sprintf("0x%04x", spec.TLSVersMax),
"tls_vers_min", fmt.Sprintf("0x%04x", spec.TLSVersMin))
if d.profile != nil {
slog.Debug("tls_fingerprint_socks5_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
}
// Create uTLS connection on the tunnel
tlsConn := utls.UClient(conn, &utls.Config{
ServerName: host,
}, utls.HelloCustom)
if err := tlsConn.ApplyPreset(spec); err != nil {
slog.Debug("tls_fingerprint_socks5_apply_preset_failed", "error", err)
_ = conn.Close()
return nil, fmt.Errorf("apply TLS preset: %w", err)
}
if err := tlsConn.Handshake(); err != nil {
slog.Debug("tls_fingerprint_socks5_handshake_failed", "error", err)
_ = conn.Close()
return nil, fmt.Errorf("TLS handshake failed: %w", err)
}
state := tlsConn.ConnectionState()
slog.Debug("tls_fingerprint_socks5_handshake_success",
"version", fmt.Sprintf("0x%04x", state.Version),
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
"alpn", state.NegotiatedProtocol)
return tlsConn, nil
}
// DialTLSContext establishes a TLS connection through HTTP proxy with the configured fingerprint.
// Flow: TCP connect to proxy -> CONNECT tunnel -> TLS handshake with utls
func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
slog.Debug("tls_fingerprint_http_proxy_connecting", "proxy", d.proxyURL.Host, "target", addr)
// Step 1: TCP connect to proxy server
var proxyAddr string
if d.proxyURL.Port() != "" {
proxyAddr = d.proxyURL.Host
} else {
// Default ports
if d.proxyURL.Scheme == "https" {
proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "443")
} else {
proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "80")
}
}
dialer := &net.Dialer{}
conn, err := dialer.DialContext(ctx, "tcp", proxyAddr)
if err != nil {
slog.Debug("tls_fingerprint_http_proxy_connect_failed", "error", err)
return nil, fmt.Errorf("connect to proxy: %w", err)
}
slog.Debug("tls_fingerprint_http_proxy_connected", "proxy_addr", proxyAddr)
// Step 2: Send CONNECT request to establish tunnel
req := &http.Request{
Method: "CONNECT",
URL: &url.URL{Opaque: addr},
Host: addr,
Header: make(http.Header),
}
// Add proxy authentication if present
if d.proxyURL.User != nil {
username := d.proxyURL.User.Username()
password, _ := d.proxyURL.User.Password()
auth := base64.StdEncoding.EncodeToString([]byte(username + ":" + password))
req.Header.Set("Proxy-Authorization", "Basic "+auth)
}
slog.Debug("tls_fingerprint_http_proxy_sending_connect", "target", addr)
if err := req.Write(conn); err != nil {
_ = conn.Close()
slog.Debug("tls_fingerprint_http_proxy_write_failed", "error", err)
return nil, fmt.Errorf("write CONNECT request: %w", err)
}
// Step 3: Read CONNECT response
br := bufio.NewReader(conn)
resp, err := http.ReadResponse(br, req)
if err != nil {
_ = conn.Close()
slog.Debug("tls_fingerprint_http_proxy_read_response_failed", "error", err)
return nil, fmt.Errorf("read CONNECT response: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
_ = conn.Close()
slog.Debug("tls_fingerprint_http_proxy_connect_failed_status", "status_code", resp.StatusCode, "status", resp.Status)
return nil, fmt.Errorf("proxy CONNECT failed: %s", resp.Status)
}
slog.Debug("tls_fingerprint_http_proxy_tunnel_established")
// Step 4: Perform TLS handshake on the tunnel with utls fingerprint
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
slog.Debug("tls_fingerprint_http_proxy_starting_handshake", "host", host)
// Build ClientHello specification (reuse the shared method)
spec := buildClientHelloSpecFromProfile(d.profile)
slog.Debug("tls_fingerprint_http_proxy_clienthello_spec",
"cipher_suites", len(spec.CipherSuites),
"extensions", len(spec.Extensions))
if d.profile != nil {
slog.Debug("tls_fingerprint_http_proxy_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
}
// Create uTLS connection on the tunnel
// Note: TLS 1.3 cipher suites are handled automatically by utls when TLS 1.3 is in SupportedVersions
tlsConn := utls.UClient(conn, &utls.Config{
ServerName: host,
}, utls.HelloCustom)
if err := tlsConn.ApplyPreset(spec); err != nil {
slog.Debug("tls_fingerprint_http_proxy_apply_preset_failed", "error", err)
_ = conn.Close()
return nil, fmt.Errorf("apply TLS preset: %w", err)
}
if err := tlsConn.HandshakeContext(ctx); err != nil {
slog.Debug("tls_fingerprint_http_proxy_handshake_failed", "error", err)
_ = conn.Close()
return nil, fmt.Errorf("TLS handshake failed: %w", err)
}
state := tlsConn.ConnectionState()
slog.Debug("tls_fingerprint_http_proxy_handshake_success",
"version", fmt.Sprintf("0x%04x", state.Version),
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
"alpn", state.NegotiatedProtocol)
return tlsConn, nil
}
// DialTLSContext establishes a TLS connection with the configured fingerprint.
// This method is designed to be used as http.Transport.DialTLSContext.
func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
// Establish TCP connection using base dialer (supports proxy)
slog.Debug("tls_fingerprint_dialing_tcp", "addr", addr)
conn, err := d.baseDialer(ctx, network, addr)
if err != nil {
slog.Debug("tls_fingerprint_tcp_dial_failed", "error", err)
return nil, err
}
slog.Debug("tls_fingerprint_tcp_connected", "addr", addr)
// Extract hostname for SNI
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
slog.Debug("tls_fingerprint_sni_hostname", "host", host)
// Build ClientHello specification
spec := d.buildClientHelloSpec()
slog.Debug("tls_fingerprint_clienthello_spec",
"cipher_suites", len(spec.CipherSuites),
"extensions", len(spec.Extensions))
// Log profile info
if d.profile != nil {
slog.Debug("tls_fingerprint_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
} else {
slog.Debug("tls_fingerprint_using_default_profile")
}
// Create uTLS connection
// Note: TLS 1.3 cipher suites are handled automatically by utls when TLS 1.3 is in SupportedVersions
tlsConn := utls.UClient(conn, &utls.Config{
ServerName: host,
}, utls.HelloCustom)
// Apply fingerprint
if err := tlsConn.ApplyPreset(spec); err != nil {
slog.Debug("tls_fingerprint_apply_preset_failed", "error", err)
_ = conn.Close()
return nil, err
}
slog.Debug("tls_fingerprint_preset_applied")
// Perform TLS handshake
if err := tlsConn.HandshakeContext(ctx); err != nil {
slog.Debug("tls_fingerprint_handshake_failed",
"error", err,
"local_addr", conn.LocalAddr(),
"remote_addr", conn.RemoteAddr())
_ = conn.Close()
return nil, fmt.Errorf("TLS handshake failed: %w", err)
}
// Log successful handshake details
state := tlsConn.ConnectionState()
slog.Debug("tls_fingerprint_handshake_success",
"version", fmt.Sprintf("0x%04x", state.Version),
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
"alpn", state.NegotiatedProtocol)
return tlsConn, nil
}
// buildClientHelloSpec constructs the ClientHello specification based on the profile.
func (d *Dialer) buildClientHelloSpec() *utls.ClientHelloSpec {
return buildClientHelloSpecFromProfile(d.profile)
}
// toUTLSCurves converts uint16 slice to utls.CurveID slice.
func toUTLSCurves(curves []uint16) []utls.CurveID {
result := make([]utls.CurveID, len(curves))
for i, c := range curves {
result[i] = utls.CurveID(c)
}
return result
}
// buildClientHelloSpecFromProfile constructs ClientHelloSpec from a Profile.
// This is a standalone function that can be used by both Dialer and HTTPProxyDialer.
func buildClientHelloSpecFromProfile(profile *Profile) *utls.ClientHelloSpec {
// Get cipher suites
var cipherSuites []uint16
if profile != nil && len(profile.CipherSuites) > 0 {
cipherSuites = profile.CipherSuites
} else {
cipherSuites = defaultCipherSuites
}
// Get curves
var curves []utls.CurveID
if profile != nil && len(profile.Curves) > 0 {
curves = toUTLSCurves(profile.Curves)
} else {
curves = defaultCurves
}
// Get point formats
var pointFormats []uint8
if profile != nil && len(profile.PointFormats) > 0 {
pointFormats = profile.PointFormats
} else {
pointFormats = defaultPointFormats
}
// Check if GREASE is enabled
enableGREASE := profile != nil && profile.EnableGREASE
extensions := make([]utls.TLSExtension, 0, 16)
if enableGREASE {
extensions = append(extensions, &utls.UtlsGREASEExtension{})
}
// SNI extension - MUST be explicitly added for HelloCustom mode
// utls will populate the server name from Config.ServerName
extensions = append(extensions, &utls.SNIExtension{})
// Claude CLI extension order (captured from tshark):
// server_name(0), ec_point_formats(11), supported_groups(10), session_ticket(35),
// alpn(16), encrypt_then_mac(22), extended_master_secret(23),
// signature_algorithms(13), supported_versions(43),
// psk_key_exchange_modes(45), key_share(51)
extensions = append(extensions,
&utls.SupportedPointsExtension{SupportedPoints: pointFormats},
&utls.SupportedCurvesExtension{Curves: curves},
&utls.SessionTicketExtension{},
&utls.ALPNExtension{AlpnProtocols: []string{"http/1.1"}},
&utls.GenericExtension{Id: 22},
&utls.ExtendedMasterSecretExtension{},
&utls.SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: defaultSignatureAlgorithms},
&utls.SupportedVersionsExtension{Versions: []uint16{
utls.VersionTLS13,
utls.VersionTLS12,
}},
&utls.PSKKeyExchangeModesExtension{Modes: []uint8{utls.PskModeDHE}},
&utls.KeyShareExtension{KeyShares: []utls.KeyShare{
{Group: utls.X25519},
}},
)
if enableGREASE {
extensions = append(extensions, &utls.UtlsGREASEExtension{})
}
return &utls.ClientHelloSpec{
CipherSuites: cipherSuites,
CompressionMethods: []uint8{0}, // null compression only (standard)
Extensions: extensions,
TLSVersMax: utls.VersionTLS13,
TLSVersMin: utls.VersionTLS10,
}
}

View File

@@ -0,0 +1,278 @@
//go:build integration
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
//
// Integration tests for verifying TLS fingerprint correctness.
// These tests make actual network requests to external services and should be run manually.
//
// Run with: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
package tlsfingerprint
import (
"context"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"time"
)
// skipIfExternalServiceUnavailable checks if the external service is available.
// If not, it skips the test instead of failing.
func skipIfExternalServiceUnavailable(t *testing.T, err error) {
t.Helper()
if err != nil {
// Check for common network/TLS errors that indicate external service issues
errStr := err.Error()
if strings.Contains(errStr, "certificate has expired") ||
strings.Contains(errStr, "certificate is not yet valid") ||
strings.Contains(errStr, "connection refused") ||
strings.Contains(errStr, "no such host") ||
strings.Contains(errStr, "network is unreachable") ||
strings.Contains(errStr, "timeout") {
t.Skipf("skipping test: external service unavailable: %v", err)
}
t.Fatalf("failed to get fingerprint: %v", err)
}
}
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
// This test uses tls.peet.ws to verify the fingerprint.
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
func TestJA3Fingerprint(t *testing.T) {
// Skip if network is unavailable or if running in short mode
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
profile := &Profile{
Name: "Claude CLI Test",
EnableGREASE: false,
}
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
// Use tls.peet.ws fingerprint detection API
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
skipIfExternalServiceUnavailable(t, err)
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
}
// Log all fingerprint information
t.Logf("JA3: %s", fpResp.TLS.JA3)
t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
t.Logf("JA4: %s", fpResp.TLS.JA4)
t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
// Verify JA3 hash matches expected value
expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
if fpResp.TLS.JA3Hash == expectedJA3Hash {
t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
} else {
t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
}
// Verify JA4 fingerprint
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
// The suffix _a33745022dd6_1f22a2ca17c4 should match
expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
} else {
t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
}
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
// d = domain (SNI present), i = IP (no SNI)
// Since we connect to tls.peet.ws (domain), we expect 'd'
expectedJA4Prefix := "t13d5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
} else {
// Also accept 'i' variant for IP connections
altPrefix := "t13i5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
} else {
t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
}
}
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
} else {
t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
}
// Verify extension list (should be 11 extensions including SNI)
// Expected: 0-11-10-35-16-22-23-13-43-45-51
expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
} else {
t.Logf("Warning: JA3 extension list may differ")
}
}
// TestProfileExpectation defines expected fingerprint values for a profile.
type TestProfileExpectation struct {
Profile *Profile
ExpectedJA3 string // Expected JA3 hash (empty = don't check)
ExpectedJA4 string // Expected full JA4 (empty = don't check)
JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
}
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
// Run with: go test -v -tags=integration -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
func TestAllProfiles(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
// Define all profiles to test with their expected fingerprints
// These profiles are from config.yaml gateway.tls_fingerprint.profiles
profiles := []TestProfileExpectation{
{
// Linux x64 Node.js v22.17.1
// Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
Profile: &Profile{
Name: "linux_x64_node_v22171",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part
},
{
// MacOS arm64 Node.js v22.18.0
// Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
// Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
Profile: &Profile{
Name: "macos_arm64_node_v22180",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
},
}
for _, tc := range profiles {
tc := tc // capture range variable
t.Run(tc.Profile.Name, func(t *testing.T) {
fp := fetchFingerprint(t, tc.Profile)
if fp == nil {
return // fetchFingerprint already called t.Fatal
}
t.Logf("Profile: %s", tc.Profile.Name)
t.Logf(" JA3: %s", fp.JA3)
t.Logf(" JA3 Hash: %s", fp.JA3Hash)
t.Logf(" JA4: %s", fp.JA4)
t.Logf(" PeetPrint: %s", fp.PeetPrint)
t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
// Verify expectations
if tc.ExpectedJA3 != "" {
if fp.JA3Hash == tc.ExpectedJA3 {
t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
} else {
t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
}
}
if tc.ExpectedJA4 != "" {
if fp.JA4 == tc.ExpectedJA4 {
t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
} else {
t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
}
}
// Check JA4 cipher hash (stable middle part)
// JA4 format: prefix_cipherHash_extHash
if tc.JA4CipherHash != "" {
if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
} else {
t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
}
}
})
}
}
// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
t.Helper()
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
return nil
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
skipIfExternalServiceUnavailable(t, err)
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
return nil
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
return nil
}
return &fpResp.TLS
}

View File

@@ -0,0 +1,160 @@
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
//
// Unit tests for TLS fingerprint dialer.
// Integration tests that require external network are in dialer_integration_test.go
// and require the 'integration' build tag.
//
// Run unit tests: go test -v ./internal/pkg/tlsfingerprint/...
// Run integration tests: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
package tlsfingerprint
import (
"net/url"
"testing"
)
// FingerprintResponse represents the response from tls.peet.ws/api/all.
type FingerprintResponse struct {
IP string `json:"ip"`
TLS TLSInfo `json:"tls"`
HTTP2 any `json:"http2"`
}
// TLSInfo contains TLS fingerprint details.
type TLSInfo struct {
JA3 string `json:"ja3"`
JA3Hash string `json:"ja3_hash"`
JA4 string `json:"ja4"`
PeetPrint string `json:"peetprint"`
PeetPrintHash string `json:"peetprint_hash"`
ClientRandom string `json:"client_random"`
SessionID string `json:"session_id"`
}
// TestDialerWithProfile tests that different profiles produce different fingerprints.
func TestDialerWithProfile(t *testing.T) {
// Create two dialers with different profiles
profile1 := &Profile{
Name: "Profile 1 - No GREASE",
EnableGREASE: false,
}
profile2 := &Profile{
Name: "Profile 2 - With GREASE",
EnableGREASE: true,
}
dialer1 := NewDialer(profile1, nil)
dialer2 := NewDialer(profile2, nil)
// Build specs and compare
// Note: We can't directly compare JA3 without making network requests
// but we can verify the specs are different
spec1 := dialer1.buildClientHelloSpec()
spec2 := dialer2.buildClientHelloSpec()
// Profile with GREASE should have more extensions
if len(spec2.Extensions) <= len(spec1.Extensions) {
t.Error("expected GREASE profile to have more extensions")
}
}
// TestHTTPProxyDialerBasic tests HTTP proxy dialer creation.
// Note: This is a unit test - actual proxy testing requires a proxy server.
func TestHTTPProxyDialerBasic(t *testing.T) {
profile := &Profile{
Name: "Test Profile",
EnableGREASE: false,
}
// Test that dialer is created without panic
proxyURL := mustParseURL("http://proxy.example.com:8080")
dialer := NewHTTPProxyDialer(profile, proxyURL)
if dialer == nil {
t.Fatal("expected dialer to be created")
}
if dialer.profile != profile {
t.Error("expected profile to be set")
}
if dialer.proxyURL != proxyURL {
t.Error("expected proxyURL to be set")
}
}
// TestSOCKS5ProxyDialerBasic tests SOCKS5 proxy dialer creation.
// Note: This is a unit test - actual proxy testing requires a proxy server.
func TestSOCKS5ProxyDialerBasic(t *testing.T) {
profile := &Profile{
Name: "Test Profile",
EnableGREASE: false,
}
// Test that dialer is created without panic
proxyURL := mustParseURL("socks5://proxy.example.com:1080")
dialer := NewSOCKS5ProxyDialer(profile, proxyURL)
if dialer == nil {
t.Fatal("expected dialer to be created")
}
if dialer.profile != profile {
t.Error("expected profile to be set")
}
if dialer.proxyURL != proxyURL {
t.Error("expected proxyURL to be set")
}
}
// TestBuildClientHelloSpec tests ClientHello spec construction.
func TestBuildClientHelloSpec(t *testing.T) {
// Test with nil profile (should use defaults)
spec := buildClientHelloSpecFromProfile(nil)
if len(spec.CipherSuites) == 0 {
t.Error("expected cipher suites to be set")
}
if len(spec.Extensions) == 0 {
t.Error("expected extensions to be set")
}
// Verify default cipher suites are used
if len(spec.CipherSuites) != len(defaultCipherSuites) {
t.Errorf("expected %d cipher suites, got %d", len(defaultCipherSuites), len(spec.CipherSuites))
}
// Test with custom profile
customProfile := &Profile{
Name: "Custom",
EnableGREASE: false,
CipherSuites: []uint16{0x1301, 0x1302},
}
spec = buildClientHelloSpecFromProfile(customProfile)
if len(spec.CipherSuites) != 2 {
t.Errorf("expected 2 cipher suites, got %d", len(spec.CipherSuites))
}
}
// TestToUTLSCurves tests curve ID conversion.
func TestToUTLSCurves(t *testing.T) {
input := []uint16{0x001d, 0x0017, 0x0018}
result := toUTLSCurves(input)
if len(result) != len(input) {
t.Errorf("expected %d curves, got %d", len(input), len(result))
}
for i, curve := range result {
if uint16(curve) != input[i] {
t.Errorf("curve %d: expected 0x%04x, got 0x%04x", i, input[i], uint16(curve))
}
}
}
// Helper function to parse URL without error handling.
func mustParseURL(rawURL string) *url.URL {
u, err := url.Parse(rawURL)
if err != nil {
panic(err)
}
return u
}

Some files were not shown because too many files have changed in this diff Show More