Compare commits

..

115 Commits

Author SHA1 Message Date
shaw
7b1d63a786 fix(types): 添加缺失的 ignore_invalid_api_key_errors 类型定义
OpsAdvancedSettings 接口缺少 ignore_invalid_api_key_errors 字段,
导致 TypeScript 编译报错。
2026-02-02 21:01:32 +08:00
Wesley Liddick
e204b4d81f Merge pull request #452 from s-Joshua-s/feat/enhance-usage-endpoint
feat(gateway): 增强 /v1/usage 端点返回完整用量统计
2026-02-02 20:35:00 +08:00
Wesley Liddick
325ed747d8 Merge pull request #455 from ZeroClover/feat/ops-ignore-invalid-api-key-errors
feat(ops): 支持过滤无效 API Key 错误,不写入错误日志
2026-02-02 20:28:00 +08:00
Wesley Liddick
cbf3dba28d Merge pull request #454 from ZeroClover/feat-exclude-user-inactive-errors
feat(ops): 将 USER_INACTIVE 错误排除在 SLA 统计之外
2026-02-02 20:19:48 +08:00
Wesley Liddick
4329f72abf Merge pull request #450 from bayma888/feature/show-admin-adjustment-notes
feat: 向用户显示管理员调整余额的备注
2026-02-02 20:19:23 +08:00
Zero Clover
ad1cdba338 feat(ops): 支持过滤无效 API Key 错误,不写入错误日志
新增 IgnoreInvalidApiKeyErrors 开关,启用后 INVALID_API_KEY 和
API_KEY_REQUIRED 错误将被完全跳过,不写入 Ops 错误日志。
这些错误由用户错误配置导致,与服务质量无关。
2026-02-02 20:16:17 +08:00
Wesley Liddick
016c3915d7 Merge pull request #449 from bayma888/feature/user-search-support-notes
feat: 支持在用户列表搜索中使用备注字段和模糊查询、支持用户名备注等搜索
2026-02-02 20:16:03 +08:00
shaw
79fa18132b fix(gateway): 修复 OAuth token 刷新后调度器缓存不一致问题
Token 刷新成功后,调度器缓存中的 Account 对象仍包含旧的 credentials,
导致在 Outbox 异步更新之前(最多 1 秒窗口)请求使用过期 token,
返回 403 错误(OAuth token has been revoked)。

修复方案:在 token 刷新成功后同步更新调度器缓存,确保调度获取的
Account 对象立即包含最新的 access_token 和 _token_version。

此修复覆盖所有 OAuth 平台:OpenAI、Claude、Gemini、Antigravity。
2026-02-02 20:05:37 +08:00
Zero Clover
673caf41a0 feat(ops): 将 USER_INACTIVE 错误排除在 SLA 统计之外
将账户停用 (USER_INACTIVE) 导致的请求失败视为业务限制类错误,不计入 SLA 和错误率统计。

账户停用是预期内的业务结果,不应被视为系统错误或服务质量问题。此改动使错误分类更加准确,避免将预期的业务限制误报为系统故障。

修改内容:
- 在 classifyOpsIsBusinessLimited 函数中添加 USER_INACTIVE 错误码
- 该类错误不再触发错误率告警

Fixes Wei-Shaw/sub2api#453
2026-02-02 18:50:54 +08:00
JIA-ss
c441638fc0 feat(gateway): 增强 /v1/usage 端点返回完整用量统计
为 CC Switch 集成增强 /v1/usage 网关端点,在保持原有 4 字段
(isValid, planName, remaining, unit) 向后兼容的基础上,新增:

- usage 对象:今日/累计的请求数、token 用量、费用,以及 RPM/TPM
- subscription 对象(订阅模式):日/周/月用量和限额、过期时间
- balance 字段(余额模式):当前钱包余额

用量数据获取采用 best-effort 策略,失败不影响基础响应。

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-02 18:30:06 +08:00
小北
ae18397ca6 feat: 向用户显示管理员调整余额的备注
- 为RedeemCode DTO添加notes字段(仅用于admin_balance/admin_concurrency类型)
- 更新mapper使其有条件地包含备注信息
- 在用户兑换历史UI中显示备注
- 备注以斜体显示,悬停时显示完整内容

用户现在可以看到管理员调整其余额的原因说明。

Changes:
- backend/internal/handler/dto/types.go: RedeemCode添加notes字段
- backend/internal/handler/dto/mappers.go: 条件性填充notes
- frontend/src/api/redeem.ts: TypeScript接口添加notes
- frontend/src/views/user/RedeemView.vue: UI显示备注信息
2026-02-02 17:44:50 +08:00
小北
426ce616c0 feat: 支持在用户搜索中使用备注字段
- 在用户仓库的搜索过滤器中添加备注字段
- 管理员现在可以通过备注/标记搜索用户
- 使用不区分大小写的搜索(ContainsFold)

Changes:
- backend/internal/repository/user_repo.go: 添加 NotesContainsFold 到搜索条件
2026-02-02 17:41:27 +08:00
shaw
5cda979209 feat(deploy): 优化 Docker 部署体验,新增一键部署脚本
## 新增功能

- 新增 docker-compose.local.yml:使用本地目录存储数据,便于迁移和备份
- 新增 docker-deploy.sh:一键部署脚本,自动生成安全密钥(JWT_SECRET、TOTP_ENCRYPTION_KEY、POSTGRES_PASSWORD)
- 新增 deploy/.gitignore:忽略运行时数据目录

## 优化改进

- docker-compose.local.yml 包含 PGDATA 环境变量修复,解决 PostgreSQL 18 Alpine 数据丢失问题
- 脚本自动设置 .env 文件权限为 600,增强安全性
- 脚本显示生成的凭证,方便用户记录

## 文档更新

- 更新 README.md(英文版):新增"快速开始"章节,添加部署版本对比表
- 更新 README_CN.md(中文版):同步英文版更新
- 更新 deploy/README.md:详细说明两种部署方式和迁移方法

## 使用方式

一键部署:
```bash
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
docker-compose -f docker-compose.local.yml up -d
```

轻松迁移:
```bash
tar czf sub2api-complete.tar.gz deploy/
# 传输到新服务器后直接解压启动即可
```
2026-02-02 16:17:07 +08:00
Wesley Liddick
cc7e67b01a Merge pull request #445 from touwaeriol/fix/gemini-cache-token-billing
fix(billing): 修复 Gemini 接口缓存 token 统计
2026-02-02 15:22:46 +08:00
Wesley Liddick
6999a9c011 Merge pull request #444 from touwaeriol/fix/gemini-apikey-passthrough
feat(gateway): Gemini API Key 账户跳过模型映射检查,直接透传
2026-02-02 15:17:05 +08:00
shaw
bbdc8663d3 feat: 重新设计公告系统为Header铃铛通知
- 新增 AnnouncementBell 组件,支持 Modal 弹窗和 Markdown 渲染
- 移除 Dashboard 横幅和独立公告页面
- 铃铛位置在 Header 文档按钮左侧,显示未读红点
- 支持点击查看详情、标记已读、全部已读等操作
- 完善国际化,移除所有硬编码中文
- 修复 AnnouncementTargetingEditor watch 循环问题
2026-02-02 15:15:39 +08:00
liuxiongfeng
4bfeeecb05 fix(billing): 修复 Gemini 接口缓存 token 统计
extractGeminiUsage 函数未提取 cachedContentTokenCount,
导致计费时缓存读取 token 始终为 0。

修复:
- 提取 usageMetadata.cachedContentTokenCount
- 设置 CacheReadInputTokens 字段
- InputTokens 减去缓存 token(与 response_transformer 逻辑一致)
2026-02-02 14:01:17 +08:00
liuxiongfeng
bbc7b4aeed feat(gateway): Gemini API Key 账户跳过模型映射检查,直接透传
Gemini API Key 账户通常代理上游服务,模型支持由上游判断,
本地不需要预先配置模型映射。
2026-02-02 13:40:29 +08:00
Wesley Liddick
d3062b2e46 Merge pull request #434 from DuckyProject/feat/announcement-system-pr-upstream
feat(announcements): admin/user announcement system
2026-02-02 10:50:26 +08:00
Wesley Liddick
b7777fb46c Merge pull request #436 from iBenzene/feat/redis-tls-support
feat: add support for using TLS to connect to Redis
2026-02-02 10:02:25 +08:00
iBenzene
35f39ca291 chore: 修复了 redis.go 中代码风格(golangci-lint)的问题 2026-01-31 19:06:19 +08:00
iBenzene
f2e206700c feat: add support for using TLS to connect to Redis 2026-01-31 03:58:01 +08:00
ducky
9bee0a2071 chore: gofmt for golangci-lint 2026-01-30 17:28:53 +08:00
ducky
b7f69844e1 feat(announcements): add admin/user announcement system
Implements announcements end-to-end (admin CRUD + read status, user list + mark read) with OR-of-AND targeting. Also breaks the ent<->service import cycle by moving schema-facing constants/targeting into a new domain package.
2026-01-30 16:45:04 +08:00
Wesley Liddick
c3d1891ccd Merge pull request #427 from touwaeriol/pr/upgrade-antigravity-ua
chore: upgrade Antigravity User-Agent to 1.15.8
2026-01-30 09:17:17 +08:00
shaw
4d8f2db924 fix: 更新所有CI workflow的Go版本验证至1.25.6 2026-01-30 08:57:37 +08:00
shaw
6599b366dc fix: 升级Go版本至1.25.6修复标准库安全漏洞
修复GO-2026-4341和GO-2026-4340两个标准库漏洞
2026-01-30 08:53:53 +08:00
liuxiongfeng
ba16ace697 chore: upgrade Antigravity User-Agent to 1.15.8 2026-01-30 08:14:52 +08:00
shaw
cadca752c4 修复SSE流式响应中usage数据被覆盖的问题 2026-01-28 18:36:21 +08:00
Wesley Liddick
edf215e6fd Merge pull request #409 from DuckyProject/feat/purchase-subscription-iframe
feat(purchase): 增加购买订阅 iframe 页面与配置
2026-01-28 17:28:47 +08:00
shaw
e12dd079fd 修复调度器空缓存导致的竞态条件bug
当新分组创建后立即绑定账号时,调度器会错误地将空快照视为有效缓存命中,
导致返回没有可调度的账号。现在空快照会触发数据库回退查询。
2026-01-28 17:26:32 +08:00
ducky
04a509d45e feat(purchase): 增加购买订阅 iframe 页面与配置
- 新增 /purchase 页面(iframe + 新窗口兜底)

- 管理员系统设置可配置开关与URL

- 非 simple mode 才在侧边栏展示入口
2026-01-28 13:54:32 +08:00
Wesley Liddick
269a659200 Merge pull request #406 from geminiwen/main
fix(openai-oauth): 改进错误处理和代理支持
2026-01-28 13:53:44 +08:00
Wesley Liddick
2c31bf46b5 Merge pull request #401 from slovx2/heihuzi_main
feat(gemini): 为 Gemini 原生平台添加图片计费支持
2026-01-28 13:51:14 +08:00
Gemini Wen
8f6639f825 fix(response): add nil check for c.Request in error logging
Prevents panic when ErrorFrom is called in test contexts where
gin.CreateTestContext doesn't set up an HTTP request.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 19:26:44 +08:00
Gemini Wen
fc17d9d7df chore: bump version to 0.1.61 and fix tests
- Update VERSION from 0.1.46 to 0.1.61
- Remove ForceHTTP2 tests for OpenAI OAuth client (ForceHTTP2 was removed)
- Update createOpenAIReqClient test to use new single-arg signature

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 19:22:45 +08:00
Gemini Wen
ab092e88a8 fix(openai-oauth): 改进错误处理和代理支持
- 使用 ApplicationError 返回详细错误信息到前端
- 添加 User-Agent: codex-cli/0.91.0
- 移除 ForceHTTP2 以兼容 HTTP 代理
- 修复代理获取失败时静默忽略的问题
- 500 错误时记录完整错误日志

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 19:13:01 +08:00
shaw
56a1e29cdd fix(gateway): 修复 SSE 流式响应 usage 统计错误
message_delta 应完全覆盖 message_start 的 usage 数据,
而非仅在值为 0 时才更新。
2026-01-27 09:16:34 +08:00
song
0059a232a6 feat(gemini): 为 Gemini 原生平台添加图片计费支持
对齐 Antigravity 平台的图片计费逻辑:
- 添加 extractImageSize() 方法提取图片尺寸
- Forward() 和 ForwardNative() 返回 ImageCount/ImageSize
- 支持分组自定义图片价格和倍率
2026-01-26 20:51:40 +08:00
shaw
45676fdc8d fix(ci): 转义 Telegram 消息中的 Markdown 特殊字符
修复发布通知发送失败的问题,原因是 tag message 中包含未闭合的
Markdown 格式标记(如 user_id 中的 _ 被解析为斜体开始)导致
Telegram API 返回解析错误。

添加 sed 命令转义 _、*、` 和 [ 字符,避免被 Telegram Markdown
解析器错误处理。
2026-01-26 11:07:08 +08:00
Wesley Liddick
e32c5f534f Merge pull request #386 from IanShaw027/fix/openai-usage-limit-reset-time
fix(ratelimit): 修复 OpenAI usage_limit_reached 错误的重置时间解析
2026-01-26 10:22:42 +08:00
shaw
426d691c95 fix(urlvalidator): 移除 ValidateURLFormat 返回值的末尾斜杠
修复 API Key 账号 base_url 末尾带斜杠时导致的双斜杠问题
2026-01-26 10:21:41 +08:00
shaw
e9a4c8ab97 docs: 修改演示站点域名 2026-01-26 10:04:44 +08:00
ianshaw
a55cfebd09 fix(ratelimit): 修复 OpenAI usage_limit_reached 错误的重置时间解析
- 问题:OpenAI 的 usage_limit_reached 错误(需 37 小时重置)被错误地设置为 5 分钟
- 原因:handle429 只检查 Anthropic 响应头,没有解析 OpenAI 响应体中的 resets_in_seconds
- 修复:新增 parseOpenAIRateLimitResetTime 函数解析 OpenAI 响应体
- 影响:避免调度器不断尝试已达配额上限的账户
2026-01-26 09:57:44 +08:00
Wesley Liddick
34cc02f8c7 Merge pull request #393 from IanShaw027/fix/gemini-thought-signature-preserve
fix(gemini): 修复 thoughtSignature 跨账号验证错误
2026-01-26 09:23:46 +08:00
Wesley Liddick
624d9fddb7 Merge pull request #391 from geminiwen/main
fix(subscription): 修复订阅调整逻辑,已过期订阅从当前时间计算
2026-01-26 09:23:29 +08:00
Wesley Liddick
47fbe43324 Merge pull request #385 from DDZS987/fix/oauth-token-refresh-missing-project-id-retry
fix(oauth): 修复 OAuth 令牌刷新时 missing_project_id 误报问题
2026-01-26 09:22:48 +08:00
shaw
1245f07a2d feat(auth): 实现 TOTP 双因素认证功能
新增功能:
- 支持 Google Authenticator 等应用进行 TOTP 二次验证
- 用户可在个人设置中启用/禁用 2FA
- 登录时支持 TOTP 验证流程
- 管理后台可全局开关 TOTP 功能

安全增强:
- TOTP 密钥使用 AES-256-GCM 加密存储
- 添加 TOTP_ENCRYPTION_KEY 配置项,必须手动配置才能启用功能
- 防止服务重启导致加密密钥变更使用户无法登录
- 验证失败次数限制,防止暴力破解

配置说明:
- Docker 部署:在 .env 中设置 TOTP_ENCRYPTION_KEY
- 非 Docker 部署:在 config.yaml 中设置 totp.encryption_key
- 生成密钥命令:openssl rand -hex 32
2026-01-26 09:19:53 +08:00
ianshaw
839975b0cf feat(gemini): 支持 Gemini CLI 粘性会话与跨账号 thoughtSignature 清理
## 问题背景

1. Gemini CLI 没有明确的会话标识(如 Claude Code 的 metadata.user_id)
2. thoughtSignature 与具体上游账号强绑定,跨账号使用会导致 400 错误
3. 粘性会话切换账号或 cache 丢失时,旧签名会导致请求失败

## 解决方案

### 1. Gemini CLI 会话标识提取

- 从 `x-gemini-api-privileged-user-id` header 和请求体中的 tmp 目录哈希生成会话标识
- 组合策略:SHA256(privileged-user-id + ":" + tmp_dir_hash)
- 正则提取:`/\.gemini/tmp/([A-Fa-f0-9]{64})`

### 2. 跨账号 thoughtSignature 清理

实现三种场景的智能清理:

1. **Cache 命中 + 账号切换**
   - 粘性会话绑定的账号与当前选择的账号不同时清理

2. **同一请求内 failover 切换**
   - 通过 sessionBoundAccountID 跟踪,检测重试时的账号切换

3. **Gemini CLI + Cache 未命中 + 含签名**
   - 预防性清理,避免 cache 丢失后首次转发就 400
   - 仅对 Gemini CLI 请求且请求体包含 thoughtSignature 时触发

## 修改内容

### backend/internal/handler/gemini_v1beta_handler.go
- 添加 `extractGeminiCLISessionHash` 函数提取 Gemini CLI 会话标识
- 添加 `isGeminiCLIRequest` 函数识别 Gemini CLI 请求
- 实现账号切换检测与 thoughtSignature 清理逻辑
- 添加 `geminiCLITmpDirRegex` 正则表达式

### backend/internal/service/gateway_service.go
- 添加 `GetCachedSessionAccountID` 方法查询粘性会话绑定的账号 ID

### backend/internal/service/gemini_native_signature_cleaner.go (新增)
- 实现 `CleanGeminiNativeThoughtSignatures` 函数
- 递归清理 JSON 中的所有 thoughtSignature 字段
- 支持任意 JSON 顶层类型(object/array)

### backend/internal/handler/gemini_cli_session_test.go (新增)
- 测试 Gemini CLI 会话哈希提取逻辑
- 测试 tmp 目录正则匹配
- 覆盖有/无 privileged-user-id 的场景

## 影响范围

- 修复 Gemini CLI 多轮对话时账号切换导致的 400 错误
- 提高粘性会话的稳定性和容错能力
- 不影响其他客户端(Claude Code 等)的会话标识生成

## 测试

- 单元测试:go test -tags=unit ./internal/handler -run TestExtractGeminiCLISessionHash
- 单元测试:go test -tags=unit ./internal/handler -run TestGeminiCLITmpDirRegex
- 编译验证:go build ./cmd/server
2026-01-26 04:40:38 +08:00
ianshaw
8c1233393f fix(antigravity): 修复 Gemini 模型 thoughtSignature 被错误覆盖的问题
## 问题描述

在使用 Gemini 模型(gemini-3-flash-preview)时,出现 400 错误:
"Unable to submit request because Thought signature is not valid"

## 根本原因

在 `request_transformer.go` 的 `buildParts()` 函数中:
- 对于 `tool_use` 和 `thinking` 块,当 `allowDummyThought=true`(Gemini 模型)时
- 代码会无条件将客户端传入的真实 `thoughtSignature` 覆盖成 dummy 值
- 导致 Gemini API 验证签名失败(签名与上下文不匹配)

## 修复方案

修改 signature 处理逻辑:
1. **优先透传真实 signature**:如果客户端提供了有效的 signature,保留它
2. **缺失时才使用 dummy**:只有在 signature 缺失且是 Gemini 模型时,才使用 dummy signature
3. **Claude 模型特殊处理**:将 dummy signature 视为缺失,避免透传到需要真实签名的链路

## 修改内容

### request_transformer.go
- `thinking` 块(第 367-386 行):优先透传真实 signature
- `tool_use` 块(第 411-418 行):优先透传真实 signature

### request_transformer_test.go
- 修改测试用例名称,反映新的行为
- 新增测试用例验证"缺失时才使用 dummy"的逻辑

## 影响范围

- 修复 Gemini 模型在多轮对话中使用 tool_use 时的签名验证错误
- 不影响 Claude 模型的现有行为
- 提高跨账号切换时的稳定性

相关问题:#[issue_number]
2026-01-26 03:06:45 +08:00
Gemini Wen
9cdb0568cc fix(subscription): 修复订阅调整逻辑,已过期订阅从当前时间计算
- 已过期订阅延长时,从当前时间开始增加天数
- 已过期订阅不允许负向调整(缩短)
- 未过期订阅保持原逻辑,从原过期时间增减

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-25 18:12:15 +08:00
shaw
74e05b83ea fix(ratelimit): 修复 OpenAI 账号限流倒计时计算错误
- 解析 x-codex-* 响应头获取正确的重置时间
- 7d 限制用尽时使用 codex_7d_reset_after_seconds
- 提取 Normalize() 方法统一窗口规范化逻辑
2026-01-25 13:32:08 +08:00
Ubuntu
4ded9e7d49 fix(oauth): 为初始 OAuth 授权添加 LoadCodeAssist 重试机制
问题:
- 初始授权时 LoadCodeAssist 没有重试机制,失败后直接跳过
- 导致账号创建时就可能缺失 project_id
- 之后每次刷新都因为 missing_project_id 报错

修复:
- 统一使用 loadProjectIDWithRetry 方法(最多4次尝试)
- 初始授权和token刷新使用相同的重试策略
- 保留原注释说明部分账户可能没有 project_id

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-24 23:41:36 +08:00
Ubuntu
716272a1e2 fix(oauth): 彻底修复 project_id 丢失问题
根本原因:
- BuildAccountCredentials 只在 project_id 非空时才添加该字段
- LoadCodeAssist 失败时返回空字符串 → 新 credentials 不包含 project_id 键
- 普通合并逻辑只保留新 credentials 中不存在的键,无法覆盖空值

解决方案:
1. 在合并后特殊处理 project_id:如果新值为空但旧值非空,保留旧值
2. LoadCodeAssist 失败不再返回错误,只记录警告
3. Token 刷新成功(access_token 已更新)就不应标记账户为 error

改进效果:
- 即使 LoadCodeAssist 连续失败,已有的 project_id 也不会丢失
- 避免因临时网络问题将账户误标记为不可用
- 允许在下次刷新时自动重试获取 project_id

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-24 23:04:48 +08:00
shaw
9cc8352593 feat(auth): 密码重置邮件队列化与限流优化
- 邮件发送改为异步队列处理,避免并发导致发送失败
- 新增 Email 维度限流(30秒冷却期),防止邮件轰炸
- Token 验证使用常量时间比较,防止时序攻击
- 重构代码消除冗余,提取公共验证逻辑
2026-01-24 22:55:28 +08:00
shaw
43a1031e38 fix(test): 修复订阅相关测试失败问题
1. 使用未来日期(2099年)作为测试订阅的过期时间,避免
   normalizeSubscriptionStatus 将测试数据标记为过期
2. 修复 List 方法调用参数不足的问题(新增 sortBy/sortOrder 参数)
2026-01-24 21:10:02 +08:00
Wesley Liddick
a5547b2f30 Merge pull request #380 from DDZS987/fix/oauth-token-refresh-missing-project-id-retry
fix(oauth): 修复 OAuth 令牌刷新时 missing_project_id 误报问题
2026-01-24 20:29:43 +08:00
shaw
b0aa23540b feat(subscription): 订阅过期状态自动更新与服务端排序
- 新增 SubscriptionExpiryService 定时任务,每分钟更新过期订阅状态
- 订阅列表支持服务端排序(按过期时间、状态、创建时间)
- 实时显示正确的过期状态,无需等待定时任务
- 允许对已过期订阅进行续期操作
- DataTable 组件支持 serverSideSort 模式
2026-01-24 20:26:01 +08:00
Ubuntu
ffaa6c4a17 fix(oauth): 修复 OAuth 令牌刷新时 missing_project_id 误报问题
问题描述:
在网络波动环境下,LoadCodeAssist 临时失败会被错误地标记为
"missing_project_id" 不可重试错误,导致账户被禁用。但实际上
账户配置正确,手动刷新后即可恢复。

解决方案:
1. 为 LoadCodeAssist 增加重试机制(最多4次,指数退避)
2. 区分真正的配置缺失和临时网络故障:
   - 如果之前有 project_id,本次获取失败只保留旧值,不报错
   - 只有从未获取过 project_id 且本次也失败时,才标记为缺失
3. 优化错误判断逻辑,避免误报

改进效果:
- 提高在复杂网络环境(如家宽代理)下的鲁棒性
- 减少因临时网络波动导致的服务中断
- 保持真正配置错误的检测能力

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-24 17:44:56 +08:00
Wesley Liddick
fbf72f0ec4 Merge pull request #377 from lynoot/fix/non-streaming-chunk-aggregation
fix(gateway): aggregate all text chunks in non-streaming Gemini responses
2026-01-24 08:49:10 +08:00
lynoot
909b8a8f9c fix(gateway): aggregate all text chunks in non-streaming Gemini responses
Previously, collectGeminiSSE() only returned the last chunk received
from the upstream streaming response when converting to non-streaming.
This caused incomplete responses where only the final text fragment
was returned to clients.

For example, a request asking to "count from 1 to 10" would only
return "\n" (the last chunk) instead of "1\n2\n3\n...\n10\n".

This was especially problematic for JSON structured output where
the opening brace "{" from the first chunk was lost, resulting
in invalid JSON like: colors": ["red", "blue"]}

The fix:
- Collect all text parts from each SSE chunk into a slice
- Merge all collected text parts into the final response
- Reuse the same pattern as handleGeminiStreamToNonStreaming
  in antigravity_gateway_service.go

Fixes: non-streaming responses returning incomplete text
Fixes: structured output (JSON schema) returning invalid JSON
2026-01-23 13:54:09 +00:00
shaw
4a0fe3b143 feat(gateway): 增加 SUGGESTION MODE 请求拦截
扩展现有的预热请求拦截功能,新增对 SUGGESTION MODE 请求的拦截:
- 检测 messages 最后一条 user 消息是否以 [SUGGESTION MODE: 开头
- 拦截后返回空内容响应,节省 token 消耗
- 重构检测逻辑,合并为单一函数,只解析一次 JSON
2026-01-23 16:57:25 +08:00
shaw
a1292fac81 feat(oauth): 支持Anthropic的Team账号使用sk授权 2026-01-23 16:30:12 +08:00
shaw
7f98be4f91 fix(oauth): 更新 Anthropic 账号 OAuth 参数,同步最新客户端 2026-01-23 16:00:42 +08:00
shaw
fd73b8875d feat(frontend): 优化账号限流状态显示,直接展示倒计时 2026-01-23 15:48:25 +08:00
shaw
f9ab1daa3c feat: 保存并显示OAuth账号邮箱地址 2026-01-23 15:17:47 +08:00
shaw
d27b847442 fix(test): 修复测试中引用不存在的函数名
测试文件引用了 IsTokenVersionStale 函数,但实际函数名为 CheckTokenVersion,导致 CI 构建失败
2026-01-23 10:58:30 +08:00
shaw
dac6bc2228 fix(token-cache): 版本过时时使用最新token而非旧token
上次修复(2665230)只阻止了写入缓存,但仍返回旧token导致403。
现在版本过时时直接使用DB中的最新token返回。

- 重构 IsTokenVersionStale 为 CheckTokenVersion,返回最新account
- 消除重复DB查询,复用版本检查时已获取的account
2026-01-23 10:29:52 +08:00
Wesley Liddick
4bd3dbf2ce Merge pull request #354 from DuckyProject/fix/frontend-table
feat(frontend): 账号表格默认排序/持久化 + 自动刷新 + 更多菜单外部关闭
2026-01-22 21:17:48 +08:00
Wesley Liddick
226df1c23a Merge pull request #358 from 0xff26b9a8/main
fix(antigravity): 修复非流式 Claude To Antigravity 响应内容为空的问题
2026-01-22 21:16:54 +08:00
shaw
2665230a09 fix(token-cache): 修复异步刷新与请求线程的缓存竞态条件
- 新增 _token_version 版本号机制,防止过期 token 污染缓存
- TokenRefreshService 刷新成功后写入版本号并清除缓存
- TokenProvider 写入缓存前检查版本,过时则跳过
- ClearError 时同步清除 token 缓存
2026-01-22 21:09:28 +08:00
0xff26b9a8
4f0c2b794c style: gofmt antigravity_gateway_service.go 2026-01-22 14:38:55 +08:00
0xff26b9a8
e756064c19 fix(antigravity): 修复非流式 Claude To Antigravity 响应内容为空的问题
- 修复 TransformGeminiToClaude 的 JSON 解析逻辑,当 V1InternalResponse
  解析成功但 candidates 为空时,尝试直接解析为 GeminiResponse 格式
- 修复 handleClaudeStreamToNonStreaming 收集流式响应的逻辑,累积所有
  chunks 的内容而不是只保留最后一个(最后一个 chunk 通常 text 为空)
- 新增 mergeCollectedPartsToResponse 函数,合并所有类型的 parts
  (text、thinking、functionCall、inlineData),保持原始顺序
- 连续的普通 text parts 合并为一个,thinking/functionCall/inlineData 保持原样
2026-01-22 14:17:59 +08:00
Wesley Liddick
17dfb0af01 Merge pull request #346 from 0xff26b9a8/main
refactor(antigravity): 提取并同步 Schema 清理逻辑至 schema_cleaner.go
2026-01-22 08:46:11 +08:00
ducky
ff74f517df feat(frontend): 账号表格默认排序/持久化 + 自动刷新 + 更多菜单外部关闭 2026-01-21 22:43:25 +08:00
0xff26b9a8
477a9a180f fix: 修复 schema 清理逻辑 2026-01-21 10:58:39 +08:00
0xff26b9a8
da48df06d2 refactor(antigravity): 提取并同步 Schema 清理逻辑至 schema_cleaner.go
主要变更:
1. 重构代码结构:
   - 将 CleanJSONSchema 及其相关辅助函数从 request_transformer.go 提取到独立的 schema_cleaner.go 文件中,实现逻辑解耦。

2. 逻辑优化与修正:
   - 参考 Antigravity-Manager (json_schema.rs) 的实现逻辑,修正了 Schema 清洗策略。
2026-01-20 23:41:53 +08:00
Wesley Liddick
39fad63ccf Merge pull request #345 from whoismonay/main
mod(frontend): 管理员订阅/兑换码分组选择展示备注
2026-01-20 16:22:53 +08:00
Wesley Liddick
5602d02b1b Merge pull request #343 from mt21625457/main
fix(调度): 完善粘性会话清理与账号调度刷新 和 启用 OpenAI OAuth HTTP/2 并修复清理任务 lint
2026-01-20 16:05:53 +08:00
shaw
81989eed1c test: add promo_code_enabled to API contract test 2026-01-20 16:02:49 +08:00
shaw
192efb84a0 feat(promo-code): complete promo code feature implementation
- Add promo_code_enabled field to SystemSettings and PublicSettings DTOs
- Add promo code validation in registration flow
- Add admin settings UI for promo code configuration
- Add i18n translations for promo code feature
2026-01-20 15:56:26 +08:00
shaw
8672347f93 fix(settings): add missing promo_code_enabled field in public settings API
The field was defined in DTO but not mapped in handler response.
2026-01-20 15:49:57 +08:00
yangjianbo
5e5d4a513b feat: 移动镜像脚本位置 2026-01-20 15:11:27 +08:00
墨颜
88b6358472 build(frontend): vite 加载开发环境变量
- 使用 loadEnv 读取 VITE_DEV_PROXY_TARGET/VITE_DEV_PORT
- 注入 public settings 与 dev proxy 使用同源后端地址
2026-01-20 15:04:18 +08:00
墨颜
dd8d5e2c42 mod(frontend): 订阅分组下拉显示备注
- 订阅管理/分配订阅:下拉项展示分组备注
- 兑换码/订阅类型:下拉项展示分组备注
- 复用 GroupOptionItem/GroupBadge 保持一致体验
2026-01-20 15:02:48 +08:00
shaw
d91e2328fb test(tlsfingerprint): add multi-profile fingerprint verification test
Add TestAllProfiles to verify TLS fingerprint configurations from
config.yaml against tls.peet.ws. Tests check JA4 cipher hash (stable
part) to validate fingerprint spoofing works correctly.
2026-01-20 14:53:15 +08:00
yangjianbo
2a16735495 fix(测试): 修复 SelectAccountWithLoadAwareness 调用缺少参数
为 gateway_multiplatform_test.go 中的 SelectAccountWithLoadAwareness
调用添加缺少的第6个参数 metadataUserID,修复 CI 测试编译错误。

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 14:16:46 +08:00
yangjianbo
292f25f9ca Merge branch 'main' of https://github.com/mt21625457/aicodex2api 2026-01-20 14:02:08 +08:00
yangjianbo
c92e37775a Merge branch 'dev' 2026-01-20 13:57:08 +08:00
yangjianbo
f6ed3d1456 Merge branch 'test' into dev 2026-01-20 11:59:13 +08:00
yangjianbo
84686753e8 Merge branch 'test' of https://github.com/mt21625457/aicodex2api into test 2026-01-20 11:51:44 +08:00
yangjianbo
91f01309da fix(调度): 完善粘性会话清理与账号调度刷新
- Update/BulkUpdate 按不可调度字段触发缓存刷新
- GatewayCache 支持多前缀会话键清理
- 模型路由与混合调度优化粘性会话处理
- 补充调度与缓存相关测试覆盖
2026-01-20 11:40:55 +08:00
yangjianbo
57a1fc9d33 style(仓储): 格式化账号仓储
- gofmt 修正 lint 格式提示
2026-01-20 11:30:36 +08:00
shaw
c95a864975 docs: add TLS fingerprint tool link 2026-01-20 11:30:10 +08:00
yangjianbo
7a83db6180 fix(调度): 完善粘性会话清理与账号调度刷新
- Update/BulkUpdate 按不可调度字段触发缓存刷新
- GatewayCache 支持多前缀会话键清理
- 模型路由与混合调度优化粘性会话处理
- 补充调度与缓存相关测试覆盖
2026-01-20 11:19:32 +08:00
Wesley Liddick
a8513da7ff Merge pull request #335 from geminiwen/main
feat(subscription): 支持调整订阅时长(延长/缩短)
2026-01-20 08:52:19 +08:00
Gemini Wen
53534d3956 style(admin): 统一列设置按钮位置到刷新按钮右侧
将订阅管理和账号管理页面的列设置按钮移动到刷新按钮右侧,
与用户管理页面保持一致的布局。

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 00:18:51 +08:00
Gemini Wen
cc07a0e295 feat(subscription): 支持调整订阅时长(延长/缩短)
- 将"延长订阅"功能改为"调整订阅",支持正数延长、负数缩短
- 后端验证:调整天数范围 -36500 到 36500,缩短后剩余天数必须 > 0
- 前端同步更新界面文案和验证逻辑

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 00:11:30 +08:00
Wesley Liddick
e7bc62500b Merge pull request #333 from whoismonay/main
fix: 普通用户接口移除管理员敏感字段透传
2026-01-19 21:35:51 +08:00
墨颜
c8fb9ef3a5 style(dto): 修复 gofmt 格式问题
- 修复 mappers.go 中 Notes 字段的对齐格式
- 修复 types.go 中 BulkAssignResult 结构体字段的 JSON tag 对齐

修复 golangci-lint 检查中的 gofmt 格式错误
2026-01-19 21:26:30 +08:00
Wesley Liddick
eb5e6214bc Merge pull request #332 from geminiwen/main
fix: 修复手动刷新令牌后缓存未清除导致403错误的问题
2026-01-19 20:53:06 +08:00
shaw
568d6ee10e fix: 修复测试缺少新增设置字段 2026-01-19 20:52:05 +08:00
墨颜
6aef1af76e fix(redeem): 用户兑换历史不返回备注
- 用户侧 RedeemCode DTO 移除 notes 字段,避免泄露内部备注\n- 新增 AdminRedeemCode,并调整管理员兑换码接口继续返回 notes\n- 增加 /api/v1/redeem/history 契约测试,确保用户侧响应不包含 notes
2026-01-19 20:09:35 +08:00
Gemini Wen
a54852e129 fix: 补充API契约测试中缺失的hide_ccs_import_button字段
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 20:06:36 +08:00
Gemini Wen
668118def1 fix: 修复遗漏的测试文件更新和lint错误
- 更新api_contract_test.go以匹配NewAccountHandler新增的tokenCacheInvalidator参数
- 修复errcheck lint错误,显式忽略c.Error()返回值

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 19:58:09 +08:00
yangjianbo
73e6b160f8 feat(认证): 启用 OpenAI OAuth HTTP/2 并修复清理任务 lint
为共享 req 客户端增加 HTTP/2 选项与缓存隔离
OpenAI OAuth 超时提升到 120s,并按协议控制强制
新增客户端池与 OAuth 客户端单测覆盖
修复 usage cleanup 相关 errcheck/ineffassign/staticcheck 并统一格式

测试: make test
2026-01-19 19:50:57 +08:00
Gemini Wen
6fec141de6 fix: 修复手动刷新令牌后缓存未清除导致403错误的问题
手动刷新令牌后,新token保存到数据库但Redis缓存未清除,
导致下游请求仍然使用旧的失效token,上游API返回403错误。

修复方案:在AccountHandler中注入TokenCacheInvalidator,
刷新令牌成功后调用InvalidateToken清除缓存。

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 19:40:43 +08:00
墨颜
31cde6c555 fix(subscriptions): 用户订阅不返回分配信息
- 用户侧 UserSubscription DTO 移除 assigned_by/assigned_at/notes/assigned_by_user 等管理员字段\n- 新增 AdminUserSubscription,并调整管理员订阅接口与批量分配结果使用\n- 增加 /api/v1/subscriptions 契约测试,确保用户侧响应不包含上述字段
2026-01-19 19:35:13 +08:00
shaw
b1a980f344 feat: 添加隐藏CCS导入按钮的设置选项
在管理后台设置页面新增开关,允许管理员隐藏API Keys页面的"导入CCS"按钮
2026-01-19 19:25:16 +08:00
墨颜
00d9fbd220 fix(user): 普通用户接口不返回备注
- 用户侧 dto.User 移除 notes 字段,避免泄露管理员备注\n- 新增 dto.AdminUser 并调整 /admin/users 系列接口使用\n- 前端拆分 User/AdminUser,管理端用户页面使用 AdminUser\n- 更新契约测试:/api/v1/auth/me 响应不包含 notes
2026-01-19 19:23:51 +08:00
墨颜
4f4c9679bf fix(groups): 用户分组不下发内部路由信息
- 普通用户 Group DTO 移除 model_routing/account_count/account_groups,避免泄露内部路由与账号信息\n- 新增 AdminGroup DTO,并仅在管理员分组接口返回完整字段\n- 前端拆分 Group/AdminGroup,管理端页面与 API 使用 AdminGroup\n- 增加 /api/v1/groups/available 契约测试,防止回归
2026-01-19 18:58:42 +08:00
shaw
3dab71729d feat: usage接口支持TLS指纹和缓存User-Agent 2026-01-19 17:06:16 +08:00
墨颜
2f6f758670 fix(usage): 用户使用记录不下发账号计费倍率
- 将 usage log DTO 拆分为用户/管理员两类
- 用户接口不返回 account_rate_multiplier/ip_address/account
- 管理员接口保留管理员字段
- 补充契约测试防止回归
2026-01-19 17:05:42 +08:00
shaw
090c8981dd fix: 更新Claude OAuth授权配置以匹配最新规范
- 更新TokenURL和RedirectURI为platform.claude.com
- 更新scope定义,区分浏览器URL和内部API调用
- 修正state/code_verifier生成算法使用base64url编码
- 修正授权URL参数顺序并添加code=true
- 更新token交换请求头匹配官方实现
- 清理未使用的类型和函数
2026-01-19 16:40:06 +08:00
shaw
fbb572948d fix: 修复会话数量查询使用错误的超时配置 2026-01-19 11:45:04 +08:00
249 changed files with 27529 additions and 8881 deletions

View File

@@ -19,7 +19,7 @@ jobs:
cache: true
- name: Verify Go version
run: |
go version | grep -q 'go1.25.5'
go version | grep -q 'go1.25.6'
- name: Unit tests
working-directory: backend
run: make test-unit
@@ -38,7 +38,7 @@ jobs:
cache: true
- name: Verify Go version
run: |
go version | grep -q 'go1.25.5'
go version | grep -q 'go1.25.6'
- name: golangci-lint
uses: golangci/golangci-lint-action@v9
with:

View File

@@ -115,7 +115,7 @@ jobs:
- name: Verify Go version
run: |
go version | grep -q 'go1.25.5'
go version | grep -q 'go1.25.6'
# Docker setup for GoReleaser
- name: Set up QEMU
@@ -222,8 +222,9 @@ jobs:
REPO="${{ github.repository }}"
GHCR_IMAGE="ghcr.io/${REPO,,}" # ${,,} converts to lowercase
# 获取 tag message 内容
# 获取 tag message 内容并转义 Markdown 特殊字符
TAG_MESSAGE='${{ steps.tag_message.outputs.message }}'
TAG_MESSAGE=$(echo "$TAG_MESSAGE" | sed 's/\([_*`\[]\)/\\\1/g')
# 限制消息长度Telegram 消息限制 4096 字符,预留空间给头尾固定内容)
if [ ${#TAG_MESSAGE} -gt 3500 ]; then

View File

@@ -22,7 +22,7 @@ jobs:
cache-dependency-path: backend/go.sum
- name: Verify Go version
run: |
go version | grep -q 'go1.25.5'
go version | grep -q 'go1.25.6'
- name: Run govulncheck
working-directory: backend
run: |

View File

@@ -7,7 +7,7 @@
# =============================================================================
ARG NODE_IMAGE=node:24-alpine
ARG GOLANG_IMAGE=golang:1.25.5-alpine
ARG GOLANG_IMAGE=golang:1.25.6-alpine
ARG ALPINE_IMAGE=alpine:3.20
ARG GOPROXY=https://goproxy.cn,direct
ARG GOSUMDB=sum.golang.google.cn

130
README.md
View File

@@ -18,7 +18,7 @@ English | [中文](README_CN.md)
## Demo
Try Sub2API online: **https://v2.pincc.ai/**
Try Sub2API online: **https://demo.sub2api.org/**
Demo credentials (shared demo environment; **not** created automatically for self-hosted installs):
@@ -128,7 +128,7 @@ curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install
---
### Method 2: Docker Compose
### Method 2: Docker Compose (Recommended)
Deploy with Docker Compose, including PostgreSQL and Redis containers.
@@ -137,87 +137,157 @@ Deploy with Docker Compose, including PostgreSQL and Redis containers.
- Docker 20.10+
- Docker Compose v2+
#### Installation Steps
#### Quick Start (One-Click Deployment)
Use the automated deployment script for easy setup:
```bash
# Create deployment directory
mkdir -p sub2api-deploy && cd sub2api-deploy
# Download and run deployment preparation script
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
# Start services
docker-compose -f docker-compose.local.yml up -d
# View logs
docker-compose -f docker-compose.local.yml logs -f sub2api
```
**What the script does:**
- Downloads `docker-compose.local.yml` and `.env.example`
- Generates secure credentials (JWT_SECRET, TOTP_ENCRYPTION_KEY, POSTGRES_PASSWORD)
- Creates `.env` file with auto-generated secrets
- Creates data directories (uses local directories for easy backup/migration)
- Displays generated credentials for your reference
#### Manual Deployment
If you prefer manual setup:
```bash
# 1. Clone the repository
git clone https://github.com/Wei-Shaw/sub2api.git
cd sub2api
cd sub2api/deploy
# 2. Enter the deploy directory
cd deploy
# 3. Copy environment configuration
# 2. Copy environment configuration
cp .env.example .env
# 4. Edit configuration (set your passwords)
# 3. Edit configuration (generate secure passwords)
nano .env
```
**Required configuration in `.env`:**
```bash
# PostgreSQL password (REQUIRED - change this!)
# PostgreSQL password (REQUIRED)
POSTGRES_PASSWORD=your_secure_password_here
# JWT Secret (RECOMMENDED - keeps users logged in after restart)
JWT_SECRET=your_jwt_secret_here
# TOTP Encryption Key (RECOMMENDED - preserves 2FA after restart)
TOTP_ENCRYPTION_KEY=your_totp_key_here
# Optional: Admin account
ADMIN_EMAIL=admin@example.com
ADMIN_PASSWORD=your_admin_password
# Optional: Custom port
SERVER_PORT=8080
```
# Optional: Security configuration
# Enable URL allowlist validation (false to skip allowlist checks, only basic format validation)
SECURITY_URL_ALLOWLIST_ENABLED=false
**Generate secure secrets:**
```bash
# Generate JWT_SECRET
openssl rand -hex 32
# Allow insecure HTTP URLs when allowlist is disabled (default: false, requires https)
# ⚠️ WARNING: Enabling this allows HTTP (plaintext) URLs which can expose API keys
# Only recommended for:
# - Development/testing environments
# - Internal networks with trusted endpoints
# - When using local test servers (http://localhost)
# PRODUCTION: Keep this false or use HTTPS URLs only
SECURITY_URL_ALLOWLIST_ALLOW_INSECURE_HTTP=false
# Generate TOTP_ENCRYPTION_KEY
openssl rand -hex 32
# Allow private IP addresses for upstream/pricing/CRS (for internal deployments)
SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS=false
# Generate POSTGRES_PASSWORD
openssl rand -hex 32
```
```bash
# 4. Create data directories (for local version)
mkdir -p data postgres_data redis_data
# 5. Start all services
# Option A: Local directory version (recommended - easy migration)
docker-compose -f docker-compose.local.yml up -d
# Option B: Named volumes version (simple setup)
docker-compose up -d
# 6. Check status
docker-compose ps
docker-compose -f docker-compose.local.yml ps
# 7. View logs
docker-compose logs -f sub2api
docker-compose -f docker-compose.local.yml logs -f sub2api
```
#### Deployment Versions
| Version | Data Storage | Migration | Best For |
|---------|-------------|-----------|----------|
| **docker-compose.local.yml** | Local directories | ✅ Easy (tar entire directory) | Production, frequent backups |
| **docker-compose.yml** | Named volumes | ⚠️ Requires docker commands | Simple setup |
**Recommendation:** Use `docker-compose.local.yml` (deployed by script) for easier data management.
#### Access
Open `http://YOUR_SERVER_IP:8080` in your browser.
If admin password was auto-generated, find it in logs:
```bash
docker-compose -f docker-compose.local.yml logs sub2api | grep "admin password"
```
#### Upgrade
```bash
# Pull latest image and recreate container
docker-compose pull
docker-compose up -d
docker-compose -f docker-compose.local.yml pull
docker-compose -f docker-compose.local.yml up -d
```
#### Easy Migration (Local Directory Version)
When using `docker-compose.local.yml`, migrate to a new server easily:
```bash
# On source server
docker-compose -f docker-compose.local.yml down
cd ..
tar czf sub2api-complete.tar.gz sub2api-deploy/
# Transfer to new server
scp sub2api-complete.tar.gz user@new-server:/path/
# On new server
tar xzf sub2api-complete.tar.gz
cd sub2api-deploy/
docker-compose -f docker-compose.local.yml up -d
```
#### Useful Commands
```bash
# Stop all services
docker-compose down
docker-compose -f docker-compose.local.yml down
# Restart
docker-compose restart
docker-compose -f docker-compose.local.yml restart
# View all logs
docker-compose logs -f
docker-compose -f docker-compose.local.yml logs -f
# Remove all data (caution!)
docker-compose -f docker-compose.local.yml down
rm -rf data/ postgres_data/ redis_data/
```
---

View File

@@ -135,7 +135,7 @@ curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install
---
### 方式二Docker Compose
### 方式二Docker Compose(推荐)
使用 Docker Compose 部署,包含 PostgreSQL 和 Redis 容器。
@@ -144,87 +144,157 @@ curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install
- Docker 20.10+
- Docker Compose v2+
#### 安装步骤
#### 快速开始(一键部署)
使用自动化部署脚本快速搭建:
```bash
# 创建部署目录
mkdir -p sub2api-deploy && cd sub2api-deploy
# 下载并运行部署准备脚本
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
# 启动服务
docker-compose -f docker-compose.local.yml up -d
# 查看日志
docker-compose -f docker-compose.local.yml logs -f sub2api
```
**脚本功能:**
- 下载 `docker-compose.local.yml``.env.example`
- 自动生成安全凭证JWT_SECRET、TOTP_ENCRYPTION_KEY、POSTGRES_PASSWORD
- 创建 `.env` 文件并填充自动生成的密钥
- 创建数据目录(使用本地目录,便于备份和迁移)
- 显示生成的凭证供你记录
#### 手动部署
如果你希望手动配置:
```bash
# 1. 克隆仓库
git clone https://github.com/Wei-Shaw/sub2api.git
cd sub2api
cd sub2api/deploy
# 2. 进入 deploy 目录
cd deploy
# 3. 复制环境配置文件
# 2. 复制环境配置文件
cp .env.example .env
# 4. 编辑配置(设置密码
# 3. 编辑配置(生成安全密码)
nano .env
```
**`.env` 必须配置项:**
```bash
# PostgreSQL 密码(必须修改!
# PostgreSQL 密码(必
POSTGRES_PASSWORD=your_secure_password_here
# JWT 密钥(推荐 - 重启后保持用户登录状态)
JWT_SECRET=your_jwt_secret_here
# TOTP 加密密钥(推荐 - 重启后保留双因素认证)
TOTP_ENCRYPTION_KEY=your_totp_key_here
# 可选:管理员账号
ADMIN_EMAIL=admin@example.com
ADMIN_PASSWORD=your_admin_password
# 可选:自定义端口
SERVER_PORT=8080
```
# 可选:安全配置
# 启用 URL 白名单验证false 则跳过白名单检查,仅做基本格式校验)
SECURITY_URL_ALLOWLIST_ENABLED=false
**生成安全密钥:**
```bash
# 生成 JWT_SECRET
openssl rand -hex 32
# 关闭白名单时,是否允许 http:// URL默认 false只允许 https://
# ⚠️ 警告:允许 HTTP 会暴露 API 密钥(明文传输)
# 仅建议在以下场景使用:
# - 开发/测试环境
# - 内部可信网络
# - 本地测试服务器http://localhost
# 生产环境:保持 false 或仅使用 HTTPS URL
SECURITY_URL_ALLOWLIST_ALLOW_INSECURE_HTTP=false
# 生成 TOTP_ENCRYPTION_KEY
openssl rand -hex 32
# 是否允许私有 IP 地址用于上游/定价/CRS内网部署时使用
SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS=false
# 生成 POSTGRES_PASSWORD
openssl rand -hex 32
```
```bash
# 4. 创建数据目录(本地版)
mkdir -p data postgres_data redis_data
# 5. 启动所有服务
# 选项 A本地目录版推荐 - 易于迁移)
docker-compose -f docker-compose.local.yml up -d
# 选项 B命名卷版简单设置
docker-compose up -d
# 6. 查看状态
docker-compose ps
docker-compose -f docker-compose.local.yml ps
# 7. 查看日志
docker-compose logs -f sub2api
docker-compose -f docker-compose.local.yml logs -f sub2api
```
#### 部署版本对比
| 版本 | 数据存储 | 迁移便利性 | 适用场景 |
|------|---------|-----------|---------|
| **docker-compose.local.yml** | 本地目录 | ✅ 简单(打包整个目录) | 生产环境、频繁备份 |
| **docker-compose.yml** | 命名卷 | ⚠️ 需要 docker 命令 | 简单设置 |
**推荐:** 使用 `docker-compose.local.yml`(脚本部署)以便更轻松地管理数据。
#### 访问
在浏览器中打开 `http://你的服务器IP:8080`
如果管理员密码是自动生成的,在日志中查找:
```bash
docker-compose -f docker-compose.local.yml logs sub2api | grep "admin password"
```
#### 升级
```bash
# 拉取最新镜像并重建容器
docker-compose pull
docker-compose up -d
docker-compose -f docker-compose.local.yml pull
docker-compose -f docker-compose.local.yml up -d
```
#### 轻松迁移(本地目录版)
使用 `docker-compose.local.yml` 时,可以轻松迁移到新服务器:
```bash
# 源服务器
docker-compose -f docker-compose.local.yml down
cd ..
tar czf sub2api-complete.tar.gz sub2api-deploy/
# 传输到新服务器
scp sub2api-complete.tar.gz user@new-server:/path/
# 新服务器
tar xzf sub2api-complete.tar.gz
cd sub2api-deploy/
docker-compose -f docker-compose.local.yml up -d
```
#### 常用命令
```bash
# 停止所有服务
docker-compose down
docker-compose -f docker-compose.local.yml down
# 重启
docker-compose restart
docker-compose -f docker-compose.local.yml restart
# 查看所有日志
docker-compose logs -f
docker-compose -f docker-compose.local.yml logs -f
# 删除所有数据(谨慎!)
docker-compose -f docker-compose.local.yml down
rm -rf data/ postgres_data/ redis_data/
```
---

View File

@@ -1 +1 @@
0.1.46
0.1.61

View File

@@ -70,6 +70,7 @@ func provideCleanup(
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
subscriptionExpiry *service.SubscriptionExpiryService,
usageCleanup *service.UsageCleanupService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
@@ -138,6 +139,10 @@ func provideCleanup(
accountExpiry.Stop()
return nil
}},
{"SubscriptionExpiryService", func() error {
subscriptionExpiry.Stop()
return nil
}},
{"PricingService", func() error {
pricing.Stop()
return nil

View File

@@ -63,7 +63,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService)
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
if err != nil {
return nil, err
}
totpCache := repository.NewTotpCache(redisClient)
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, totpService)
userHandler := handler.NewUserHandler(userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
@@ -75,6 +81,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
announcementRepository := repository.NewAnnouncementRepository(client)
announcementReadRepository := repository.NewAnnouncementReadRepository(client)
announcementService := service.NewAnnouncementService(announcementRepository, announcementReadRepository, userRepository, userSubscriptionRepository)
announcementHandler := handler.NewAnnouncementHandler(announcementService)
dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig)
dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig)
@@ -84,7 +94,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
}
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
accountRepository := repository.NewAccountRepository(client, db)
schedulerCache := repository.NewSchedulerCache(redisClient)
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
@@ -105,21 +116,23 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
httpUpstream := repository.NewHTTPUpstream(configConfig)
claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream)
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
usageCache := service.NewUsageCache()
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache)
identityCache := repository.NewIdentityCache(redisClient)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
gatewayCache := repository.NewGatewayCache(redisClient)
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
httpUpstream := repository.NewHTTPUpstream(configConfig)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator)
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
@@ -128,7 +141,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminRedeemHandler := admin.NewRedeemHandler(adminService)
promoHandler := admin.NewPromoHandler(promoService)
opsRepository := repository.NewOpsRepository(db)
schedulerCache := repository.NewSchedulerCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
@@ -137,7 +149,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
return nil, err
}
billingService := service.NewBillingService(configConfig, pricingService)
identityCache := repository.NewIdentityCache(redisClient)
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
@@ -161,11 +172,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
totpHandler := handler.NewTotpHandler(totpService)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
@@ -176,9 +188,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -211,6 +224,7 @@ func provideCleanup(
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
subscriptionExpiry *service.SubscriptionExpiryService,
usageCleanup *service.UsageCleanupService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
@@ -278,6 +292,10 @@ func provideCleanup(
accountExpiry.Stop()
return nil
}},
{"SubscriptionExpiryService", func() error {
subscriptionExpiry.Stop()
return nil
}},
{"PricingService", func() error {
pricing.Stop()
return nil

249
backend/ent/announcement.go Normal file
View File

@@ -0,0 +1,249 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"encoding/json"
"fmt"
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/internal/domain"
)
// Announcement is the model entity for the Announcement schema.
type Announcement struct {
config `json:"-"`
// ID of the ent.
ID int64 `json:"id,omitempty"`
// 公告标题
Title string `json:"title,omitempty"`
// 公告内容(支持 Markdown
Content string `json:"content,omitempty"`
// 状态: draft, active, archived
Status string `json:"status,omitempty"`
// 展示条件JSON 规则)
Targeting domain.AnnouncementTargeting `json:"targeting,omitempty"`
// 开始展示时间(为空表示立即生效)
StartsAt *time.Time `json:"starts_at,omitempty"`
// 结束展示时间(为空表示永久生效)
EndsAt *time.Time `json:"ends_at,omitempty"`
// 创建人用户ID管理员
CreatedBy *int64 `json:"created_by,omitempty"`
// 更新人用户ID管理员
UpdatedBy *int64 `json:"updated_by,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// UpdatedAt holds the value of the "updated_at" field.
UpdatedAt time.Time `json:"updated_at,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the AnnouncementQuery when eager-loading is set.
Edges AnnouncementEdges `json:"edges"`
selectValues sql.SelectValues
}
// AnnouncementEdges holds the relations/edges for other nodes in the graph.
type AnnouncementEdges struct {
// Reads holds the value of the reads edge.
Reads []*AnnouncementRead `json:"reads,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [1]bool
}
// ReadsOrErr returns the Reads value or an error if the edge
// was not loaded in eager-loading.
func (e AnnouncementEdges) ReadsOrErr() ([]*AnnouncementRead, error) {
if e.loadedTypes[0] {
return e.Reads, nil
}
return nil, &NotLoadedError{edge: "reads"}
}
// scanValues returns the types for scanning values from sql.Rows.
func (*Announcement) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case announcement.FieldTargeting:
values[i] = new([]byte)
case announcement.FieldID, announcement.FieldCreatedBy, announcement.FieldUpdatedBy:
values[i] = new(sql.NullInt64)
case announcement.FieldTitle, announcement.FieldContent, announcement.FieldStatus:
values[i] = new(sql.NullString)
case announcement.FieldStartsAt, announcement.FieldEndsAt, announcement.FieldCreatedAt, announcement.FieldUpdatedAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
}
}
return values, nil
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the Announcement fields.
func (_m *Announcement) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
for i := range columns {
switch columns[i] {
case announcement.FieldID:
value, ok := values[i].(*sql.NullInt64)
if !ok {
return fmt.Errorf("unexpected type %T for field id", value)
}
_m.ID = int64(value.Int64)
case announcement.FieldTitle:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field title", values[i])
} else if value.Valid {
_m.Title = value.String
}
case announcement.FieldContent:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field content", values[i])
} else if value.Valid {
_m.Content = value.String
}
case announcement.FieldStatus:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field status", values[i])
} else if value.Valid {
_m.Status = value.String
}
case announcement.FieldTargeting:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field targeting", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.Targeting); err != nil {
return fmt.Errorf("unmarshal field targeting: %w", err)
}
}
case announcement.FieldStartsAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field starts_at", values[i])
} else if value.Valid {
_m.StartsAt = new(time.Time)
*_m.StartsAt = value.Time
}
case announcement.FieldEndsAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field ends_at", values[i])
} else if value.Valid {
_m.EndsAt = new(time.Time)
*_m.EndsAt = value.Time
}
case announcement.FieldCreatedBy:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field created_by", values[i])
} else if value.Valid {
_m.CreatedBy = new(int64)
*_m.CreatedBy = value.Int64
}
case announcement.FieldUpdatedBy:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field updated_by", values[i])
} else if value.Valid {
_m.UpdatedBy = new(int64)
*_m.UpdatedBy = value.Int64
}
case announcement.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
} else if value.Valid {
_m.CreatedAt = value.Time
}
case announcement.FieldUpdatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
} else if value.Valid {
_m.UpdatedAt = value.Time
}
default:
_m.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// Value returns the ent.Value that was dynamically selected and assigned to the Announcement.
// This includes values selected through modifiers, order, etc.
func (_m *Announcement) Value(name string) (ent.Value, error) {
return _m.selectValues.Get(name)
}
// QueryReads queries the "reads" edge of the Announcement entity.
func (_m *Announcement) QueryReads() *AnnouncementReadQuery {
return NewAnnouncementClient(_m.config).QueryReads(_m)
}
// Update returns a builder for updating this Announcement.
// Note that you need to call Announcement.Unwrap() before calling this method if this Announcement
// was returned from a transaction, and the transaction was committed or rolled back.
func (_m *Announcement) Update() *AnnouncementUpdateOne {
return NewAnnouncementClient(_m.config).UpdateOne(_m)
}
// Unwrap unwraps the Announcement entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func (_m *Announcement) Unwrap() *Announcement {
_tx, ok := _m.config.driver.(*txDriver)
if !ok {
panic("ent: Announcement is not a transactional entity")
}
_m.config.driver = _tx.drv
return _m
}
// String implements the fmt.Stringer.
func (_m *Announcement) String() string {
var builder strings.Builder
builder.WriteString("Announcement(")
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
builder.WriteString("title=")
builder.WriteString(_m.Title)
builder.WriteString(", ")
builder.WriteString("content=")
builder.WriteString(_m.Content)
builder.WriteString(", ")
builder.WriteString("status=")
builder.WriteString(_m.Status)
builder.WriteString(", ")
builder.WriteString("targeting=")
builder.WriteString(fmt.Sprintf("%v", _m.Targeting))
builder.WriteString(", ")
if v := _m.StartsAt; v != nil {
builder.WriteString("starts_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
if v := _m.EndsAt; v != nil {
builder.WriteString("ends_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
if v := _m.CreatedBy; v != nil {
builder.WriteString("created_by=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
if v := _m.UpdatedBy; v != nil {
builder.WriteString("updated_by=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("updated_at=")
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
builder.WriteByte(')')
return builder.String()
}
// Announcements is a parsable slice of Announcement.
type Announcements []*Announcement

View File

@@ -0,0 +1,164 @@
// Code generated by ent, DO NOT EDIT.
package announcement
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
const (
// Label holds the string label denoting the announcement type in the database.
Label = "announcement"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldTitle holds the string denoting the title field in the database.
FieldTitle = "title"
// FieldContent holds the string denoting the content field in the database.
FieldContent = "content"
// FieldStatus holds the string denoting the status field in the database.
FieldStatus = "status"
// FieldTargeting holds the string denoting the targeting field in the database.
FieldTargeting = "targeting"
// FieldStartsAt holds the string denoting the starts_at field in the database.
FieldStartsAt = "starts_at"
// FieldEndsAt holds the string denoting the ends_at field in the database.
FieldEndsAt = "ends_at"
// FieldCreatedBy holds the string denoting the created_by field in the database.
FieldCreatedBy = "created_by"
// FieldUpdatedBy holds the string denoting the updated_by field in the database.
FieldUpdatedBy = "updated_by"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
FieldUpdatedAt = "updated_at"
// EdgeReads holds the string denoting the reads edge name in mutations.
EdgeReads = "reads"
// Table holds the table name of the announcement in the database.
Table = "announcements"
// ReadsTable is the table that holds the reads relation/edge.
ReadsTable = "announcement_reads"
// ReadsInverseTable is the table name for the AnnouncementRead entity.
// It exists in this package in order to avoid circular dependency with the "announcementread" package.
ReadsInverseTable = "announcement_reads"
// ReadsColumn is the table column denoting the reads relation/edge.
ReadsColumn = "announcement_id"
)
// Columns holds all SQL columns for announcement fields.
var Columns = []string{
FieldID,
FieldTitle,
FieldContent,
FieldStatus,
FieldTargeting,
FieldStartsAt,
FieldEndsAt,
FieldCreatedBy,
FieldUpdatedBy,
FieldCreatedAt,
FieldUpdatedAt,
}
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
if column == Columns[i] {
return true
}
}
return false
}
var (
// TitleValidator is a validator for the "title" field. It is called by the builders before save.
TitleValidator func(string) error
// ContentValidator is a validator for the "content" field. It is called by the builders before save.
ContentValidator func(string) error
// DefaultStatus holds the default value on creation for the "status" field.
DefaultStatus string
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
StatusValidator func(string) error
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
DefaultUpdatedAt func() time.Time
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
UpdateDefaultUpdatedAt func() time.Time
)
// OrderOption defines the ordering options for the Announcement queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByTitle orders the results by the title field.
func ByTitle(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTitle, opts...).ToFunc()
}
// ByContent orders the results by the content field.
func ByContent(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldContent, opts...).ToFunc()
}
// ByStatus orders the results by the status field.
func ByStatus(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStatus, opts...).ToFunc()
}
// ByStartsAt orders the results by the starts_at field.
func ByStartsAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStartsAt, opts...).ToFunc()
}
// ByEndsAt orders the results by the ends_at field.
func ByEndsAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldEndsAt, opts...).ToFunc()
}
// ByCreatedBy orders the results by the created_by field.
func ByCreatedBy(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedBy, opts...).ToFunc()
}
// ByUpdatedBy orders the results by the updated_by field.
func ByUpdatedBy(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedBy, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByUpdatedAt orders the results by the updated_at field.
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// ByReadsCount orders the results by reads count.
func ByReadsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newReadsStep(), opts...)
}
}
// ByReads orders the results by reads terms.
func ByReads(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newReadsStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
func newReadsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(ReadsInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, ReadsTable, ReadsColumn),
)
}

View File

@@ -0,0 +1,624 @@
// Code generated by ent, DO NOT EDIT.
package announcement
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// ID filters vertices based on their ID field.
func ID(id int64) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldID, id))
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id int64) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldID, id))
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id int64) predicate.Announcement {
return predicate.Announcement(sql.FieldNEQ(FieldID, id))
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...int64) predicate.Announcement {
return predicate.Announcement(sql.FieldIn(FieldID, ids...))
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...int64) predicate.Announcement {
return predicate.Announcement(sql.FieldNotIn(FieldID, ids...))
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id int64) predicate.Announcement {
return predicate.Announcement(sql.FieldGT(FieldID, id))
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id int64) predicate.Announcement {
return predicate.Announcement(sql.FieldGTE(FieldID, id))
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id int64) predicate.Announcement {
return predicate.Announcement(sql.FieldLT(FieldID, id))
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id int64) predicate.Announcement {
return predicate.Announcement(sql.FieldLTE(FieldID, id))
}
// Title applies equality check predicate on the "title" field. It's identical to TitleEQ.
func Title(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldTitle, v))
}
// Content applies equality check predicate on the "content" field. It's identical to ContentEQ.
func Content(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldContent, v))
}
// Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
func Status(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldStatus, v))
}
// StartsAt applies equality check predicate on the "starts_at" field. It's identical to StartsAtEQ.
func StartsAt(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldStartsAt, v))
}
// EndsAt applies equality check predicate on the "ends_at" field. It's identical to EndsAtEQ.
func EndsAt(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldEndsAt, v))
}
// CreatedBy applies equality check predicate on the "created_by" field. It's identical to CreatedByEQ.
func CreatedBy(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldCreatedBy, v))
}
// UpdatedBy applies equality check predicate on the "updated_by" field. It's identical to UpdatedByEQ.
func UpdatedBy(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldUpdatedBy, v))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldCreatedAt, v))
}
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
func UpdatedAt(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldUpdatedAt, v))
}
// TitleEQ applies the EQ predicate on the "title" field.
func TitleEQ(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldTitle, v))
}
// TitleNEQ applies the NEQ predicate on the "title" field.
func TitleNEQ(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldNEQ(FieldTitle, v))
}
// TitleIn applies the In predicate on the "title" field.
func TitleIn(vs ...string) predicate.Announcement {
return predicate.Announcement(sql.FieldIn(FieldTitle, vs...))
}
// TitleNotIn applies the NotIn predicate on the "title" field.
func TitleNotIn(vs ...string) predicate.Announcement {
return predicate.Announcement(sql.FieldNotIn(FieldTitle, vs...))
}
// TitleGT applies the GT predicate on the "title" field.
func TitleGT(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldGT(FieldTitle, v))
}
// TitleGTE applies the GTE predicate on the "title" field.
func TitleGTE(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldGTE(FieldTitle, v))
}
// TitleLT applies the LT predicate on the "title" field.
func TitleLT(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldLT(FieldTitle, v))
}
// TitleLTE applies the LTE predicate on the "title" field.
func TitleLTE(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldLTE(FieldTitle, v))
}
// TitleContains applies the Contains predicate on the "title" field.
func TitleContains(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldContains(FieldTitle, v))
}
// TitleHasPrefix applies the HasPrefix predicate on the "title" field.
func TitleHasPrefix(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldHasPrefix(FieldTitle, v))
}
// TitleHasSuffix applies the HasSuffix predicate on the "title" field.
func TitleHasSuffix(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldHasSuffix(FieldTitle, v))
}
// TitleEqualFold applies the EqualFold predicate on the "title" field.
func TitleEqualFold(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldEqualFold(FieldTitle, v))
}
// TitleContainsFold applies the ContainsFold predicate on the "title" field.
func TitleContainsFold(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldContainsFold(FieldTitle, v))
}
// ContentEQ applies the EQ predicate on the "content" field.
func ContentEQ(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldContent, v))
}
// ContentNEQ applies the NEQ predicate on the "content" field.
func ContentNEQ(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldNEQ(FieldContent, v))
}
// ContentIn applies the In predicate on the "content" field.
func ContentIn(vs ...string) predicate.Announcement {
return predicate.Announcement(sql.FieldIn(FieldContent, vs...))
}
// ContentNotIn applies the NotIn predicate on the "content" field.
func ContentNotIn(vs ...string) predicate.Announcement {
return predicate.Announcement(sql.FieldNotIn(FieldContent, vs...))
}
// ContentGT applies the GT predicate on the "content" field.
func ContentGT(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldGT(FieldContent, v))
}
// ContentGTE applies the GTE predicate on the "content" field.
func ContentGTE(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldGTE(FieldContent, v))
}
// ContentLT applies the LT predicate on the "content" field.
func ContentLT(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldLT(FieldContent, v))
}
// ContentLTE applies the LTE predicate on the "content" field.
func ContentLTE(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldLTE(FieldContent, v))
}
// ContentContains applies the Contains predicate on the "content" field.
func ContentContains(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldContains(FieldContent, v))
}
// ContentHasPrefix applies the HasPrefix predicate on the "content" field.
func ContentHasPrefix(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldHasPrefix(FieldContent, v))
}
// ContentHasSuffix applies the HasSuffix predicate on the "content" field.
func ContentHasSuffix(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldHasSuffix(FieldContent, v))
}
// ContentEqualFold applies the EqualFold predicate on the "content" field.
func ContentEqualFold(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldEqualFold(FieldContent, v))
}
// ContentContainsFold applies the ContainsFold predicate on the "content" field.
func ContentContainsFold(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldContainsFold(FieldContent, v))
}
// StatusEQ applies the EQ predicate on the "status" field.
func StatusEQ(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldStatus, v))
}
// StatusNEQ applies the NEQ predicate on the "status" field.
func StatusNEQ(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldNEQ(FieldStatus, v))
}
// StatusIn applies the In predicate on the "status" field.
func StatusIn(vs ...string) predicate.Announcement {
return predicate.Announcement(sql.FieldIn(FieldStatus, vs...))
}
// StatusNotIn applies the NotIn predicate on the "status" field.
func StatusNotIn(vs ...string) predicate.Announcement {
return predicate.Announcement(sql.FieldNotIn(FieldStatus, vs...))
}
// StatusGT applies the GT predicate on the "status" field.
func StatusGT(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldGT(FieldStatus, v))
}
// StatusGTE applies the GTE predicate on the "status" field.
func StatusGTE(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldGTE(FieldStatus, v))
}
// StatusLT applies the LT predicate on the "status" field.
func StatusLT(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldLT(FieldStatus, v))
}
// StatusLTE applies the LTE predicate on the "status" field.
func StatusLTE(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldLTE(FieldStatus, v))
}
// StatusContains applies the Contains predicate on the "status" field.
func StatusContains(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldContains(FieldStatus, v))
}
// StatusHasPrefix applies the HasPrefix predicate on the "status" field.
func StatusHasPrefix(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldHasPrefix(FieldStatus, v))
}
// StatusHasSuffix applies the HasSuffix predicate on the "status" field.
func StatusHasSuffix(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldHasSuffix(FieldStatus, v))
}
// StatusEqualFold applies the EqualFold predicate on the "status" field.
func StatusEqualFold(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldEqualFold(FieldStatus, v))
}
// StatusContainsFold applies the ContainsFold predicate on the "status" field.
func StatusContainsFold(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldContainsFold(FieldStatus, v))
}
// TargetingIsNil applies the IsNil predicate on the "targeting" field.
func TargetingIsNil() predicate.Announcement {
return predicate.Announcement(sql.FieldIsNull(FieldTargeting))
}
// TargetingNotNil applies the NotNil predicate on the "targeting" field.
func TargetingNotNil() predicate.Announcement {
return predicate.Announcement(sql.FieldNotNull(FieldTargeting))
}
// StartsAtEQ applies the EQ predicate on the "starts_at" field.
func StartsAtEQ(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldStartsAt, v))
}
// StartsAtNEQ applies the NEQ predicate on the "starts_at" field.
func StartsAtNEQ(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldNEQ(FieldStartsAt, v))
}
// StartsAtIn applies the In predicate on the "starts_at" field.
func StartsAtIn(vs ...time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldIn(FieldStartsAt, vs...))
}
// StartsAtNotIn applies the NotIn predicate on the "starts_at" field.
func StartsAtNotIn(vs ...time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldNotIn(FieldStartsAt, vs...))
}
// StartsAtGT applies the GT predicate on the "starts_at" field.
func StartsAtGT(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldGT(FieldStartsAt, v))
}
// StartsAtGTE applies the GTE predicate on the "starts_at" field.
func StartsAtGTE(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldGTE(FieldStartsAt, v))
}
// StartsAtLT applies the LT predicate on the "starts_at" field.
func StartsAtLT(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldLT(FieldStartsAt, v))
}
// StartsAtLTE applies the LTE predicate on the "starts_at" field.
func StartsAtLTE(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldLTE(FieldStartsAt, v))
}
// StartsAtIsNil applies the IsNil predicate on the "starts_at" field.
func StartsAtIsNil() predicate.Announcement {
return predicate.Announcement(sql.FieldIsNull(FieldStartsAt))
}
// StartsAtNotNil applies the NotNil predicate on the "starts_at" field.
func StartsAtNotNil() predicate.Announcement {
return predicate.Announcement(sql.FieldNotNull(FieldStartsAt))
}
// EndsAtEQ applies the EQ predicate on the "ends_at" field.
func EndsAtEQ(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldEndsAt, v))
}
// EndsAtNEQ applies the NEQ predicate on the "ends_at" field.
func EndsAtNEQ(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldNEQ(FieldEndsAt, v))
}
// EndsAtIn applies the In predicate on the "ends_at" field.
func EndsAtIn(vs ...time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldIn(FieldEndsAt, vs...))
}
// EndsAtNotIn applies the NotIn predicate on the "ends_at" field.
func EndsAtNotIn(vs ...time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldNotIn(FieldEndsAt, vs...))
}
// EndsAtGT applies the GT predicate on the "ends_at" field.
func EndsAtGT(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldGT(FieldEndsAt, v))
}
// EndsAtGTE applies the GTE predicate on the "ends_at" field.
func EndsAtGTE(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldGTE(FieldEndsAt, v))
}
// EndsAtLT applies the LT predicate on the "ends_at" field.
func EndsAtLT(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldLT(FieldEndsAt, v))
}
// EndsAtLTE applies the LTE predicate on the "ends_at" field.
func EndsAtLTE(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldLTE(FieldEndsAt, v))
}
// EndsAtIsNil applies the IsNil predicate on the "ends_at" field.
func EndsAtIsNil() predicate.Announcement {
return predicate.Announcement(sql.FieldIsNull(FieldEndsAt))
}
// EndsAtNotNil applies the NotNil predicate on the "ends_at" field.
func EndsAtNotNil() predicate.Announcement {
return predicate.Announcement(sql.FieldNotNull(FieldEndsAt))
}
// CreatedByEQ applies the EQ predicate on the "created_by" field.
func CreatedByEQ(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldCreatedBy, v))
}
// CreatedByNEQ applies the NEQ predicate on the "created_by" field.
func CreatedByNEQ(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldNEQ(FieldCreatedBy, v))
}
// CreatedByIn applies the In predicate on the "created_by" field.
func CreatedByIn(vs ...int64) predicate.Announcement {
return predicate.Announcement(sql.FieldIn(FieldCreatedBy, vs...))
}
// CreatedByNotIn applies the NotIn predicate on the "created_by" field.
func CreatedByNotIn(vs ...int64) predicate.Announcement {
return predicate.Announcement(sql.FieldNotIn(FieldCreatedBy, vs...))
}
// CreatedByGT applies the GT predicate on the "created_by" field.
func CreatedByGT(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldGT(FieldCreatedBy, v))
}
// CreatedByGTE applies the GTE predicate on the "created_by" field.
func CreatedByGTE(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldGTE(FieldCreatedBy, v))
}
// CreatedByLT applies the LT predicate on the "created_by" field.
func CreatedByLT(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldLT(FieldCreatedBy, v))
}
// CreatedByLTE applies the LTE predicate on the "created_by" field.
func CreatedByLTE(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldLTE(FieldCreatedBy, v))
}
// CreatedByIsNil applies the IsNil predicate on the "created_by" field.
func CreatedByIsNil() predicate.Announcement {
return predicate.Announcement(sql.FieldIsNull(FieldCreatedBy))
}
// CreatedByNotNil applies the NotNil predicate on the "created_by" field.
func CreatedByNotNil() predicate.Announcement {
return predicate.Announcement(sql.FieldNotNull(FieldCreatedBy))
}
// UpdatedByEQ applies the EQ predicate on the "updated_by" field.
func UpdatedByEQ(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldUpdatedBy, v))
}
// UpdatedByNEQ applies the NEQ predicate on the "updated_by" field.
func UpdatedByNEQ(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldNEQ(FieldUpdatedBy, v))
}
// UpdatedByIn applies the In predicate on the "updated_by" field.
func UpdatedByIn(vs ...int64) predicate.Announcement {
return predicate.Announcement(sql.FieldIn(FieldUpdatedBy, vs...))
}
// UpdatedByNotIn applies the NotIn predicate on the "updated_by" field.
func UpdatedByNotIn(vs ...int64) predicate.Announcement {
return predicate.Announcement(sql.FieldNotIn(FieldUpdatedBy, vs...))
}
// UpdatedByGT applies the GT predicate on the "updated_by" field.
func UpdatedByGT(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldGT(FieldUpdatedBy, v))
}
// UpdatedByGTE applies the GTE predicate on the "updated_by" field.
func UpdatedByGTE(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldGTE(FieldUpdatedBy, v))
}
// UpdatedByLT applies the LT predicate on the "updated_by" field.
func UpdatedByLT(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldLT(FieldUpdatedBy, v))
}
// UpdatedByLTE applies the LTE predicate on the "updated_by" field.
func UpdatedByLTE(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldLTE(FieldUpdatedBy, v))
}
// UpdatedByIsNil applies the IsNil predicate on the "updated_by" field.
func UpdatedByIsNil() predicate.Announcement {
return predicate.Announcement(sql.FieldIsNull(FieldUpdatedBy))
}
// UpdatedByNotNil applies the NotNil predicate on the "updated_by" field.
func UpdatedByNotNil() predicate.Announcement {
return predicate.Announcement(sql.FieldNotNull(FieldUpdatedBy))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldCreatedAt, v))
}
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
func CreatedAtNEQ(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldNEQ(FieldCreatedAt, v))
}
// CreatedAtIn applies the In predicate on the "created_at" field.
func CreatedAtIn(vs ...time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldIn(FieldCreatedAt, vs...))
}
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
func CreatedAtNotIn(vs ...time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldNotIn(FieldCreatedAt, vs...))
}
// CreatedAtGT applies the GT predicate on the "created_at" field.
func CreatedAtGT(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldGT(FieldCreatedAt, v))
}
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
func CreatedAtGTE(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldGTE(FieldCreatedAt, v))
}
// CreatedAtLT applies the LT predicate on the "created_at" field.
func CreatedAtLT(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldLT(FieldCreatedAt, v))
}
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
func CreatedAtLTE(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldLTE(FieldCreatedAt, v))
}
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
func UpdatedAtEQ(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldUpdatedAt, v))
}
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
func UpdatedAtNEQ(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldNEQ(FieldUpdatedAt, v))
}
// UpdatedAtIn applies the In predicate on the "updated_at" field.
func UpdatedAtIn(vs ...time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldIn(FieldUpdatedAt, vs...))
}
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
func UpdatedAtNotIn(vs ...time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldNotIn(FieldUpdatedAt, vs...))
}
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
func UpdatedAtGT(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldGT(FieldUpdatedAt, v))
}
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
func UpdatedAtGTE(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldGTE(FieldUpdatedAt, v))
}
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
func UpdatedAtLT(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldLT(FieldUpdatedAt, v))
}
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
func UpdatedAtLTE(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldLTE(FieldUpdatedAt, v))
}
// HasReads applies the HasEdge predicate on the "reads" edge.
func HasReads() predicate.Announcement {
return predicate.Announcement(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, ReadsTable, ReadsColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasReadsWith applies the HasEdge predicate on the "reads" edge with a given conditions (other predicates).
func HasReadsWith(preds ...predicate.AnnouncementRead) predicate.Announcement {
return predicate.Announcement(func(s *sql.Selector) {
step := newReadsStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.Announcement) predicate.Announcement {
return predicate.Announcement(sql.AndPredicates(predicates...))
}
// Or groups predicates with the OR operator between them.
func Or(predicates ...predicate.Announcement) predicate.Announcement {
return predicate.Announcement(sql.OrPredicates(predicates...))
}
// Not applies the not operator on the given predicate.
func Not(p predicate.Announcement) predicate.Announcement {
return predicate.Announcement(sql.NotPredicates(p))
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,88 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// AnnouncementDelete is the builder for deleting a Announcement entity.
type AnnouncementDelete struct {
config
hooks []Hook
mutation *AnnouncementMutation
}
// Where appends a list predicates to the AnnouncementDelete builder.
func (_d *AnnouncementDelete) Where(ps ...predicate.Announcement) *AnnouncementDelete {
_d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (_d *AnnouncementDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *AnnouncementDelete) ExecX(ctx context.Context) int {
n, err := _d.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (_d *AnnouncementDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(announcement.Table, sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64))
if ps := _d.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
_d.mutation.done = true
return affected, err
}
// AnnouncementDeleteOne is the builder for deleting a single Announcement entity.
type AnnouncementDeleteOne struct {
_d *AnnouncementDelete
}
// Where appends a list predicates to the AnnouncementDelete builder.
func (_d *AnnouncementDeleteOne) Where(ps ...predicate.Announcement) *AnnouncementDeleteOne {
_d._d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query.
func (_d *AnnouncementDeleteOne) Exec(ctx context.Context) error {
n, err := _d._d.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{announcement.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *AnnouncementDeleteOne) ExecX(ctx context.Context) {
if err := _d.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,643 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"database/sql/driver"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// AnnouncementQuery is the builder for querying Announcement entities.
type AnnouncementQuery struct {
config
ctx *QueryContext
order []announcement.OrderOption
inters []Interceptor
predicates []predicate.Announcement
withReads *AnnouncementReadQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the AnnouncementQuery builder.
func (_q *AnnouncementQuery) Where(ps ...predicate.Announcement) *AnnouncementQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *AnnouncementQuery) Limit(limit int) *AnnouncementQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *AnnouncementQuery) Offset(offset int) *AnnouncementQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *AnnouncementQuery) Unique(unique bool) *AnnouncementQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *AnnouncementQuery) Order(o ...announcement.OrderOption) *AnnouncementQuery {
_q.order = append(_q.order, o...)
return _q
}
// QueryReads chains the current query on the "reads" edge.
func (_q *AnnouncementQuery) QueryReads() *AnnouncementReadQuery {
query := (&AnnouncementReadClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(announcement.Table, announcement.FieldID, selector),
sqlgraph.To(announcementread.Table, announcementread.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, announcement.ReadsTable, announcement.ReadsColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// First returns the first Announcement entity from the query.
// Returns a *NotFoundError when no Announcement was found.
func (_q *AnnouncementQuery) First(ctx context.Context) (*Announcement, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{announcement.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *AnnouncementQuery) FirstX(ctx context.Context) *Announcement {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first Announcement ID from the query.
// Returns a *NotFoundError when no Announcement ID was found.
func (_q *AnnouncementQuery) FirstID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{announcement.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *AnnouncementQuery) FirstIDX(ctx context.Context) int64 {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single Announcement entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one Announcement entity is found.
// Returns a *NotFoundError when no Announcement entities are found.
func (_q *AnnouncementQuery) Only(ctx context.Context) (*Announcement, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{announcement.Label}
default:
return nil, &NotSingularError{announcement.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *AnnouncementQuery) OnlyX(ctx context.Context) *Announcement {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only Announcement ID in the query.
// Returns a *NotSingularError when more than one Announcement ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *AnnouncementQuery) OnlyID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{announcement.Label}
default:
err = &NotSingularError{announcement.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *AnnouncementQuery) OnlyIDX(ctx context.Context) int64 {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of Announcements.
func (_q *AnnouncementQuery) All(ctx context.Context) ([]*Announcement, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*Announcement, *AnnouncementQuery]()
return withInterceptors[[]*Announcement](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *AnnouncementQuery) AllX(ctx context.Context) []*Announcement {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of Announcement IDs.
func (_q *AnnouncementQuery) IDs(ctx context.Context) (ids []int64, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(announcement.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *AnnouncementQuery) IDsX(ctx context.Context) []int64 {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *AnnouncementQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*AnnouncementQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *AnnouncementQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *AnnouncementQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *AnnouncementQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the AnnouncementQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *AnnouncementQuery) Clone() *AnnouncementQuery {
if _q == nil {
return nil
}
return &AnnouncementQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]announcement.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.Announcement{}, _q.predicates...),
withReads: _q.withReads.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// WithReads tells the query-builder to eager-load the nodes that are connected to
// the "reads" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *AnnouncementQuery) WithReads(opts ...func(*AnnouncementReadQuery)) *AnnouncementQuery {
query := (&AnnouncementReadClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withReads = query
return _q
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// Title string `json:"title,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.Announcement.Query().
// GroupBy(announcement.FieldTitle).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *AnnouncementQuery) GroupBy(field string, fields ...string) *AnnouncementGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &AnnouncementGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = announcement.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// Title string `json:"title,omitempty"`
// }
//
// client.Announcement.Query().
// Select(announcement.FieldTitle).
// Scan(ctx, &v)
func (_q *AnnouncementQuery) Select(fields ...string) *AnnouncementSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &AnnouncementSelect{AnnouncementQuery: _q}
sbuild.label = announcement.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a AnnouncementSelect configured with the given aggregations.
func (_q *AnnouncementQuery) Aggregate(fns ...AggregateFunc) *AnnouncementSelect {
return _q.Select().Aggregate(fns...)
}
func (_q *AnnouncementQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !announcement.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *AnnouncementQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Announcement, error) {
var (
nodes = []*Announcement{}
_spec = _q.querySpec()
loadedTypes = [1]bool{
_q.withReads != nil,
}
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*Announcement).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &Announcement{config: _q.config}
nodes = append(nodes, node)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
if query := _q.withReads; query != nil {
if err := _q.loadReads(ctx, query, nodes,
func(n *Announcement) { n.Edges.Reads = []*AnnouncementRead{} },
func(n *Announcement, e *AnnouncementRead) { n.Edges.Reads = append(n.Edges.Reads, e) }); err != nil {
return nil, err
}
}
return nodes, nil
}
func (_q *AnnouncementQuery) loadReads(ctx context.Context, query *AnnouncementReadQuery, nodes []*Announcement, init func(*Announcement), assign func(*Announcement, *AnnouncementRead)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*Announcement)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(announcementread.FieldAnnouncementID)
}
query.Where(predicate.AnnouncementRead(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(announcement.ReadsColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.AnnouncementID
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "announcement_id" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *AnnouncementQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *AnnouncementQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(announcement.Table, announcement.Columns, sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, announcement.FieldID)
for i := range fields {
if fields[i] != announcement.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *AnnouncementQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(announcement.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = announcement.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *AnnouncementQuery) ForUpdate(opts ...sql.LockOption) *AnnouncementQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *AnnouncementQuery) ForShare(opts ...sql.LockOption) *AnnouncementQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// AnnouncementGroupBy is the group-by builder for Announcement entities.
type AnnouncementGroupBy struct {
selector
build *AnnouncementQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *AnnouncementGroupBy) Aggregate(fns ...AggregateFunc) *AnnouncementGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *AnnouncementGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*AnnouncementQuery, *AnnouncementGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *AnnouncementGroupBy) sqlScan(ctx context.Context, root *AnnouncementQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// AnnouncementSelect is the builder for selecting fields of Announcement entities.
type AnnouncementSelect struct {
*AnnouncementQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *AnnouncementSelect) Aggregate(fns ...AggregateFunc) *AnnouncementSelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *AnnouncementSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*AnnouncementQuery, *AnnouncementSelect](ctx, _s.AnnouncementQuery, _s, _s.inters, v)
}
func (_s *AnnouncementSelect) sqlScan(ctx context.Context, root *AnnouncementQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}

View File

@@ -0,0 +1,824 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/internal/domain"
)
// AnnouncementUpdate is the builder for updating Announcement entities.
type AnnouncementUpdate struct {
config
hooks []Hook
mutation *AnnouncementMutation
}
// Where appends a list predicates to the AnnouncementUpdate builder.
func (_u *AnnouncementUpdate) Where(ps ...predicate.Announcement) *AnnouncementUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetTitle sets the "title" field.
func (_u *AnnouncementUpdate) SetTitle(v string) *AnnouncementUpdate {
_u.mutation.SetTitle(v)
return _u
}
// SetNillableTitle sets the "title" field if the given value is not nil.
func (_u *AnnouncementUpdate) SetNillableTitle(v *string) *AnnouncementUpdate {
if v != nil {
_u.SetTitle(*v)
}
return _u
}
// SetContent sets the "content" field.
func (_u *AnnouncementUpdate) SetContent(v string) *AnnouncementUpdate {
_u.mutation.SetContent(v)
return _u
}
// SetNillableContent sets the "content" field if the given value is not nil.
func (_u *AnnouncementUpdate) SetNillableContent(v *string) *AnnouncementUpdate {
if v != nil {
_u.SetContent(*v)
}
return _u
}
// SetStatus sets the "status" field.
func (_u *AnnouncementUpdate) SetStatus(v string) *AnnouncementUpdate {
_u.mutation.SetStatus(v)
return _u
}
// SetNillableStatus sets the "status" field if the given value is not nil.
func (_u *AnnouncementUpdate) SetNillableStatus(v *string) *AnnouncementUpdate {
if v != nil {
_u.SetStatus(*v)
}
return _u
}
// SetTargeting sets the "targeting" field.
func (_u *AnnouncementUpdate) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpdate {
_u.mutation.SetTargeting(v)
return _u
}
// SetNillableTargeting sets the "targeting" field if the given value is not nil.
func (_u *AnnouncementUpdate) SetNillableTargeting(v *domain.AnnouncementTargeting) *AnnouncementUpdate {
if v != nil {
_u.SetTargeting(*v)
}
return _u
}
// ClearTargeting clears the value of the "targeting" field.
func (_u *AnnouncementUpdate) ClearTargeting() *AnnouncementUpdate {
_u.mutation.ClearTargeting()
return _u
}
// SetStartsAt sets the "starts_at" field.
func (_u *AnnouncementUpdate) SetStartsAt(v time.Time) *AnnouncementUpdate {
_u.mutation.SetStartsAt(v)
return _u
}
// SetNillableStartsAt sets the "starts_at" field if the given value is not nil.
func (_u *AnnouncementUpdate) SetNillableStartsAt(v *time.Time) *AnnouncementUpdate {
if v != nil {
_u.SetStartsAt(*v)
}
return _u
}
// ClearStartsAt clears the value of the "starts_at" field.
func (_u *AnnouncementUpdate) ClearStartsAt() *AnnouncementUpdate {
_u.mutation.ClearStartsAt()
return _u
}
// SetEndsAt sets the "ends_at" field.
func (_u *AnnouncementUpdate) SetEndsAt(v time.Time) *AnnouncementUpdate {
_u.mutation.SetEndsAt(v)
return _u
}
// SetNillableEndsAt sets the "ends_at" field if the given value is not nil.
func (_u *AnnouncementUpdate) SetNillableEndsAt(v *time.Time) *AnnouncementUpdate {
if v != nil {
_u.SetEndsAt(*v)
}
return _u
}
// ClearEndsAt clears the value of the "ends_at" field.
func (_u *AnnouncementUpdate) ClearEndsAt() *AnnouncementUpdate {
_u.mutation.ClearEndsAt()
return _u
}
// SetCreatedBy sets the "created_by" field.
func (_u *AnnouncementUpdate) SetCreatedBy(v int64) *AnnouncementUpdate {
_u.mutation.ResetCreatedBy()
_u.mutation.SetCreatedBy(v)
return _u
}
// SetNillableCreatedBy sets the "created_by" field if the given value is not nil.
func (_u *AnnouncementUpdate) SetNillableCreatedBy(v *int64) *AnnouncementUpdate {
if v != nil {
_u.SetCreatedBy(*v)
}
return _u
}
// AddCreatedBy adds value to the "created_by" field.
func (_u *AnnouncementUpdate) AddCreatedBy(v int64) *AnnouncementUpdate {
_u.mutation.AddCreatedBy(v)
return _u
}
// ClearCreatedBy clears the value of the "created_by" field.
func (_u *AnnouncementUpdate) ClearCreatedBy() *AnnouncementUpdate {
_u.mutation.ClearCreatedBy()
return _u
}
// SetUpdatedBy sets the "updated_by" field.
func (_u *AnnouncementUpdate) SetUpdatedBy(v int64) *AnnouncementUpdate {
_u.mutation.ResetUpdatedBy()
_u.mutation.SetUpdatedBy(v)
return _u
}
// SetNillableUpdatedBy sets the "updated_by" field if the given value is not nil.
func (_u *AnnouncementUpdate) SetNillableUpdatedBy(v *int64) *AnnouncementUpdate {
if v != nil {
_u.SetUpdatedBy(*v)
}
return _u
}
// AddUpdatedBy adds value to the "updated_by" field.
func (_u *AnnouncementUpdate) AddUpdatedBy(v int64) *AnnouncementUpdate {
_u.mutation.AddUpdatedBy(v)
return _u
}
// ClearUpdatedBy clears the value of the "updated_by" field.
func (_u *AnnouncementUpdate) ClearUpdatedBy() *AnnouncementUpdate {
_u.mutation.ClearUpdatedBy()
return _u
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *AnnouncementUpdate) SetUpdatedAt(v time.Time) *AnnouncementUpdate {
_u.mutation.SetUpdatedAt(v)
return _u
}
// AddReadIDs adds the "reads" edge to the AnnouncementRead entity by IDs.
func (_u *AnnouncementUpdate) AddReadIDs(ids ...int64) *AnnouncementUpdate {
_u.mutation.AddReadIDs(ids...)
return _u
}
// AddReads adds the "reads" edges to the AnnouncementRead entity.
func (_u *AnnouncementUpdate) AddReads(v ...*AnnouncementRead) *AnnouncementUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddReadIDs(ids...)
}
// Mutation returns the AnnouncementMutation object of the builder.
func (_u *AnnouncementUpdate) Mutation() *AnnouncementMutation {
return _u.mutation
}
// ClearReads clears all "reads" edges to the AnnouncementRead entity.
func (_u *AnnouncementUpdate) ClearReads() *AnnouncementUpdate {
_u.mutation.ClearReads()
return _u
}
// RemoveReadIDs removes the "reads" edge to AnnouncementRead entities by IDs.
func (_u *AnnouncementUpdate) RemoveReadIDs(ids ...int64) *AnnouncementUpdate {
_u.mutation.RemoveReadIDs(ids...)
return _u
}
// RemoveReads removes "reads" edges to AnnouncementRead entities.
func (_u *AnnouncementUpdate) RemoveReads(v ...*AnnouncementRead) *AnnouncementUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveReadIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *AnnouncementUpdate) Save(ctx context.Context) (int, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *AnnouncementUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *AnnouncementUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *AnnouncementUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *AnnouncementUpdate) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := announcement.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *AnnouncementUpdate) check() error {
if v, ok := _u.mutation.Title(); ok {
if err := announcement.TitleValidator(v); err != nil {
return &ValidationError{Name: "title", err: fmt.Errorf(`ent: validator failed for field "Announcement.title": %w`, err)}
}
}
if v, ok := _u.mutation.Content(); ok {
if err := announcement.ContentValidator(v); err != nil {
return &ValidationError{Name: "content", err: fmt.Errorf(`ent: validator failed for field "Announcement.content": %w`, err)}
}
}
if v, ok := _u.mutation.Status(); ok {
if err := announcement.StatusValidator(v); err != nil {
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)}
}
}
return nil
}
func (_u *AnnouncementUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(announcement.Table, announcement.Columns, sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.Title(); ok {
_spec.SetField(announcement.FieldTitle, field.TypeString, value)
}
if value, ok := _u.mutation.Content(); ok {
_spec.SetField(announcement.FieldContent, field.TypeString, value)
}
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(announcement.FieldStatus, field.TypeString, value)
}
if value, ok := _u.mutation.Targeting(); ok {
_spec.SetField(announcement.FieldTargeting, field.TypeJSON, value)
}
if _u.mutation.TargetingCleared() {
_spec.ClearField(announcement.FieldTargeting, field.TypeJSON)
}
if value, ok := _u.mutation.StartsAt(); ok {
_spec.SetField(announcement.FieldStartsAt, field.TypeTime, value)
}
if _u.mutation.StartsAtCleared() {
_spec.ClearField(announcement.FieldStartsAt, field.TypeTime)
}
if value, ok := _u.mutation.EndsAt(); ok {
_spec.SetField(announcement.FieldEndsAt, field.TypeTime, value)
}
if _u.mutation.EndsAtCleared() {
_spec.ClearField(announcement.FieldEndsAt, field.TypeTime)
}
if value, ok := _u.mutation.CreatedBy(); ok {
_spec.SetField(announcement.FieldCreatedBy, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedCreatedBy(); ok {
_spec.AddField(announcement.FieldCreatedBy, field.TypeInt64, value)
}
if _u.mutation.CreatedByCleared() {
_spec.ClearField(announcement.FieldCreatedBy, field.TypeInt64)
}
if value, ok := _u.mutation.UpdatedBy(); ok {
_spec.SetField(announcement.FieldUpdatedBy, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedUpdatedBy(); ok {
_spec.AddField(announcement.FieldUpdatedBy, field.TypeInt64, value)
}
if _u.mutation.UpdatedByCleared() {
_spec.ClearField(announcement.FieldUpdatedBy, field.TypeInt64)
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(announcement.FieldUpdatedAt, field.TypeTime, value)
}
if _u.mutation.ReadsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: announcement.ReadsTable,
Columns: []string{announcement.ReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedReadsIDs(); len(nodes) > 0 && !_u.mutation.ReadsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: announcement.ReadsTable,
Columns: []string{announcement.ReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.ReadsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: announcement.ReadsTable,
Columns: []string{announcement.ReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{announcement.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// AnnouncementUpdateOne is the builder for updating a single Announcement entity.
type AnnouncementUpdateOne struct {
config
fields []string
hooks []Hook
mutation *AnnouncementMutation
}
// SetTitle sets the "title" field.
func (_u *AnnouncementUpdateOne) SetTitle(v string) *AnnouncementUpdateOne {
_u.mutation.SetTitle(v)
return _u
}
// SetNillableTitle sets the "title" field if the given value is not nil.
func (_u *AnnouncementUpdateOne) SetNillableTitle(v *string) *AnnouncementUpdateOne {
if v != nil {
_u.SetTitle(*v)
}
return _u
}
// SetContent sets the "content" field.
func (_u *AnnouncementUpdateOne) SetContent(v string) *AnnouncementUpdateOne {
_u.mutation.SetContent(v)
return _u
}
// SetNillableContent sets the "content" field if the given value is not nil.
func (_u *AnnouncementUpdateOne) SetNillableContent(v *string) *AnnouncementUpdateOne {
if v != nil {
_u.SetContent(*v)
}
return _u
}
// SetStatus sets the "status" field.
func (_u *AnnouncementUpdateOne) SetStatus(v string) *AnnouncementUpdateOne {
_u.mutation.SetStatus(v)
return _u
}
// SetNillableStatus sets the "status" field if the given value is not nil.
func (_u *AnnouncementUpdateOne) SetNillableStatus(v *string) *AnnouncementUpdateOne {
if v != nil {
_u.SetStatus(*v)
}
return _u
}
// SetTargeting sets the "targeting" field.
func (_u *AnnouncementUpdateOne) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpdateOne {
_u.mutation.SetTargeting(v)
return _u
}
// SetNillableTargeting sets the "targeting" field if the given value is not nil.
func (_u *AnnouncementUpdateOne) SetNillableTargeting(v *domain.AnnouncementTargeting) *AnnouncementUpdateOne {
if v != nil {
_u.SetTargeting(*v)
}
return _u
}
// ClearTargeting clears the value of the "targeting" field.
func (_u *AnnouncementUpdateOne) ClearTargeting() *AnnouncementUpdateOne {
_u.mutation.ClearTargeting()
return _u
}
// SetStartsAt sets the "starts_at" field.
func (_u *AnnouncementUpdateOne) SetStartsAt(v time.Time) *AnnouncementUpdateOne {
_u.mutation.SetStartsAt(v)
return _u
}
// SetNillableStartsAt sets the "starts_at" field if the given value is not nil.
func (_u *AnnouncementUpdateOne) SetNillableStartsAt(v *time.Time) *AnnouncementUpdateOne {
if v != nil {
_u.SetStartsAt(*v)
}
return _u
}
// ClearStartsAt clears the value of the "starts_at" field.
func (_u *AnnouncementUpdateOne) ClearStartsAt() *AnnouncementUpdateOne {
_u.mutation.ClearStartsAt()
return _u
}
// SetEndsAt sets the "ends_at" field.
func (_u *AnnouncementUpdateOne) SetEndsAt(v time.Time) *AnnouncementUpdateOne {
_u.mutation.SetEndsAt(v)
return _u
}
// SetNillableEndsAt sets the "ends_at" field if the given value is not nil.
func (_u *AnnouncementUpdateOne) SetNillableEndsAt(v *time.Time) *AnnouncementUpdateOne {
if v != nil {
_u.SetEndsAt(*v)
}
return _u
}
// ClearEndsAt clears the value of the "ends_at" field.
func (_u *AnnouncementUpdateOne) ClearEndsAt() *AnnouncementUpdateOne {
_u.mutation.ClearEndsAt()
return _u
}
// SetCreatedBy sets the "created_by" field.
func (_u *AnnouncementUpdateOne) SetCreatedBy(v int64) *AnnouncementUpdateOne {
_u.mutation.ResetCreatedBy()
_u.mutation.SetCreatedBy(v)
return _u
}
// SetNillableCreatedBy sets the "created_by" field if the given value is not nil.
func (_u *AnnouncementUpdateOne) SetNillableCreatedBy(v *int64) *AnnouncementUpdateOne {
if v != nil {
_u.SetCreatedBy(*v)
}
return _u
}
// AddCreatedBy adds value to the "created_by" field.
func (_u *AnnouncementUpdateOne) AddCreatedBy(v int64) *AnnouncementUpdateOne {
_u.mutation.AddCreatedBy(v)
return _u
}
// ClearCreatedBy clears the value of the "created_by" field.
func (_u *AnnouncementUpdateOne) ClearCreatedBy() *AnnouncementUpdateOne {
_u.mutation.ClearCreatedBy()
return _u
}
// SetUpdatedBy sets the "updated_by" field.
func (_u *AnnouncementUpdateOne) SetUpdatedBy(v int64) *AnnouncementUpdateOne {
_u.mutation.ResetUpdatedBy()
_u.mutation.SetUpdatedBy(v)
return _u
}
// SetNillableUpdatedBy sets the "updated_by" field if the given value is not nil.
func (_u *AnnouncementUpdateOne) SetNillableUpdatedBy(v *int64) *AnnouncementUpdateOne {
if v != nil {
_u.SetUpdatedBy(*v)
}
return _u
}
// AddUpdatedBy adds value to the "updated_by" field.
func (_u *AnnouncementUpdateOne) AddUpdatedBy(v int64) *AnnouncementUpdateOne {
_u.mutation.AddUpdatedBy(v)
return _u
}
// ClearUpdatedBy clears the value of the "updated_by" field.
func (_u *AnnouncementUpdateOne) ClearUpdatedBy() *AnnouncementUpdateOne {
_u.mutation.ClearUpdatedBy()
return _u
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *AnnouncementUpdateOne) SetUpdatedAt(v time.Time) *AnnouncementUpdateOne {
_u.mutation.SetUpdatedAt(v)
return _u
}
// AddReadIDs adds the "reads" edge to the AnnouncementRead entity by IDs.
func (_u *AnnouncementUpdateOne) AddReadIDs(ids ...int64) *AnnouncementUpdateOne {
_u.mutation.AddReadIDs(ids...)
return _u
}
// AddReads adds the "reads" edges to the AnnouncementRead entity.
func (_u *AnnouncementUpdateOne) AddReads(v ...*AnnouncementRead) *AnnouncementUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddReadIDs(ids...)
}
// Mutation returns the AnnouncementMutation object of the builder.
func (_u *AnnouncementUpdateOne) Mutation() *AnnouncementMutation {
return _u.mutation
}
// ClearReads clears all "reads" edges to the AnnouncementRead entity.
func (_u *AnnouncementUpdateOne) ClearReads() *AnnouncementUpdateOne {
_u.mutation.ClearReads()
return _u
}
// RemoveReadIDs removes the "reads" edge to AnnouncementRead entities by IDs.
func (_u *AnnouncementUpdateOne) RemoveReadIDs(ids ...int64) *AnnouncementUpdateOne {
_u.mutation.RemoveReadIDs(ids...)
return _u
}
// RemoveReads removes "reads" edges to AnnouncementRead entities.
func (_u *AnnouncementUpdateOne) RemoveReads(v ...*AnnouncementRead) *AnnouncementUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveReadIDs(ids...)
}
// Where appends a list predicates to the AnnouncementUpdate builder.
func (_u *AnnouncementUpdateOne) Where(ps ...predicate.Announcement) *AnnouncementUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *AnnouncementUpdateOne) Select(field string, fields ...string) *AnnouncementUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated Announcement entity.
func (_u *AnnouncementUpdateOne) Save(ctx context.Context) (*Announcement, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *AnnouncementUpdateOne) SaveX(ctx context.Context) *Announcement {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *AnnouncementUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *AnnouncementUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *AnnouncementUpdateOne) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := announcement.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *AnnouncementUpdateOne) check() error {
if v, ok := _u.mutation.Title(); ok {
if err := announcement.TitleValidator(v); err != nil {
return &ValidationError{Name: "title", err: fmt.Errorf(`ent: validator failed for field "Announcement.title": %w`, err)}
}
}
if v, ok := _u.mutation.Content(); ok {
if err := announcement.ContentValidator(v); err != nil {
return &ValidationError{Name: "content", err: fmt.Errorf(`ent: validator failed for field "Announcement.content": %w`, err)}
}
}
if v, ok := _u.mutation.Status(); ok {
if err := announcement.StatusValidator(v); err != nil {
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)}
}
}
return nil
}
func (_u *AnnouncementUpdateOne) sqlSave(ctx context.Context) (_node *Announcement, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(announcement.Table, announcement.Columns, sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Announcement.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, announcement.FieldID)
for _, f := range fields {
if !announcement.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != announcement.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.Title(); ok {
_spec.SetField(announcement.FieldTitle, field.TypeString, value)
}
if value, ok := _u.mutation.Content(); ok {
_spec.SetField(announcement.FieldContent, field.TypeString, value)
}
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(announcement.FieldStatus, field.TypeString, value)
}
if value, ok := _u.mutation.Targeting(); ok {
_spec.SetField(announcement.FieldTargeting, field.TypeJSON, value)
}
if _u.mutation.TargetingCleared() {
_spec.ClearField(announcement.FieldTargeting, field.TypeJSON)
}
if value, ok := _u.mutation.StartsAt(); ok {
_spec.SetField(announcement.FieldStartsAt, field.TypeTime, value)
}
if _u.mutation.StartsAtCleared() {
_spec.ClearField(announcement.FieldStartsAt, field.TypeTime)
}
if value, ok := _u.mutation.EndsAt(); ok {
_spec.SetField(announcement.FieldEndsAt, field.TypeTime, value)
}
if _u.mutation.EndsAtCleared() {
_spec.ClearField(announcement.FieldEndsAt, field.TypeTime)
}
if value, ok := _u.mutation.CreatedBy(); ok {
_spec.SetField(announcement.FieldCreatedBy, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedCreatedBy(); ok {
_spec.AddField(announcement.FieldCreatedBy, field.TypeInt64, value)
}
if _u.mutation.CreatedByCleared() {
_spec.ClearField(announcement.FieldCreatedBy, field.TypeInt64)
}
if value, ok := _u.mutation.UpdatedBy(); ok {
_spec.SetField(announcement.FieldUpdatedBy, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedUpdatedBy(); ok {
_spec.AddField(announcement.FieldUpdatedBy, field.TypeInt64, value)
}
if _u.mutation.UpdatedByCleared() {
_spec.ClearField(announcement.FieldUpdatedBy, field.TypeInt64)
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(announcement.FieldUpdatedAt, field.TypeTime, value)
}
if _u.mutation.ReadsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: announcement.ReadsTable,
Columns: []string{announcement.ReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedReadsIDs(); len(nodes) > 0 && !_u.mutation.ReadsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: announcement.ReadsTable,
Columns: []string{announcement.ReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.ReadsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: announcement.ReadsTable,
Columns: []string{announcement.ReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &Announcement{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{announcement.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}

View File

@@ -0,0 +1,185 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"fmt"
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/user"
)
// AnnouncementRead is the model entity for the AnnouncementRead schema.
type AnnouncementRead struct {
config `json:"-"`
// ID of the ent.
ID int64 `json:"id,omitempty"`
// AnnouncementID holds the value of the "announcement_id" field.
AnnouncementID int64 `json:"announcement_id,omitempty"`
// UserID holds the value of the "user_id" field.
UserID int64 `json:"user_id,omitempty"`
// 用户首次已读时间
ReadAt time.Time `json:"read_at,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the AnnouncementReadQuery when eager-loading is set.
Edges AnnouncementReadEdges `json:"edges"`
selectValues sql.SelectValues
}
// AnnouncementReadEdges holds the relations/edges for other nodes in the graph.
type AnnouncementReadEdges struct {
// Announcement holds the value of the announcement edge.
Announcement *Announcement `json:"announcement,omitempty"`
// User holds the value of the user edge.
User *User `json:"user,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [2]bool
}
// AnnouncementOrErr returns the Announcement value or an error if the edge
// was not loaded in eager-loading, or loaded but was not found.
func (e AnnouncementReadEdges) AnnouncementOrErr() (*Announcement, error) {
if e.Announcement != nil {
return e.Announcement, nil
} else if e.loadedTypes[0] {
return nil, &NotFoundError{label: announcement.Label}
}
return nil, &NotLoadedError{edge: "announcement"}
}
// UserOrErr returns the User value or an error if the edge
// was not loaded in eager-loading, or loaded but was not found.
func (e AnnouncementReadEdges) UserOrErr() (*User, error) {
if e.User != nil {
return e.User, nil
} else if e.loadedTypes[1] {
return nil, &NotFoundError{label: user.Label}
}
return nil, &NotLoadedError{edge: "user"}
}
// scanValues returns the types for scanning values from sql.Rows.
func (*AnnouncementRead) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case announcementread.FieldID, announcementread.FieldAnnouncementID, announcementread.FieldUserID:
values[i] = new(sql.NullInt64)
case announcementread.FieldReadAt, announcementread.FieldCreatedAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
}
}
return values, nil
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the AnnouncementRead fields.
func (_m *AnnouncementRead) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
for i := range columns {
switch columns[i] {
case announcementread.FieldID:
value, ok := values[i].(*sql.NullInt64)
if !ok {
return fmt.Errorf("unexpected type %T for field id", value)
}
_m.ID = int64(value.Int64)
case announcementread.FieldAnnouncementID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field announcement_id", values[i])
} else if value.Valid {
_m.AnnouncementID = value.Int64
}
case announcementread.FieldUserID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field user_id", values[i])
} else if value.Valid {
_m.UserID = value.Int64
}
case announcementread.FieldReadAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field read_at", values[i])
} else if value.Valid {
_m.ReadAt = value.Time
}
case announcementread.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
} else if value.Valid {
_m.CreatedAt = value.Time
}
default:
_m.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// Value returns the ent.Value that was dynamically selected and assigned to the AnnouncementRead.
// This includes values selected through modifiers, order, etc.
func (_m *AnnouncementRead) Value(name string) (ent.Value, error) {
return _m.selectValues.Get(name)
}
// QueryAnnouncement queries the "announcement" edge of the AnnouncementRead entity.
func (_m *AnnouncementRead) QueryAnnouncement() *AnnouncementQuery {
return NewAnnouncementReadClient(_m.config).QueryAnnouncement(_m)
}
// QueryUser queries the "user" edge of the AnnouncementRead entity.
func (_m *AnnouncementRead) QueryUser() *UserQuery {
return NewAnnouncementReadClient(_m.config).QueryUser(_m)
}
// Update returns a builder for updating this AnnouncementRead.
// Note that you need to call AnnouncementRead.Unwrap() before calling this method if this AnnouncementRead
// was returned from a transaction, and the transaction was committed or rolled back.
func (_m *AnnouncementRead) Update() *AnnouncementReadUpdateOne {
return NewAnnouncementReadClient(_m.config).UpdateOne(_m)
}
// Unwrap unwraps the AnnouncementRead entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func (_m *AnnouncementRead) Unwrap() *AnnouncementRead {
_tx, ok := _m.config.driver.(*txDriver)
if !ok {
panic("ent: AnnouncementRead is not a transactional entity")
}
_m.config.driver = _tx.drv
return _m
}
// String implements the fmt.Stringer.
func (_m *AnnouncementRead) String() string {
var builder strings.Builder
builder.WriteString("AnnouncementRead(")
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
builder.WriteString("announcement_id=")
builder.WriteString(fmt.Sprintf("%v", _m.AnnouncementID))
builder.WriteString(", ")
builder.WriteString("user_id=")
builder.WriteString(fmt.Sprintf("%v", _m.UserID))
builder.WriteString(", ")
builder.WriteString("read_at=")
builder.WriteString(_m.ReadAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteByte(')')
return builder.String()
}
// AnnouncementReads is a parsable slice of AnnouncementRead.
type AnnouncementReads []*AnnouncementRead

View File

@@ -0,0 +1,127 @@
// Code generated by ent, DO NOT EDIT.
package announcementread
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
const (
// Label holds the string label denoting the announcementread type in the database.
Label = "announcement_read"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldAnnouncementID holds the string denoting the announcement_id field in the database.
FieldAnnouncementID = "announcement_id"
// FieldUserID holds the string denoting the user_id field in the database.
FieldUserID = "user_id"
// FieldReadAt holds the string denoting the read_at field in the database.
FieldReadAt = "read_at"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// EdgeAnnouncement holds the string denoting the announcement edge name in mutations.
EdgeAnnouncement = "announcement"
// EdgeUser holds the string denoting the user edge name in mutations.
EdgeUser = "user"
// Table holds the table name of the announcementread in the database.
Table = "announcement_reads"
// AnnouncementTable is the table that holds the announcement relation/edge.
AnnouncementTable = "announcement_reads"
// AnnouncementInverseTable is the table name for the Announcement entity.
// It exists in this package in order to avoid circular dependency with the "announcement" package.
AnnouncementInverseTable = "announcements"
// AnnouncementColumn is the table column denoting the announcement relation/edge.
AnnouncementColumn = "announcement_id"
// UserTable is the table that holds the user relation/edge.
UserTable = "announcement_reads"
// UserInverseTable is the table name for the User entity.
// It exists in this package in order to avoid circular dependency with the "user" package.
UserInverseTable = "users"
// UserColumn is the table column denoting the user relation/edge.
UserColumn = "user_id"
)
// Columns holds all SQL columns for announcementread fields.
var Columns = []string{
FieldID,
FieldAnnouncementID,
FieldUserID,
FieldReadAt,
FieldCreatedAt,
}
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
if column == Columns[i] {
return true
}
}
return false
}
var (
// DefaultReadAt holds the default value on creation for the "read_at" field.
DefaultReadAt func() time.Time
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
)
// OrderOption defines the ordering options for the AnnouncementRead queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByAnnouncementID orders the results by the announcement_id field.
func ByAnnouncementID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAnnouncementID, opts...).ToFunc()
}
// ByUserID orders the results by the user_id field.
func ByUserID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUserID, opts...).ToFunc()
}
// ByReadAt orders the results by the read_at field.
func ByReadAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldReadAt, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByAnnouncementField orders the results by announcement field.
func ByAnnouncementField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newAnnouncementStep(), sql.OrderByField(field, opts...))
}
}
// ByUserField orders the results by user field.
func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...))
}
}
func newAnnouncementStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(AnnouncementInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, AnnouncementTable, AnnouncementColumn),
)
}
func newUserStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(UserInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
)
}

View File

@@ -0,0 +1,257 @@
// Code generated by ent, DO NOT EDIT.
package announcementread
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// ID filters vertices based on their ID field.
func ID(id int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldEQ(FieldID, id))
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldEQ(FieldID, id))
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldNEQ(FieldID, id))
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldIn(FieldID, ids...))
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldNotIn(FieldID, ids...))
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldGT(FieldID, id))
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldGTE(FieldID, id))
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldLT(FieldID, id))
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldLTE(FieldID, id))
}
// AnnouncementID applies equality check predicate on the "announcement_id" field. It's identical to AnnouncementIDEQ.
func AnnouncementID(v int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldEQ(FieldAnnouncementID, v))
}
// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ.
func UserID(v int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldEQ(FieldUserID, v))
}
// ReadAt applies equality check predicate on the "read_at" field. It's identical to ReadAtEQ.
func ReadAt(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldEQ(FieldReadAt, v))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldEQ(FieldCreatedAt, v))
}
// AnnouncementIDEQ applies the EQ predicate on the "announcement_id" field.
func AnnouncementIDEQ(v int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldEQ(FieldAnnouncementID, v))
}
// AnnouncementIDNEQ applies the NEQ predicate on the "announcement_id" field.
func AnnouncementIDNEQ(v int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldNEQ(FieldAnnouncementID, v))
}
// AnnouncementIDIn applies the In predicate on the "announcement_id" field.
func AnnouncementIDIn(vs ...int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldIn(FieldAnnouncementID, vs...))
}
// AnnouncementIDNotIn applies the NotIn predicate on the "announcement_id" field.
func AnnouncementIDNotIn(vs ...int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldNotIn(FieldAnnouncementID, vs...))
}
// UserIDEQ applies the EQ predicate on the "user_id" field.
func UserIDEQ(v int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldEQ(FieldUserID, v))
}
// UserIDNEQ applies the NEQ predicate on the "user_id" field.
func UserIDNEQ(v int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldNEQ(FieldUserID, v))
}
// UserIDIn applies the In predicate on the "user_id" field.
func UserIDIn(vs ...int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldIn(FieldUserID, vs...))
}
// UserIDNotIn applies the NotIn predicate on the "user_id" field.
func UserIDNotIn(vs ...int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldNotIn(FieldUserID, vs...))
}
// ReadAtEQ applies the EQ predicate on the "read_at" field.
func ReadAtEQ(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldEQ(FieldReadAt, v))
}
// ReadAtNEQ applies the NEQ predicate on the "read_at" field.
func ReadAtNEQ(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldNEQ(FieldReadAt, v))
}
// ReadAtIn applies the In predicate on the "read_at" field.
func ReadAtIn(vs ...time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldIn(FieldReadAt, vs...))
}
// ReadAtNotIn applies the NotIn predicate on the "read_at" field.
func ReadAtNotIn(vs ...time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldNotIn(FieldReadAt, vs...))
}
// ReadAtGT applies the GT predicate on the "read_at" field.
func ReadAtGT(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldGT(FieldReadAt, v))
}
// ReadAtGTE applies the GTE predicate on the "read_at" field.
func ReadAtGTE(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldGTE(FieldReadAt, v))
}
// ReadAtLT applies the LT predicate on the "read_at" field.
func ReadAtLT(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldLT(FieldReadAt, v))
}
// ReadAtLTE applies the LTE predicate on the "read_at" field.
func ReadAtLTE(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldLTE(FieldReadAt, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldEQ(FieldCreatedAt, v))
}
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
func CreatedAtNEQ(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldNEQ(FieldCreatedAt, v))
}
// CreatedAtIn applies the In predicate on the "created_at" field.
func CreatedAtIn(vs ...time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldIn(FieldCreatedAt, vs...))
}
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
func CreatedAtNotIn(vs ...time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldNotIn(FieldCreatedAt, vs...))
}
// CreatedAtGT applies the GT predicate on the "created_at" field.
func CreatedAtGT(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldGT(FieldCreatedAt, v))
}
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
func CreatedAtGTE(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldGTE(FieldCreatedAt, v))
}
// CreatedAtLT applies the LT predicate on the "created_at" field.
func CreatedAtLT(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldLT(FieldCreatedAt, v))
}
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
func CreatedAtLTE(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldLTE(FieldCreatedAt, v))
}
// HasAnnouncement applies the HasEdge predicate on the "announcement" edge.
func HasAnnouncement() predicate.AnnouncementRead {
return predicate.AnnouncementRead(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, AnnouncementTable, AnnouncementColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasAnnouncementWith applies the HasEdge predicate on the "announcement" edge with a given conditions (other predicates).
func HasAnnouncementWith(preds ...predicate.Announcement) predicate.AnnouncementRead {
return predicate.AnnouncementRead(func(s *sql.Selector) {
step := newAnnouncementStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasUser applies the HasEdge predicate on the "user" edge.
func HasUser() predicate.AnnouncementRead {
return predicate.AnnouncementRead(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates).
func HasUserWith(preds ...predicate.User) predicate.AnnouncementRead {
return predicate.AnnouncementRead(func(s *sql.Selector) {
step := newUserStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.AnnouncementRead) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.AndPredicates(predicates...))
}
// Or groups predicates with the OR operator between them.
func Or(predicates ...predicate.AnnouncementRead) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.OrPredicates(predicates...))
}
// Not applies the not operator on the given predicate.
func Not(p predicate.AnnouncementRead) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.NotPredicates(p))
}

View File

@@ -0,0 +1,660 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/user"
)
// AnnouncementReadCreate is the builder for creating a AnnouncementRead entity.
type AnnouncementReadCreate struct {
config
mutation *AnnouncementReadMutation
hooks []Hook
conflict []sql.ConflictOption
}
// SetAnnouncementID sets the "announcement_id" field.
func (_c *AnnouncementReadCreate) SetAnnouncementID(v int64) *AnnouncementReadCreate {
_c.mutation.SetAnnouncementID(v)
return _c
}
// SetUserID sets the "user_id" field.
func (_c *AnnouncementReadCreate) SetUserID(v int64) *AnnouncementReadCreate {
_c.mutation.SetUserID(v)
return _c
}
// SetReadAt sets the "read_at" field.
func (_c *AnnouncementReadCreate) SetReadAt(v time.Time) *AnnouncementReadCreate {
_c.mutation.SetReadAt(v)
return _c
}
// SetNillableReadAt sets the "read_at" field if the given value is not nil.
func (_c *AnnouncementReadCreate) SetNillableReadAt(v *time.Time) *AnnouncementReadCreate {
if v != nil {
_c.SetReadAt(*v)
}
return _c
}
// SetCreatedAt sets the "created_at" field.
func (_c *AnnouncementReadCreate) SetCreatedAt(v time.Time) *AnnouncementReadCreate {
_c.mutation.SetCreatedAt(v)
return _c
}
// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
func (_c *AnnouncementReadCreate) SetNillableCreatedAt(v *time.Time) *AnnouncementReadCreate {
if v != nil {
_c.SetCreatedAt(*v)
}
return _c
}
// SetAnnouncement sets the "announcement" edge to the Announcement entity.
func (_c *AnnouncementReadCreate) SetAnnouncement(v *Announcement) *AnnouncementReadCreate {
return _c.SetAnnouncementID(v.ID)
}
// SetUser sets the "user" edge to the User entity.
func (_c *AnnouncementReadCreate) SetUser(v *User) *AnnouncementReadCreate {
return _c.SetUserID(v.ID)
}
// Mutation returns the AnnouncementReadMutation object of the builder.
func (_c *AnnouncementReadCreate) Mutation() *AnnouncementReadMutation {
return _c.mutation
}
// Save creates the AnnouncementRead in the database.
func (_c *AnnouncementReadCreate) Save(ctx context.Context) (*AnnouncementRead, error) {
_c.defaults()
return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
}
// SaveX calls Save and panics if Save returns an error.
func (_c *AnnouncementReadCreate) SaveX(ctx context.Context) *AnnouncementRead {
v, err := _c.Save(ctx)
if err != nil {
panic(err)
}
return v
}
// Exec executes the query.
func (_c *AnnouncementReadCreate) Exec(ctx context.Context) error {
_, err := _c.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_c *AnnouncementReadCreate) ExecX(ctx context.Context) {
if err := _c.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_c *AnnouncementReadCreate) defaults() {
if _, ok := _c.mutation.ReadAt(); !ok {
v := announcementread.DefaultReadAt()
_c.mutation.SetReadAt(v)
}
if _, ok := _c.mutation.CreatedAt(); !ok {
v := announcementread.DefaultCreatedAt()
_c.mutation.SetCreatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_c *AnnouncementReadCreate) check() error {
if _, ok := _c.mutation.AnnouncementID(); !ok {
return &ValidationError{Name: "announcement_id", err: errors.New(`ent: missing required field "AnnouncementRead.announcement_id"`)}
}
if _, ok := _c.mutation.UserID(); !ok {
return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "AnnouncementRead.user_id"`)}
}
if _, ok := _c.mutation.ReadAt(); !ok {
return &ValidationError{Name: "read_at", err: errors.New(`ent: missing required field "AnnouncementRead.read_at"`)}
}
if _, ok := _c.mutation.CreatedAt(); !ok {
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AnnouncementRead.created_at"`)}
}
if len(_c.mutation.AnnouncementIDs()) == 0 {
return &ValidationError{Name: "announcement", err: errors.New(`ent: missing required edge "AnnouncementRead.announcement"`)}
}
if len(_c.mutation.UserIDs()) == 0 {
return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "AnnouncementRead.user"`)}
}
return nil
}
func (_c *AnnouncementReadCreate) sqlSave(ctx context.Context) (*AnnouncementRead, error) {
if err := _c.check(); err != nil {
return nil, err
}
_node, _spec := _c.createSpec()
if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
id := _spec.ID.Value.(int64)
_node.ID = int64(id)
_c.mutation.id = &_node.ID
_c.mutation.done = true
return _node, nil
}
func (_c *AnnouncementReadCreate) createSpec() (*AnnouncementRead, *sqlgraph.CreateSpec) {
var (
_node = &AnnouncementRead{config: _c.config}
_spec = sqlgraph.NewCreateSpec(announcementread.Table, sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64))
)
_spec.OnConflict = _c.conflict
if value, ok := _c.mutation.ReadAt(); ok {
_spec.SetField(announcementread.FieldReadAt, field.TypeTime, value)
_node.ReadAt = value
}
if value, ok := _c.mutation.CreatedAt(); ok {
_spec.SetField(announcementread.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = value
}
if nodes := _c.mutation.AnnouncementIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: announcementread.AnnouncementTable,
Columns: []string{announcementread.AnnouncementColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_node.AnnouncementID = nodes[0]
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: announcementread.UserTable,
Columns: []string{announcementread.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_node.UserID = nodes[0]
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec
}
// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
// of the `INSERT` statement. For example:
//
// client.AnnouncementRead.Create().
// SetAnnouncementID(v).
// OnConflict(
// // Update the row with the new values
// // the was proposed for insertion.
// sql.ResolveWithNewValues(),
// ).
// // Override some of the fields with custom
// // update values.
// Update(func(u *ent.AnnouncementReadUpsert) {
// SetAnnouncementID(v+v).
// }).
// Exec(ctx)
func (_c *AnnouncementReadCreate) OnConflict(opts ...sql.ConflictOption) *AnnouncementReadUpsertOne {
_c.conflict = opts
return &AnnouncementReadUpsertOne{
create: _c,
}
}
// OnConflictColumns calls `OnConflict` and configures the columns
// as conflict target. Using this option is equivalent to using:
//
// client.AnnouncementRead.Create().
// OnConflict(sql.ConflictColumns(columns...)).
// Exec(ctx)
func (_c *AnnouncementReadCreate) OnConflictColumns(columns ...string) *AnnouncementReadUpsertOne {
_c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
return &AnnouncementReadUpsertOne{
create: _c,
}
}
type (
// AnnouncementReadUpsertOne is the builder for "upsert"-ing
// one AnnouncementRead node.
AnnouncementReadUpsertOne struct {
create *AnnouncementReadCreate
}
// AnnouncementReadUpsert is the "OnConflict" setter.
AnnouncementReadUpsert struct {
*sql.UpdateSet
}
)
// SetAnnouncementID sets the "announcement_id" field.
func (u *AnnouncementReadUpsert) SetAnnouncementID(v int64) *AnnouncementReadUpsert {
u.Set(announcementread.FieldAnnouncementID, v)
return u
}
// UpdateAnnouncementID sets the "announcement_id" field to the value that was provided on create.
func (u *AnnouncementReadUpsert) UpdateAnnouncementID() *AnnouncementReadUpsert {
u.SetExcluded(announcementread.FieldAnnouncementID)
return u
}
// SetUserID sets the "user_id" field.
func (u *AnnouncementReadUpsert) SetUserID(v int64) *AnnouncementReadUpsert {
u.Set(announcementread.FieldUserID, v)
return u
}
// UpdateUserID sets the "user_id" field to the value that was provided on create.
func (u *AnnouncementReadUpsert) UpdateUserID() *AnnouncementReadUpsert {
u.SetExcluded(announcementread.FieldUserID)
return u
}
// SetReadAt sets the "read_at" field.
func (u *AnnouncementReadUpsert) SetReadAt(v time.Time) *AnnouncementReadUpsert {
u.Set(announcementread.FieldReadAt, v)
return u
}
// UpdateReadAt sets the "read_at" field to the value that was provided on create.
func (u *AnnouncementReadUpsert) UpdateReadAt() *AnnouncementReadUpsert {
u.SetExcluded(announcementread.FieldReadAt)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
// client.AnnouncementRead.Create().
// OnConflict(
// sql.ResolveWithNewValues(),
// ).
// Exec(ctx)
func (u *AnnouncementReadUpsertOne) UpdateNewValues() *AnnouncementReadUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
if _, exists := u.create.mutation.CreatedAt(); exists {
s.SetIgnore(announcementread.FieldCreatedAt)
}
}))
return u
}
// Ignore sets each column to itself in case of conflict.
// Using this option is equivalent to using:
//
// client.AnnouncementRead.Create().
// OnConflict(sql.ResolveWithIgnore()).
// Exec(ctx)
func (u *AnnouncementReadUpsertOne) Ignore() *AnnouncementReadUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
return u
}
// DoNothing configures the conflict_action to `DO NOTHING`.
// Supported only by SQLite and PostgreSQL.
func (u *AnnouncementReadUpsertOne) DoNothing() *AnnouncementReadUpsertOne {
u.create.conflict = append(u.create.conflict, sql.DoNothing())
return u
}
// Update allows overriding fields `UPDATE` values. See the AnnouncementReadCreate.OnConflict
// documentation for more info.
func (u *AnnouncementReadUpsertOne) Update(set func(*AnnouncementReadUpsert)) *AnnouncementReadUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
set(&AnnouncementReadUpsert{UpdateSet: update})
}))
return u
}
// SetAnnouncementID sets the "announcement_id" field.
func (u *AnnouncementReadUpsertOne) SetAnnouncementID(v int64) *AnnouncementReadUpsertOne {
return u.Update(func(s *AnnouncementReadUpsert) {
s.SetAnnouncementID(v)
})
}
// UpdateAnnouncementID sets the "announcement_id" field to the value that was provided on create.
func (u *AnnouncementReadUpsertOne) UpdateAnnouncementID() *AnnouncementReadUpsertOne {
return u.Update(func(s *AnnouncementReadUpsert) {
s.UpdateAnnouncementID()
})
}
// SetUserID sets the "user_id" field.
func (u *AnnouncementReadUpsertOne) SetUserID(v int64) *AnnouncementReadUpsertOne {
return u.Update(func(s *AnnouncementReadUpsert) {
s.SetUserID(v)
})
}
// UpdateUserID sets the "user_id" field to the value that was provided on create.
func (u *AnnouncementReadUpsertOne) UpdateUserID() *AnnouncementReadUpsertOne {
return u.Update(func(s *AnnouncementReadUpsert) {
s.UpdateUserID()
})
}
// SetReadAt sets the "read_at" field.
func (u *AnnouncementReadUpsertOne) SetReadAt(v time.Time) *AnnouncementReadUpsertOne {
return u.Update(func(s *AnnouncementReadUpsert) {
s.SetReadAt(v)
})
}
// UpdateReadAt sets the "read_at" field to the value that was provided on create.
func (u *AnnouncementReadUpsertOne) UpdateReadAt() *AnnouncementReadUpsertOne {
return u.Update(func(s *AnnouncementReadUpsert) {
s.UpdateReadAt()
})
}
// Exec executes the query.
func (u *AnnouncementReadUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
return errors.New("ent: missing options for AnnouncementReadCreate.OnConflict")
}
return u.create.Exec(ctx)
}
// ExecX is like Exec, but panics if an error occurs.
func (u *AnnouncementReadUpsertOne) ExecX(ctx context.Context) {
if err := u.create.Exec(ctx); err != nil {
panic(err)
}
}
// Exec executes the UPSERT query and returns the inserted/updated ID.
func (u *AnnouncementReadUpsertOne) ID(ctx context.Context) (id int64, err error) {
node, err := u.create.Save(ctx)
if err != nil {
return id, err
}
return node.ID, nil
}
// IDX is like ID, but panics if an error occurs.
func (u *AnnouncementReadUpsertOne) IDX(ctx context.Context) int64 {
id, err := u.ID(ctx)
if err != nil {
panic(err)
}
return id
}
// AnnouncementReadCreateBulk is the builder for creating many AnnouncementRead entities in bulk.
type AnnouncementReadCreateBulk struct {
config
err error
builders []*AnnouncementReadCreate
conflict []sql.ConflictOption
}
// Save creates the AnnouncementRead entities in the database.
func (_c *AnnouncementReadCreateBulk) Save(ctx context.Context) ([]*AnnouncementRead, error) {
if _c.err != nil {
return nil, _c.err
}
specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
nodes := make([]*AnnouncementRead, len(_c.builders))
mutators := make([]Mutator, len(_c.builders))
for i := range _c.builders {
func(i int, root context.Context) {
builder := _c.builders[i]
builder.defaults()
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*AnnouncementReadMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
if err := builder.check(); err != nil {
return nil, err
}
builder.mutation = mutation
var err error
nodes[i], specs[i] = builder.createSpec()
if i < len(mutators)-1 {
_, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
} else {
spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
spec.OnConflict = _c.conflict
// Invoke the actual operation on the latest mutation in the chain.
if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
}
}
if err != nil {
return nil, err
}
mutation.id = &nodes[i].ID
if specs[i].ID.Value != nil {
id := specs[i].ID.Value.(int64)
nodes[i].ID = int64(id)
}
mutation.done = true
return nodes[i], nil
})
for i := len(builder.hooks) - 1; i >= 0; i-- {
mut = builder.hooks[i](mut)
}
mutators[i] = mut
}(i, ctx)
}
if len(mutators) > 0 {
if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
return nil, err
}
}
return nodes, nil
}
// SaveX is like Save, but panics if an error occurs.
func (_c *AnnouncementReadCreateBulk) SaveX(ctx context.Context) []*AnnouncementRead {
v, err := _c.Save(ctx)
if err != nil {
panic(err)
}
return v
}
// Exec executes the query.
func (_c *AnnouncementReadCreateBulk) Exec(ctx context.Context) error {
_, err := _c.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_c *AnnouncementReadCreateBulk) ExecX(ctx context.Context) {
if err := _c.Exec(ctx); err != nil {
panic(err)
}
}
// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
// of the `INSERT` statement. For example:
//
// client.AnnouncementRead.CreateBulk(builders...).
// OnConflict(
// // Update the row with the new values
// // the was proposed for insertion.
// sql.ResolveWithNewValues(),
// ).
// // Override some of the fields with custom
// // update values.
// Update(func(u *ent.AnnouncementReadUpsert) {
// SetAnnouncementID(v+v).
// }).
// Exec(ctx)
func (_c *AnnouncementReadCreateBulk) OnConflict(opts ...sql.ConflictOption) *AnnouncementReadUpsertBulk {
_c.conflict = opts
return &AnnouncementReadUpsertBulk{
create: _c,
}
}
// OnConflictColumns calls `OnConflict` and configures the columns
// as conflict target. Using this option is equivalent to using:
//
// client.AnnouncementRead.Create().
// OnConflict(sql.ConflictColumns(columns...)).
// Exec(ctx)
func (_c *AnnouncementReadCreateBulk) OnConflictColumns(columns ...string) *AnnouncementReadUpsertBulk {
_c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
return &AnnouncementReadUpsertBulk{
create: _c,
}
}
// AnnouncementReadUpsertBulk is the builder for "upsert"-ing
// a bulk of AnnouncementRead nodes.
type AnnouncementReadUpsertBulk struct {
create *AnnouncementReadCreateBulk
}
// UpdateNewValues updates the mutable fields using the new values that
// were set on create. Using this option is equivalent to using:
//
// client.AnnouncementRead.Create().
// OnConflict(
// sql.ResolveWithNewValues(),
// ).
// Exec(ctx)
func (u *AnnouncementReadUpsertBulk) UpdateNewValues() *AnnouncementReadUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
for _, b := range u.create.builders {
if _, exists := b.mutation.CreatedAt(); exists {
s.SetIgnore(announcementread.FieldCreatedAt)
}
}
}))
return u
}
// Ignore sets each column to itself in case of conflict.
// Using this option is equivalent to using:
//
// client.AnnouncementRead.Create().
// OnConflict(sql.ResolveWithIgnore()).
// Exec(ctx)
func (u *AnnouncementReadUpsertBulk) Ignore() *AnnouncementReadUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
return u
}
// DoNothing configures the conflict_action to `DO NOTHING`.
// Supported only by SQLite and PostgreSQL.
func (u *AnnouncementReadUpsertBulk) DoNothing() *AnnouncementReadUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.DoNothing())
return u
}
// Update allows overriding fields `UPDATE` values. See the AnnouncementReadCreateBulk.OnConflict
// documentation for more info.
func (u *AnnouncementReadUpsertBulk) Update(set func(*AnnouncementReadUpsert)) *AnnouncementReadUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
set(&AnnouncementReadUpsert{UpdateSet: update})
}))
return u
}
// SetAnnouncementID sets the "announcement_id" field.
func (u *AnnouncementReadUpsertBulk) SetAnnouncementID(v int64) *AnnouncementReadUpsertBulk {
return u.Update(func(s *AnnouncementReadUpsert) {
s.SetAnnouncementID(v)
})
}
// UpdateAnnouncementID sets the "announcement_id" field to the value that was provided on create.
func (u *AnnouncementReadUpsertBulk) UpdateAnnouncementID() *AnnouncementReadUpsertBulk {
return u.Update(func(s *AnnouncementReadUpsert) {
s.UpdateAnnouncementID()
})
}
// SetUserID sets the "user_id" field.
func (u *AnnouncementReadUpsertBulk) SetUserID(v int64) *AnnouncementReadUpsertBulk {
return u.Update(func(s *AnnouncementReadUpsert) {
s.SetUserID(v)
})
}
// UpdateUserID sets the "user_id" field to the value that was provided on create.
func (u *AnnouncementReadUpsertBulk) UpdateUserID() *AnnouncementReadUpsertBulk {
return u.Update(func(s *AnnouncementReadUpsert) {
s.UpdateUserID()
})
}
// SetReadAt sets the "read_at" field.
func (u *AnnouncementReadUpsertBulk) SetReadAt(v time.Time) *AnnouncementReadUpsertBulk {
return u.Update(func(s *AnnouncementReadUpsert) {
s.SetReadAt(v)
})
}
// UpdateReadAt sets the "read_at" field to the value that was provided on create.
func (u *AnnouncementReadUpsertBulk) UpdateReadAt() *AnnouncementReadUpsertBulk {
return u.Update(func(s *AnnouncementReadUpsert) {
s.UpdateReadAt()
})
}
// Exec executes the query.
func (u *AnnouncementReadUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
return u.create.err
}
for i, b := range u.create.builders {
if len(b.conflict) != 0 {
return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AnnouncementReadCreateBulk instead", i)
}
}
if len(u.create.conflict) == 0 {
return errors.New("ent: missing options for AnnouncementReadCreateBulk.OnConflict")
}
return u.create.Exec(ctx)
}
// ExecX is like Exec, but panics if an error occurs.
func (u *AnnouncementReadUpsertBulk) ExecX(ctx context.Context) {
if err := u.create.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,88 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// AnnouncementReadDelete is the builder for deleting a AnnouncementRead entity.
type AnnouncementReadDelete struct {
config
hooks []Hook
mutation *AnnouncementReadMutation
}
// Where appends a list predicates to the AnnouncementReadDelete builder.
func (_d *AnnouncementReadDelete) Where(ps ...predicate.AnnouncementRead) *AnnouncementReadDelete {
_d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (_d *AnnouncementReadDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *AnnouncementReadDelete) ExecX(ctx context.Context) int {
n, err := _d.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (_d *AnnouncementReadDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(announcementread.Table, sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64))
if ps := _d.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
_d.mutation.done = true
return affected, err
}
// AnnouncementReadDeleteOne is the builder for deleting a single AnnouncementRead entity.
type AnnouncementReadDeleteOne struct {
_d *AnnouncementReadDelete
}
// Where appends a list predicates to the AnnouncementReadDelete builder.
func (_d *AnnouncementReadDeleteOne) Where(ps ...predicate.AnnouncementRead) *AnnouncementReadDeleteOne {
_d._d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query.
func (_d *AnnouncementReadDeleteOne) Exec(ctx context.Context) error {
n, err := _d._d.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{announcementread.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *AnnouncementReadDeleteOne) ExecX(ctx context.Context) {
if err := _d.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,718 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/user"
)
// AnnouncementReadQuery is the builder for querying AnnouncementRead entities.
type AnnouncementReadQuery struct {
config
ctx *QueryContext
order []announcementread.OrderOption
inters []Interceptor
predicates []predicate.AnnouncementRead
withAnnouncement *AnnouncementQuery
withUser *UserQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the AnnouncementReadQuery builder.
func (_q *AnnouncementReadQuery) Where(ps ...predicate.AnnouncementRead) *AnnouncementReadQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *AnnouncementReadQuery) Limit(limit int) *AnnouncementReadQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *AnnouncementReadQuery) Offset(offset int) *AnnouncementReadQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *AnnouncementReadQuery) Unique(unique bool) *AnnouncementReadQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *AnnouncementReadQuery) Order(o ...announcementread.OrderOption) *AnnouncementReadQuery {
_q.order = append(_q.order, o...)
return _q
}
// QueryAnnouncement chains the current query on the "announcement" edge.
func (_q *AnnouncementReadQuery) QueryAnnouncement() *AnnouncementQuery {
query := (&AnnouncementClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(announcementread.Table, announcementread.FieldID, selector),
sqlgraph.To(announcement.Table, announcement.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, announcementread.AnnouncementTable, announcementread.AnnouncementColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryUser chains the current query on the "user" edge.
func (_q *AnnouncementReadQuery) QueryUser() *UserQuery {
query := (&UserClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(announcementread.Table, announcementread.FieldID, selector),
sqlgraph.To(user.Table, user.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, announcementread.UserTable, announcementread.UserColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// First returns the first AnnouncementRead entity from the query.
// Returns a *NotFoundError when no AnnouncementRead was found.
func (_q *AnnouncementReadQuery) First(ctx context.Context) (*AnnouncementRead, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{announcementread.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *AnnouncementReadQuery) FirstX(ctx context.Context) *AnnouncementRead {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first AnnouncementRead ID from the query.
// Returns a *NotFoundError when no AnnouncementRead ID was found.
func (_q *AnnouncementReadQuery) FirstID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{announcementread.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *AnnouncementReadQuery) FirstIDX(ctx context.Context) int64 {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single AnnouncementRead entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one AnnouncementRead entity is found.
// Returns a *NotFoundError when no AnnouncementRead entities are found.
func (_q *AnnouncementReadQuery) Only(ctx context.Context) (*AnnouncementRead, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{announcementread.Label}
default:
return nil, &NotSingularError{announcementread.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *AnnouncementReadQuery) OnlyX(ctx context.Context) *AnnouncementRead {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only AnnouncementRead ID in the query.
// Returns a *NotSingularError when more than one AnnouncementRead ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *AnnouncementReadQuery) OnlyID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{announcementread.Label}
default:
err = &NotSingularError{announcementread.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *AnnouncementReadQuery) OnlyIDX(ctx context.Context) int64 {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of AnnouncementReads.
func (_q *AnnouncementReadQuery) All(ctx context.Context) ([]*AnnouncementRead, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*AnnouncementRead, *AnnouncementReadQuery]()
return withInterceptors[[]*AnnouncementRead](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *AnnouncementReadQuery) AllX(ctx context.Context) []*AnnouncementRead {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of AnnouncementRead IDs.
func (_q *AnnouncementReadQuery) IDs(ctx context.Context) (ids []int64, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(announcementread.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *AnnouncementReadQuery) IDsX(ctx context.Context) []int64 {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *AnnouncementReadQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*AnnouncementReadQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *AnnouncementReadQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *AnnouncementReadQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *AnnouncementReadQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the AnnouncementReadQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *AnnouncementReadQuery) Clone() *AnnouncementReadQuery {
if _q == nil {
return nil
}
return &AnnouncementReadQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]announcementread.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.AnnouncementRead{}, _q.predicates...),
withAnnouncement: _q.withAnnouncement.Clone(),
withUser: _q.withUser.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// WithAnnouncement tells the query-builder to eager-load the nodes that are connected to
// the "announcement" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *AnnouncementReadQuery) WithAnnouncement(opts ...func(*AnnouncementQuery)) *AnnouncementReadQuery {
query := (&AnnouncementClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withAnnouncement = query
return _q
}
// WithUser tells the query-builder to eager-load the nodes that are connected to
// the "user" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *AnnouncementReadQuery) WithUser(opts ...func(*UserQuery)) *AnnouncementReadQuery {
query := (&UserClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withUser = query
return _q
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// AnnouncementID int64 `json:"announcement_id,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.AnnouncementRead.Query().
// GroupBy(announcementread.FieldAnnouncementID).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *AnnouncementReadQuery) GroupBy(field string, fields ...string) *AnnouncementReadGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &AnnouncementReadGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = announcementread.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// AnnouncementID int64 `json:"announcement_id,omitempty"`
// }
//
// client.AnnouncementRead.Query().
// Select(announcementread.FieldAnnouncementID).
// Scan(ctx, &v)
func (_q *AnnouncementReadQuery) Select(fields ...string) *AnnouncementReadSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &AnnouncementReadSelect{AnnouncementReadQuery: _q}
sbuild.label = announcementread.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a AnnouncementReadSelect configured with the given aggregations.
func (_q *AnnouncementReadQuery) Aggregate(fns ...AggregateFunc) *AnnouncementReadSelect {
return _q.Select().Aggregate(fns...)
}
func (_q *AnnouncementReadQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !announcementread.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *AnnouncementReadQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AnnouncementRead, error) {
var (
nodes = []*AnnouncementRead{}
_spec = _q.querySpec()
loadedTypes = [2]bool{
_q.withAnnouncement != nil,
_q.withUser != nil,
}
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*AnnouncementRead).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &AnnouncementRead{config: _q.config}
nodes = append(nodes, node)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
if query := _q.withAnnouncement; query != nil {
if err := _q.loadAnnouncement(ctx, query, nodes, nil,
func(n *AnnouncementRead, e *Announcement) { n.Edges.Announcement = e }); err != nil {
return nil, err
}
}
if query := _q.withUser; query != nil {
if err := _q.loadUser(ctx, query, nodes, nil,
func(n *AnnouncementRead, e *User) { n.Edges.User = e }); err != nil {
return nil, err
}
}
return nodes, nil
}
func (_q *AnnouncementReadQuery) loadAnnouncement(ctx context.Context, query *AnnouncementQuery, nodes []*AnnouncementRead, init func(*AnnouncementRead), assign func(*AnnouncementRead, *Announcement)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*AnnouncementRead)
for i := range nodes {
fk := nodes[i].AnnouncementID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(announcement.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "announcement_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *AnnouncementReadQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*AnnouncementRead, init func(*AnnouncementRead), assign func(*AnnouncementRead, *User)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*AnnouncementRead)
for i := range nodes {
fk := nodes[i].UserID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(user.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *AnnouncementReadQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *AnnouncementReadQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(announcementread.Table, announcementread.Columns, sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, announcementread.FieldID)
for i := range fields {
if fields[i] != announcementread.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
if _q.withAnnouncement != nil {
_spec.Node.AddColumnOnce(announcementread.FieldAnnouncementID)
}
if _q.withUser != nil {
_spec.Node.AddColumnOnce(announcementread.FieldUserID)
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *AnnouncementReadQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(announcementread.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = announcementread.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *AnnouncementReadQuery) ForUpdate(opts ...sql.LockOption) *AnnouncementReadQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *AnnouncementReadQuery) ForShare(opts ...sql.LockOption) *AnnouncementReadQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// AnnouncementReadGroupBy is the group-by builder for AnnouncementRead entities.
type AnnouncementReadGroupBy struct {
selector
build *AnnouncementReadQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *AnnouncementReadGroupBy) Aggregate(fns ...AggregateFunc) *AnnouncementReadGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *AnnouncementReadGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*AnnouncementReadQuery, *AnnouncementReadGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *AnnouncementReadGroupBy) sqlScan(ctx context.Context, root *AnnouncementReadQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// AnnouncementReadSelect is the builder for selecting fields of AnnouncementRead entities.
type AnnouncementReadSelect struct {
*AnnouncementReadQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *AnnouncementReadSelect) Aggregate(fns ...AggregateFunc) *AnnouncementReadSelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *AnnouncementReadSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*AnnouncementReadQuery, *AnnouncementReadSelect](ctx, _s.AnnouncementReadQuery, _s, _s.inters, v)
}
func (_s *AnnouncementReadSelect) sqlScan(ctx context.Context, root *AnnouncementReadQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}

View File

@@ -0,0 +1,456 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/user"
)
// AnnouncementReadUpdate is the builder for updating AnnouncementRead entities.
type AnnouncementReadUpdate struct {
config
hooks []Hook
mutation *AnnouncementReadMutation
}
// Where appends a list predicates to the AnnouncementReadUpdate builder.
func (_u *AnnouncementReadUpdate) Where(ps ...predicate.AnnouncementRead) *AnnouncementReadUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetAnnouncementID sets the "announcement_id" field.
func (_u *AnnouncementReadUpdate) SetAnnouncementID(v int64) *AnnouncementReadUpdate {
_u.mutation.SetAnnouncementID(v)
return _u
}
// SetNillableAnnouncementID sets the "announcement_id" field if the given value is not nil.
func (_u *AnnouncementReadUpdate) SetNillableAnnouncementID(v *int64) *AnnouncementReadUpdate {
if v != nil {
_u.SetAnnouncementID(*v)
}
return _u
}
// SetUserID sets the "user_id" field.
func (_u *AnnouncementReadUpdate) SetUserID(v int64) *AnnouncementReadUpdate {
_u.mutation.SetUserID(v)
return _u
}
// SetNillableUserID sets the "user_id" field if the given value is not nil.
func (_u *AnnouncementReadUpdate) SetNillableUserID(v *int64) *AnnouncementReadUpdate {
if v != nil {
_u.SetUserID(*v)
}
return _u
}
// SetReadAt sets the "read_at" field.
func (_u *AnnouncementReadUpdate) SetReadAt(v time.Time) *AnnouncementReadUpdate {
_u.mutation.SetReadAt(v)
return _u
}
// SetNillableReadAt sets the "read_at" field if the given value is not nil.
func (_u *AnnouncementReadUpdate) SetNillableReadAt(v *time.Time) *AnnouncementReadUpdate {
if v != nil {
_u.SetReadAt(*v)
}
return _u
}
// SetAnnouncement sets the "announcement" edge to the Announcement entity.
func (_u *AnnouncementReadUpdate) SetAnnouncement(v *Announcement) *AnnouncementReadUpdate {
return _u.SetAnnouncementID(v.ID)
}
// SetUser sets the "user" edge to the User entity.
func (_u *AnnouncementReadUpdate) SetUser(v *User) *AnnouncementReadUpdate {
return _u.SetUserID(v.ID)
}
// Mutation returns the AnnouncementReadMutation object of the builder.
func (_u *AnnouncementReadUpdate) Mutation() *AnnouncementReadMutation {
return _u.mutation
}
// ClearAnnouncement clears the "announcement" edge to the Announcement entity.
func (_u *AnnouncementReadUpdate) ClearAnnouncement() *AnnouncementReadUpdate {
_u.mutation.ClearAnnouncement()
return _u
}
// ClearUser clears the "user" edge to the User entity.
func (_u *AnnouncementReadUpdate) ClearUser() *AnnouncementReadUpdate {
_u.mutation.ClearUser()
return _u
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *AnnouncementReadUpdate) Save(ctx context.Context) (int, error) {
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *AnnouncementReadUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *AnnouncementReadUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *AnnouncementReadUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *AnnouncementReadUpdate) check() error {
if _u.mutation.AnnouncementCleared() && len(_u.mutation.AnnouncementIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "AnnouncementRead.announcement"`)
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "AnnouncementRead.user"`)
}
return nil
}
func (_u *AnnouncementReadUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(announcementread.Table, announcementread.Columns, sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.ReadAt(); ok {
_spec.SetField(announcementread.FieldReadAt, field.TypeTime, value)
}
if _u.mutation.AnnouncementCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: announcementread.AnnouncementTable,
Columns: []string{announcementread.AnnouncementColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AnnouncementIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: announcementread.AnnouncementTable,
Columns: []string{announcementread.AnnouncementColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: announcementread.UserTable,
Columns: []string{announcementread.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: announcementread.UserTable,
Columns: []string{announcementread.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{announcementread.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// AnnouncementReadUpdateOne is the builder for updating a single AnnouncementRead entity.
type AnnouncementReadUpdateOne struct {
config
fields []string
hooks []Hook
mutation *AnnouncementReadMutation
}
// SetAnnouncementID sets the "announcement_id" field.
func (_u *AnnouncementReadUpdateOne) SetAnnouncementID(v int64) *AnnouncementReadUpdateOne {
_u.mutation.SetAnnouncementID(v)
return _u
}
// SetNillableAnnouncementID sets the "announcement_id" field if the given value is not nil.
func (_u *AnnouncementReadUpdateOne) SetNillableAnnouncementID(v *int64) *AnnouncementReadUpdateOne {
if v != nil {
_u.SetAnnouncementID(*v)
}
return _u
}
// SetUserID sets the "user_id" field.
func (_u *AnnouncementReadUpdateOne) SetUserID(v int64) *AnnouncementReadUpdateOne {
_u.mutation.SetUserID(v)
return _u
}
// SetNillableUserID sets the "user_id" field if the given value is not nil.
func (_u *AnnouncementReadUpdateOne) SetNillableUserID(v *int64) *AnnouncementReadUpdateOne {
if v != nil {
_u.SetUserID(*v)
}
return _u
}
// SetReadAt sets the "read_at" field.
func (_u *AnnouncementReadUpdateOne) SetReadAt(v time.Time) *AnnouncementReadUpdateOne {
_u.mutation.SetReadAt(v)
return _u
}
// SetNillableReadAt sets the "read_at" field if the given value is not nil.
func (_u *AnnouncementReadUpdateOne) SetNillableReadAt(v *time.Time) *AnnouncementReadUpdateOne {
if v != nil {
_u.SetReadAt(*v)
}
return _u
}
// SetAnnouncement sets the "announcement" edge to the Announcement entity.
func (_u *AnnouncementReadUpdateOne) SetAnnouncement(v *Announcement) *AnnouncementReadUpdateOne {
return _u.SetAnnouncementID(v.ID)
}
// SetUser sets the "user" edge to the User entity.
func (_u *AnnouncementReadUpdateOne) SetUser(v *User) *AnnouncementReadUpdateOne {
return _u.SetUserID(v.ID)
}
// Mutation returns the AnnouncementReadMutation object of the builder.
func (_u *AnnouncementReadUpdateOne) Mutation() *AnnouncementReadMutation {
return _u.mutation
}
// ClearAnnouncement clears the "announcement" edge to the Announcement entity.
func (_u *AnnouncementReadUpdateOne) ClearAnnouncement() *AnnouncementReadUpdateOne {
_u.mutation.ClearAnnouncement()
return _u
}
// ClearUser clears the "user" edge to the User entity.
func (_u *AnnouncementReadUpdateOne) ClearUser() *AnnouncementReadUpdateOne {
_u.mutation.ClearUser()
return _u
}
// Where appends a list predicates to the AnnouncementReadUpdate builder.
func (_u *AnnouncementReadUpdateOne) Where(ps ...predicate.AnnouncementRead) *AnnouncementReadUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *AnnouncementReadUpdateOne) Select(field string, fields ...string) *AnnouncementReadUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated AnnouncementRead entity.
func (_u *AnnouncementReadUpdateOne) Save(ctx context.Context) (*AnnouncementRead, error) {
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *AnnouncementReadUpdateOne) SaveX(ctx context.Context) *AnnouncementRead {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *AnnouncementReadUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *AnnouncementReadUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *AnnouncementReadUpdateOne) check() error {
if _u.mutation.AnnouncementCleared() && len(_u.mutation.AnnouncementIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "AnnouncementRead.announcement"`)
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "AnnouncementRead.user"`)
}
return nil
}
func (_u *AnnouncementReadUpdateOne) sqlSave(ctx context.Context) (_node *AnnouncementRead, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(announcementread.Table, announcementread.Columns, sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AnnouncementRead.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, announcementread.FieldID)
for _, f := range fields {
if !announcementread.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != announcementread.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.ReadAt(); ok {
_spec.SetField(announcementread.FieldReadAt, field.TypeTime, value)
}
if _u.mutation.AnnouncementCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: announcementread.AnnouncementTable,
Columns: []string{announcementread.AnnouncementColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AnnouncementIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: announcementread.AnnouncementTable,
Columns: []string{announcementread.AnnouncementColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: announcementread.UserTable,
Columns: []string{announcementread.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: announcementread.UserTable,
Columns: []string{announcementread.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &AnnouncementRead{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{announcementread.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}

View File

@@ -17,6 +17,8 @@ import (
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/account"
"github.com/Wei-Shaw/sub2api/ent/accountgroup"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/promocode"
@@ -46,6 +48,10 @@ type Client struct {
Account *AccountClient
// AccountGroup is the client for interacting with the AccountGroup builders.
AccountGroup *AccountGroupClient
// Announcement is the client for interacting with the Announcement builders.
Announcement *AnnouncementClient
// AnnouncementRead is the client for interacting with the AnnouncementRead builders.
AnnouncementRead *AnnouncementReadClient
// Group is the client for interacting with the Group builders.
Group *GroupClient
// PromoCode is the client for interacting with the PromoCode builders.
@@ -86,6 +92,8 @@ func (c *Client) init() {
c.APIKey = NewAPIKeyClient(c.config)
c.Account = NewAccountClient(c.config)
c.AccountGroup = NewAccountGroupClient(c.config)
c.Announcement = NewAnnouncementClient(c.config)
c.AnnouncementRead = NewAnnouncementReadClient(c.config)
c.Group = NewGroupClient(c.config)
c.PromoCode = NewPromoCodeClient(c.config)
c.PromoCodeUsage = NewPromoCodeUsageClient(c.config)
@@ -194,6 +202,8 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
APIKey: NewAPIKeyClient(cfg),
Account: NewAccountClient(cfg),
AccountGroup: NewAccountGroupClient(cfg),
Announcement: NewAnnouncementClient(cfg),
AnnouncementRead: NewAnnouncementReadClient(cfg),
Group: NewGroupClient(cfg),
PromoCode: NewPromoCodeClient(cfg),
PromoCodeUsage: NewPromoCodeUsageClient(cfg),
@@ -229,6 +239,8 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
APIKey: NewAPIKeyClient(cfg),
Account: NewAccountClient(cfg),
AccountGroup: NewAccountGroupClient(cfg),
Announcement: NewAnnouncementClient(cfg),
AnnouncementRead: NewAnnouncementReadClient(cfg),
Group: NewGroupClient(cfg),
PromoCode: NewPromoCodeClient(cfg),
PromoCodeUsage: NewPromoCodeUsageClient(cfg),
@@ -271,10 +283,10 @@ func (c *Client) Close() error {
// In order to add hooks to a specific client, call: `client.Node.Use(...)`.
func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{
c.APIKey, c.Account, c.AccountGroup, c.Group, c.PromoCode, c.PromoCodeUsage,
c.Proxy, c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User,
c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.Setting,
c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup,
c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription,
} {
n.Use(hooks...)
}
@@ -284,10 +296,10 @@ func (c *Client) Use(hooks ...Hook) {
// In order to add interceptors to a specific client, call: `client.Node.Intercept(...)`.
func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{
c.APIKey, c.Account, c.AccountGroup, c.Group, c.PromoCode, c.PromoCodeUsage,
c.Proxy, c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User,
c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.Setting,
c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup,
c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription,
} {
n.Intercept(interceptors...)
}
@@ -302,6 +314,10 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.Account.mutate(ctx, m)
case *AccountGroupMutation:
return c.AccountGroup.mutate(ctx, m)
case *AnnouncementMutation:
return c.Announcement.mutate(ctx, m)
case *AnnouncementReadMutation:
return c.AnnouncementRead.mutate(ctx, m)
case *GroupMutation:
return c.Group.mutate(ctx, m)
case *PromoCodeMutation:
@@ -831,6 +847,320 @@ func (c *AccountGroupClient) mutate(ctx context.Context, m *AccountGroupMutation
}
}
// AnnouncementClient is a client for the Announcement schema.
type AnnouncementClient struct {
config
}
// NewAnnouncementClient returns a client for the Announcement from the given config.
func NewAnnouncementClient(c config) *AnnouncementClient {
return &AnnouncementClient{config: c}
}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `announcement.Hooks(f(g(h())))`.
func (c *AnnouncementClient) Use(hooks ...Hook) {
c.hooks.Announcement = append(c.hooks.Announcement, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `announcement.Intercept(f(g(h())))`.
func (c *AnnouncementClient) Intercept(interceptors ...Interceptor) {
c.inters.Announcement = append(c.inters.Announcement, interceptors...)
}
// Create returns a builder for creating a Announcement entity.
func (c *AnnouncementClient) Create() *AnnouncementCreate {
mutation := newAnnouncementMutation(c.config, OpCreate)
return &AnnouncementCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of Announcement entities.
func (c *AnnouncementClient) CreateBulk(builders ...*AnnouncementCreate) *AnnouncementCreateBulk {
return &AnnouncementCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *AnnouncementClient) MapCreateBulk(slice any, setFunc func(*AnnouncementCreate, int)) *AnnouncementCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &AnnouncementCreateBulk{err: fmt.Errorf("calling to AnnouncementClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*AnnouncementCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &AnnouncementCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for Announcement.
func (c *AnnouncementClient) Update() *AnnouncementUpdate {
mutation := newAnnouncementMutation(c.config, OpUpdate)
return &AnnouncementUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *AnnouncementClient) UpdateOne(_m *Announcement) *AnnouncementUpdateOne {
mutation := newAnnouncementMutation(c.config, OpUpdateOne, withAnnouncement(_m))
return &AnnouncementUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOneID returns an update builder for the given id.
func (c *AnnouncementClient) UpdateOneID(id int64) *AnnouncementUpdateOne {
mutation := newAnnouncementMutation(c.config, OpUpdateOne, withAnnouncementID(id))
return &AnnouncementUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// Delete returns a delete builder for Announcement.
func (c *AnnouncementClient) Delete() *AnnouncementDelete {
mutation := newAnnouncementMutation(c.config, OpDelete)
return &AnnouncementDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// DeleteOne returns a builder for deleting the given entity.
func (c *AnnouncementClient) DeleteOne(_m *Announcement) *AnnouncementDeleteOne {
return c.DeleteOneID(_m.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *AnnouncementClient) DeleteOneID(id int64) *AnnouncementDeleteOne {
builder := c.Delete().Where(announcement.ID(id))
builder.mutation.id = &id
builder.mutation.op = OpDeleteOne
return &AnnouncementDeleteOne{builder}
}
// Query returns a query builder for Announcement.
func (c *AnnouncementClient) Query() *AnnouncementQuery {
return &AnnouncementQuery{
config: c.config,
ctx: &QueryContext{Type: TypeAnnouncement},
inters: c.Interceptors(),
}
}
// Get returns a Announcement entity by its id.
func (c *AnnouncementClient) Get(ctx context.Context, id int64) (*Announcement, error) {
return c.Query().Where(announcement.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *AnnouncementClient) GetX(ctx context.Context, id int64) *Announcement {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
// QueryReads queries the reads edge of a Announcement.
func (c *AnnouncementClient) QueryReads(_m *Announcement) *AnnouncementReadQuery {
query := (&AnnouncementReadClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(announcement.Table, announcement.FieldID, id),
sqlgraph.To(announcementread.Table, announcementread.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, announcement.ReadsTable, announcement.ReadsColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// Hooks returns the client hooks.
func (c *AnnouncementClient) Hooks() []Hook {
return c.hooks.Announcement
}
// Interceptors returns the client interceptors.
func (c *AnnouncementClient) Interceptors() []Interceptor {
return c.inters.Announcement
}
func (c *AnnouncementClient) mutate(ctx context.Context, m *AnnouncementMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&AnnouncementCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&AnnouncementUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&AnnouncementUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&AnnouncementDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown Announcement mutation op: %q", m.Op())
}
}
// AnnouncementReadClient is a client for the AnnouncementRead schema.
type AnnouncementReadClient struct {
config
}
// NewAnnouncementReadClient returns a client for the AnnouncementRead from the given config.
func NewAnnouncementReadClient(c config) *AnnouncementReadClient {
return &AnnouncementReadClient{config: c}
}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `announcementread.Hooks(f(g(h())))`.
func (c *AnnouncementReadClient) Use(hooks ...Hook) {
c.hooks.AnnouncementRead = append(c.hooks.AnnouncementRead, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `announcementread.Intercept(f(g(h())))`.
func (c *AnnouncementReadClient) Intercept(interceptors ...Interceptor) {
c.inters.AnnouncementRead = append(c.inters.AnnouncementRead, interceptors...)
}
// Create returns a builder for creating a AnnouncementRead entity.
func (c *AnnouncementReadClient) Create() *AnnouncementReadCreate {
mutation := newAnnouncementReadMutation(c.config, OpCreate)
return &AnnouncementReadCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of AnnouncementRead entities.
func (c *AnnouncementReadClient) CreateBulk(builders ...*AnnouncementReadCreate) *AnnouncementReadCreateBulk {
return &AnnouncementReadCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *AnnouncementReadClient) MapCreateBulk(slice any, setFunc func(*AnnouncementReadCreate, int)) *AnnouncementReadCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &AnnouncementReadCreateBulk{err: fmt.Errorf("calling to AnnouncementReadClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*AnnouncementReadCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &AnnouncementReadCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for AnnouncementRead.
func (c *AnnouncementReadClient) Update() *AnnouncementReadUpdate {
mutation := newAnnouncementReadMutation(c.config, OpUpdate)
return &AnnouncementReadUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *AnnouncementReadClient) UpdateOne(_m *AnnouncementRead) *AnnouncementReadUpdateOne {
mutation := newAnnouncementReadMutation(c.config, OpUpdateOne, withAnnouncementRead(_m))
return &AnnouncementReadUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOneID returns an update builder for the given id.
func (c *AnnouncementReadClient) UpdateOneID(id int64) *AnnouncementReadUpdateOne {
mutation := newAnnouncementReadMutation(c.config, OpUpdateOne, withAnnouncementReadID(id))
return &AnnouncementReadUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// Delete returns a delete builder for AnnouncementRead.
func (c *AnnouncementReadClient) Delete() *AnnouncementReadDelete {
mutation := newAnnouncementReadMutation(c.config, OpDelete)
return &AnnouncementReadDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// DeleteOne returns a builder for deleting the given entity.
func (c *AnnouncementReadClient) DeleteOne(_m *AnnouncementRead) *AnnouncementReadDeleteOne {
return c.DeleteOneID(_m.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *AnnouncementReadClient) DeleteOneID(id int64) *AnnouncementReadDeleteOne {
builder := c.Delete().Where(announcementread.ID(id))
builder.mutation.id = &id
builder.mutation.op = OpDeleteOne
return &AnnouncementReadDeleteOne{builder}
}
// Query returns a query builder for AnnouncementRead.
func (c *AnnouncementReadClient) Query() *AnnouncementReadQuery {
return &AnnouncementReadQuery{
config: c.config,
ctx: &QueryContext{Type: TypeAnnouncementRead},
inters: c.Interceptors(),
}
}
// Get returns a AnnouncementRead entity by its id.
func (c *AnnouncementReadClient) Get(ctx context.Context, id int64) (*AnnouncementRead, error) {
return c.Query().Where(announcementread.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *AnnouncementReadClient) GetX(ctx context.Context, id int64) *AnnouncementRead {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
// QueryAnnouncement queries the announcement edge of a AnnouncementRead.
func (c *AnnouncementReadClient) QueryAnnouncement(_m *AnnouncementRead) *AnnouncementQuery {
query := (&AnnouncementClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(announcementread.Table, announcementread.FieldID, id),
sqlgraph.To(announcement.Table, announcement.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, announcementread.AnnouncementTable, announcementread.AnnouncementColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// QueryUser queries the user edge of a AnnouncementRead.
func (c *AnnouncementReadClient) QueryUser(_m *AnnouncementRead) *UserQuery {
query := (&UserClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(announcementread.Table, announcementread.FieldID, id),
sqlgraph.To(user.Table, user.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, announcementread.UserTable, announcementread.UserColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// Hooks returns the client hooks.
func (c *AnnouncementReadClient) Hooks() []Hook {
return c.hooks.AnnouncementRead
}
// Interceptors returns the client interceptors.
func (c *AnnouncementReadClient) Interceptors() []Interceptor {
return c.inters.AnnouncementRead
}
func (c *AnnouncementReadClient) mutate(ctx context.Context, m *AnnouncementReadMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&AnnouncementReadCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&AnnouncementReadUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&AnnouncementReadUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&AnnouncementReadDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown AnnouncementRead mutation op: %q", m.Op())
}
}
// GroupClient is a client for the Group schema.
type GroupClient struct {
config
@@ -2375,6 +2705,22 @@ func (c *UserClient) QueryAssignedSubscriptions(_m *User) *UserSubscriptionQuery
return query
}
// QueryAnnouncementReads queries the announcement_reads edge of a User.
func (c *UserClient) QueryAnnouncementReads(_m *User) *AnnouncementReadQuery {
query := (&AnnouncementReadClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, id),
sqlgraph.To(announcementread.Table, announcementread.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.AnnouncementReadsTable, user.AnnouncementReadsColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// QueryAllowedGroups queries the allowed_groups edge of a User.
func (c *UserClient) QueryAllowedGroups(_m *User) *GroupQuery {
query := (&GroupClient{config: c.config}).Query()
@@ -3116,14 +3462,16 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
// hooks and interceptors per client, for fast access.
type (
hooks struct {
APIKey, Account, AccountGroup, Group, PromoCode, PromoCodeUsage, Proxy,
RedeemCode, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook
APIKey, Account, AccountGroup, Announcement, AnnouncementRead, Group, PromoCode,
PromoCodeUsage, Proxy, RedeemCode, Setting, UsageCleanupTask, UsageLog, User,
UserAllowedGroup, UserAttributeDefinition, UserAttributeValue,
UserSubscription []ent.Hook
}
inters struct {
APIKey, Account, AccountGroup, Group, PromoCode, PromoCodeUsage, Proxy,
RedeemCode, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor
APIKey, Account, AccountGroup, Announcement, AnnouncementRead, Group, PromoCode,
PromoCodeUsage, Proxy, RedeemCode, Setting, UsageCleanupTask, UsageLog, User,
UserAllowedGroup, UserAttributeDefinition, UserAttributeValue,
UserSubscription []ent.Interceptor
}
)

View File

@@ -14,6 +14,8 @@ import (
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/account"
"github.com/Wei-Shaw/sub2api/ent/accountgroup"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/promocode"
@@ -91,6 +93,8 @@ func checkColumn(t, c string) error {
apikey.Table: apikey.ValidColumn,
account.Table: account.ValidColumn,
accountgroup.Table: accountgroup.ValidColumn,
announcement.Table: announcement.ValidColumn,
announcementread.Table: announcementread.ValidColumn,
group.Table: group.ValidColumn,
promocode.Table: promocode.ValidColumn,
promocodeusage.Table: promocodeusage.ValidColumn,

View File

@@ -45,6 +45,30 @@ func (f AccountGroupFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AccountGroupMutation", m)
}
// The AnnouncementFunc type is an adapter to allow the use of ordinary
// function as Announcement mutator.
type AnnouncementFunc func(context.Context, *ent.AnnouncementMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f AnnouncementFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if mv, ok := m.(*ent.AnnouncementMutation); ok {
return f(ctx, mv)
}
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AnnouncementMutation", m)
}
// The AnnouncementReadFunc type is an adapter to allow the use of ordinary
// function as AnnouncementRead mutator.
type AnnouncementReadFunc func(context.Context, *ent.AnnouncementReadMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f AnnouncementReadFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if mv, ok := m.(*ent.AnnouncementReadMutation); ok {
return f(ctx, mv)
}
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AnnouncementReadMutation", m)
}
// The GroupFunc type is an adapter to allow the use of ordinary
// function as Group mutator.
type GroupFunc func(context.Context, *ent.GroupMutation) (ent.Value, error)

View File

@@ -10,6 +10,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/account"
"github.com/Wei-Shaw/sub2api/ent/accountgroup"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
@@ -164,6 +166,60 @@ func (f TraverseAccountGroup) Traverse(ctx context.Context, q ent.Query) error {
return fmt.Errorf("unexpected query type %T. expect *ent.AccountGroupQuery", q)
}
// The AnnouncementFunc type is an adapter to allow the use of ordinary function as a Querier.
type AnnouncementFunc func(context.Context, *ent.AnnouncementQuery) (ent.Value, error)
// Query calls f(ctx, q).
func (f AnnouncementFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
if q, ok := q.(*ent.AnnouncementQuery); ok {
return f(ctx, q)
}
return nil, fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementQuery", q)
}
// The TraverseAnnouncement type is an adapter to allow the use of ordinary function as Traverser.
type TraverseAnnouncement func(context.Context, *ent.AnnouncementQuery) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f TraverseAnnouncement) Intercept(next ent.Querier) ent.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f TraverseAnnouncement) Traverse(ctx context.Context, q ent.Query) error {
if q, ok := q.(*ent.AnnouncementQuery); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementQuery", q)
}
// The AnnouncementReadFunc type is an adapter to allow the use of ordinary function as a Querier.
type AnnouncementReadFunc func(context.Context, *ent.AnnouncementReadQuery) (ent.Value, error)
// Query calls f(ctx, q).
func (f AnnouncementReadFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
if q, ok := q.(*ent.AnnouncementReadQuery); ok {
return f(ctx, q)
}
return nil, fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q)
}
// The TraverseAnnouncementRead type is an adapter to allow the use of ordinary function as Traverser.
type TraverseAnnouncementRead func(context.Context, *ent.AnnouncementReadQuery) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f TraverseAnnouncementRead) Intercept(next ent.Querier) ent.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f TraverseAnnouncementRead) Traverse(ctx context.Context, q ent.Query) error {
if q, ok := q.(*ent.AnnouncementReadQuery); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q)
}
// The GroupFunc type is an adapter to allow the use of ordinary function as a Querier.
type GroupFunc func(context.Context, *ent.GroupQuery) (ent.Value, error)
@@ -524,6 +580,10 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.AccountQuery, predicate.Account, account.OrderOption]{typ: ent.TypeAccount, tq: q}, nil
case *ent.AccountGroupQuery:
return &query[*ent.AccountGroupQuery, predicate.AccountGroup, accountgroup.OrderOption]{typ: ent.TypeAccountGroup, tq: q}, nil
case *ent.AnnouncementQuery:
return &query[*ent.AnnouncementQuery, predicate.Announcement, announcement.OrderOption]{typ: ent.TypeAnnouncement, tq: q}, nil
case *ent.AnnouncementReadQuery:
return &query[*ent.AnnouncementReadQuery, predicate.AnnouncementRead, announcementread.OrderOption]{typ: ent.TypeAnnouncementRead, tq: q}, nil
case *ent.GroupQuery:
return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil
case *ent.PromoCodeQuery:

View File

@@ -204,6 +204,98 @@ var (
},
},
}
// AnnouncementsColumns holds the columns for the "announcements" table.
AnnouncementsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "title", Type: field.TypeString, Size: 200},
{Name: "content", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
{Name: "status", Type: field.TypeString, Size: 20, Default: "draft"},
{Name: "targeting", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "starts_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "ends_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "created_by", Type: field.TypeInt64, Nullable: true},
{Name: "updated_by", Type: field.TypeInt64, Nullable: true},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
}
// AnnouncementsTable holds the schema information for the "announcements" table.
AnnouncementsTable = &schema.Table{
Name: "announcements",
Columns: AnnouncementsColumns,
PrimaryKey: []*schema.Column{AnnouncementsColumns[0]},
Indexes: []*schema.Index{
{
Name: "announcement_status",
Unique: false,
Columns: []*schema.Column{AnnouncementsColumns[3]},
},
{
Name: "announcement_created_at",
Unique: false,
Columns: []*schema.Column{AnnouncementsColumns[9]},
},
{
Name: "announcement_starts_at",
Unique: false,
Columns: []*schema.Column{AnnouncementsColumns[5]},
},
{
Name: "announcement_ends_at",
Unique: false,
Columns: []*schema.Column{AnnouncementsColumns[6]},
},
},
}
// AnnouncementReadsColumns holds the columns for the "announcement_reads" table.
AnnouncementReadsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "read_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "announcement_id", Type: field.TypeInt64},
{Name: "user_id", Type: field.TypeInt64},
}
// AnnouncementReadsTable holds the schema information for the "announcement_reads" table.
AnnouncementReadsTable = &schema.Table{
Name: "announcement_reads",
Columns: AnnouncementReadsColumns,
PrimaryKey: []*schema.Column{AnnouncementReadsColumns[0]},
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "announcement_reads_announcements_reads",
Columns: []*schema.Column{AnnouncementReadsColumns[3]},
RefColumns: []*schema.Column{AnnouncementsColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "announcement_reads_users_announcement_reads",
Columns: []*schema.Column{AnnouncementReadsColumns[4]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
},
Indexes: []*schema.Index{
{
Name: "announcementread_announcement_id",
Unique: false,
Columns: []*schema.Column{AnnouncementReadsColumns[3]},
},
{
Name: "announcementread_user_id",
Unique: false,
Columns: []*schema.Column{AnnouncementReadsColumns[4]},
},
{
Name: "announcementread_read_at",
Unique: false,
Columns: []*schema.Column{AnnouncementReadsColumns[1]},
},
{
Name: "announcementread_announcement_id_user_id",
Unique: true,
Columns: []*schema.Column{AnnouncementReadsColumns[3], AnnouncementReadsColumns[4]},
},
},
}
// GroupsColumns holds the columns for the "groups" table.
GroupsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -610,6 +702,9 @@ var (
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
{Name: "username", Type: field.TypeString, Size: 100, Default: ""},
{Name: "notes", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
}
// UsersTable holds the schema information for the "users" table.
UsersTable = &schema.Table{
@@ -837,6 +932,8 @@ var (
APIKeysTable,
AccountsTable,
AccountGroupsTable,
AnnouncementsTable,
AnnouncementReadsTable,
GroupsTable,
PromoCodesTable,
PromoCodeUsagesTable,
@@ -868,6 +965,14 @@ func init() {
AccountGroupsTable.Annotation = &entsql.Annotation{
Table: "account_groups",
}
AnnouncementsTable.Annotation = &entsql.Annotation{
Table: "announcements",
}
AnnouncementReadsTable.ForeignKeys[0].RefTable = AnnouncementsTable
AnnouncementReadsTable.ForeignKeys[1].RefTable = UsersTable
AnnouncementReadsTable.Annotation = &entsql.Annotation{
Table: "announcement_reads",
}
GroupsTable.Annotation = &entsql.Annotation{
Table: "groups",
}

File diff suppressed because it is too large Load Diff

View File

@@ -15,6 +15,12 @@ type Account func(*sql.Selector)
// AccountGroup is the predicate function for accountgroup builders.
type AccountGroup func(*sql.Selector)
// Announcement is the predicate function for announcement builders.
type Announcement func(*sql.Selector)
// AnnouncementRead is the predicate function for announcementread builders.
type AnnouncementRead func(*sql.Selector)
// Group is the predicate function for group builders.
type Group func(*sql.Selector)

View File

@@ -7,6 +7,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent/account"
"github.com/Wei-Shaw/sub2api/ent/accountgroup"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/promocode"
@@ -210,6 +212,56 @@ func init() {
accountgroupDescCreatedAt := accountgroupFields[3].Descriptor()
// accountgroup.DefaultCreatedAt holds the default value on creation for the created_at field.
accountgroup.DefaultCreatedAt = accountgroupDescCreatedAt.Default.(func() time.Time)
announcementFields := schema.Announcement{}.Fields()
_ = announcementFields
// announcementDescTitle is the schema descriptor for title field.
announcementDescTitle := announcementFields[0].Descriptor()
// announcement.TitleValidator is a validator for the "title" field. It is called by the builders before save.
announcement.TitleValidator = func() func(string) error {
validators := announcementDescTitle.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
}
return func(title string) error {
for _, fn := range fns {
if err := fn(title); err != nil {
return err
}
}
return nil
}
}()
// announcementDescContent is the schema descriptor for content field.
announcementDescContent := announcementFields[1].Descriptor()
// announcement.ContentValidator is a validator for the "content" field. It is called by the builders before save.
announcement.ContentValidator = announcementDescContent.Validators[0].(func(string) error)
// announcementDescStatus is the schema descriptor for status field.
announcementDescStatus := announcementFields[2].Descriptor()
// announcement.DefaultStatus holds the default value on creation for the status field.
announcement.DefaultStatus = announcementDescStatus.Default.(string)
// announcement.StatusValidator is a validator for the "status" field. It is called by the builders before save.
announcement.StatusValidator = announcementDescStatus.Validators[0].(func(string) error)
// announcementDescCreatedAt is the schema descriptor for created_at field.
announcementDescCreatedAt := announcementFields[8].Descriptor()
// announcement.DefaultCreatedAt holds the default value on creation for the created_at field.
announcement.DefaultCreatedAt = announcementDescCreatedAt.Default.(func() time.Time)
// announcementDescUpdatedAt is the schema descriptor for updated_at field.
announcementDescUpdatedAt := announcementFields[9].Descriptor()
// announcement.DefaultUpdatedAt holds the default value on creation for the updated_at field.
announcement.DefaultUpdatedAt = announcementDescUpdatedAt.Default.(func() time.Time)
// announcement.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
announcement.UpdateDefaultUpdatedAt = announcementDescUpdatedAt.UpdateDefault.(func() time.Time)
announcementreadFields := schema.AnnouncementRead{}.Fields()
_ = announcementreadFields
// announcementreadDescReadAt is the schema descriptor for read_at field.
announcementreadDescReadAt := announcementreadFields[2].Descriptor()
// announcementread.DefaultReadAt holds the default value on creation for the read_at field.
announcementread.DefaultReadAt = announcementreadDescReadAt.Default.(func() time.Time)
// announcementreadDescCreatedAt is the schema descriptor for created_at field.
announcementreadDescCreatedAt := announcementreadFields[3].Descriptor()
// announcementread.DefaultCreatedAt holds the default value on creation for the created_at field.
announcementread.DefaultCreatedAt = announcementreadDescCreatedAt.Default.(func() time.Time)
groupMixin := schema.Group{}.Mixin()
groupMixinHooks1 := groupMixin[1].Hooks()
group.Hooks[0] = groupMixinHooks1[0]
@@ -736,6 +788,10 @@ func init() {
userDescNotes := userFields[7].Descriptor()
// user.DefaultNotes holds the default value on creation for the notes field.
user.DefaultNotes = userDescNotes.Default.(string)
// userDescTotpEnabled is the schema descriptor for totp_enabled field.
userDescTotpEnabled := userFields[9].Descriptor()
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
_ = userallowedgroupFields
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.

View File

@@ -4,7 +4,7 @@ package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/domain"
"entgo.io/ent"
"entgo.io/ent/dialect"
@@ -111,7 +111,7 @@ func (Account) Fields() []ent.Field {
// status: 账户状态,如 "active", "error", "disabled"
field.String("status").
MaxLen(20).
Default(service.StatusActive),
Default(domain.StatusActive),
// error_message: 错误信息,记录账户异常时的详细信息
field.String("error_message").

View File

@@ -0,0 +1,90 @@
package schema
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// Announcement holds the schema definition for the Announcement entity.
//
// 删除策略:硬删除(已读记录通过外键级联删除)
type Announcement struct {
ent.Schema
}
func (Announcement) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "announcements"},
}
}
func (Announcement) Fields() []ent.Field {
return []ent.Field{
field.String("title").
MaxLen(200).
NotEmpty().
Comment("公告标题"),
field.String("content").
SchemaType(map[string]string{dialect.Postgres: "text"}).
NotEmpty().
Comment("公告内容(支持 Markdown"),
field.String("status").
MaxLen(20).
Default(domain.AnnouncementStatusDraft).
Comment("状态: draft, active, archived"),
field.JSON("targeting", domain.AnnouncementTargeting{}).
Optional().
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
Comment("展示条件JSON 规则)"),
field.Time("starts_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}).
Comment("开始展示时间(为空表示立即生效)"),
field.Time("ends_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}).
Comment("结束展示时间(为空表示永久生效)"),
field.Int64("created_by").
Optional().
Nillable().
Comment("创建人用户ID管理员"),
field.Int64("updated_by").
Optional().
Nillable().
Comment("更新人用户ID管理员"),
field.Time("created_at").
Immutable().
Default(time.Now).
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("updated_at").
Default(time.Now).
UpdateDefault(time.Now).
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
}
}
func (Announcement) Edges() []ent.Edge {
return []ent.Edge{
edge.To("reads", AnnouncementRead.Type),
}
}
func (Announcement) Indexes() []ent.Index {
return []ent.Index{
index.Fields("status"),
index.Fields("created_at"),
index.Fields("starts_at"),
index.Fields("ends_at"),
}
}

View File

@@ -0,0 +1,65 @@
package schema
import (
"time"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// AnnouncementRead holds the schema definition for the AnnouncementRead entity.
//
// 记录用户对公告的已读状态(首次已读时间)。
type AnnouncementRead struct {
ent.Schema
}
func (AnnouncementRead) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "announcement_reads"},
}
}
func (AnnouncementRead) Fields() []ent.Field {
return []ent.Field{
field.Int64("announcement_id"),
field.Int64("user_id"),
field.Time("read_at").
Default(time.Now).
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}).
Comment("用户首次已读时间"),
field.Time("created_at").
Immutable().
Default(time.Now).
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
}
}
func (AnnouncementRead) Edges() []ent.Edge {
return []ent.Edge{
edge.From("announcement", Announcement.Type).
Ref("reads").
Field("announcement_id").
Unique().
Required(),
edge.From("user", User.Type).
Ref("announcement_reads").
Field("user_id").
Unique().
Required(),
}
}
func (AnnouncementRead) Indexes() []ent.Index {
return []ent.Index{
index.Fields("announcement_id"),
index.Fields("user_id"),
index.Fields("read_at"),
index.Fields("announcement_id", "user_id").Unique(),
}
}

View File

@@ -2,7 +2,7 @@ package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/domain"
"entgo.io/ent"
"entgo.io/ent/dialect/entsql"
@@ -45,7 +45,7 @@ func (APIKey) Fields() []ent.Field {
Nillable(),
field.String("status").
MaxLen(20).
Default(service.StatusActive),
Default(domain.StatusActive),
field.JSON("ip_whitelist", []string{}).
Optional().
Comment("Allowed IPs/CIDRs, e.g. [\"192.168.1.100\", \"10.0.0.0/8\"]"),

View File

@@ -2,7 +2,7 @@ package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/domain"
"entgo.io/ent"
"entgo.io/ent/dialect"
@@ -49,15 +49,15 @@ func (Group) Fields() []ent.Field {
Default(false),
field.String("status").
MaxLen(20).
Default(service.StatusActive),
Default(domain.StatusActive),
// Subscription-related fields (added by migration 003)
field.String("platform").
MaxLen(50).
Default(service.PlatformAnthropic),
Default(domain.PlatformAnthropic),
field.String("subscription_type").
MaxLen(20).
Default(service.SubscriptionTypeStandard),
Default(domain.SubscriptionTypeStandard),
field.Float("daily_limit_usd").
Optional().
Nillable().

View File

@@ -3,7 +3,7 @@ package schema
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/domain"
"entgo.io/ent"
"entgo.io/ent/dialect"
@@ -49,7 +49,7 @@ func (PromoCode) Fields() []ent.Field {
Comment("已使用次数"),
field.String("status").
MaxLen(20).
Default(service.PromoCodeStatusActive).
Default(domain.PromoCodeStatusActive).
Comment("状态: active, disabled"),
field.Time("expires_at").
Optional().

View File

@@ -3,7 +3,7 @@ package schema
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/domain"
"entgo.io/ent"
"entgo.io/ent/dialect"
@@ -41,13 +41,13 @@ func (RedeemCode) Fields() []ent.Field {
Unique(),
field.String("type").
MaxLen(20).
Default(service.RedeemTypeBalance),
Default(domain.RedeemTypeBalance),
field.Float("value").
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
Default(0),
field.String("status").
MaxLen(20).
Default(service.StatusUnused),
Default(domain.StatusUnused),
field.Int64("used_by").
Optional().
Nillable(),

View File

@@ -2,7 +2,7 @@ package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/domain"
"entgo.io/ent"
"entgo.io/ent/dialect"
@@ -43,7 +43,7 @@ func (User) Fields() []ent.Field {
NotEmpty(),
field.String("role").
MaxLen(20).
Default(service.RoleUser),
Default(domain.RoleUser),
field.Float("balance").
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
Default(0),
@@ -51,7 +51,7 @@ func (User) Fields() []ent.Field {
Default(5),
field.String("status").
MaxLen(20).
Default(service.StatusActive),
Default(domain.StatusActive),
// Optional profile fields (added later; default '' in DB migration)
field.String("username").
@@ -61,6 +61,17 @@ func (User) Fields() []ent.Field {
field.String("notes").
SchemaType(map[string]string{dialect.Postgres: "text"}).
Default(""),
// TOTP 双因素认证字段
field.String("totp_secret_encrypted").
SchemaType(map[string]string{dialect.Postgres: "text"}).
Optional().
Nillable(),
field.Bool("totp_enabled").
Default(false),
field.Time("totp_enabled_at").
Optional().
Nillable(),
}
}
@@ -70,6 +81,7 @@ func (User) Edges() []ent.Edge {
edge.To("redeem_codes", RedeemCode.Type),
edge.To("subscriptions", UserSubscription.Type),
edge.To("assigned_subscriptions", UserSubscription.Type),
edge.To("announcement_reads", AnnouncementRead.Type),
edge.To("allowed_groups", Group.Type).
Through("user_allowed_groups", UserAllowedGroup.Type),
edge.To("usage_logs", UsageLog.Type),

View File

@@ -4,7 +4,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/domain"
"entgo.io/ent"
"entgo.io/ent/dialect"
@@ -44,7 +44,7 @@ func (UserSubscription) Fields() []ent.Field {
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.String("status").
MaxLen(20).
Default(service.SubscriptionStatusActive),
Default(domain.SubscriptionStatusActive),
field.Time("daily_window_start").
Optional().

View File

@@ -20,6 +20,10 @@ type Tx struct {
Account *AccountClient
// AccountGroup is the client for interacting with the AccountGroup builders.
AccountGroup *AccountGroupClient
// Announcement is the client for interacting with the Announcement builders.
Announcement *AnnouncementClient
// AnnouncementRead is the client for interacting with the AnnouncementRead builders.
AnnouncementRead *AnnouncementReadClient
// Group is the client for interacting with the Group builders.
Group *GroupClient
// PromoCode is the client for interacting with the PromoCode builders.
@@ -180,6 +184,8 @@ func (tx *Tx) init() {
tx.APIKey = NewAPIKeyClient(tx.config)
tx.Account = NewAccountClient(tx.config)
tx.AccountGroup = NewAccountGroupClient(tx.config)
tx.Announcement = NewAnnouncementClient(tx.config)
tx.AnnouncementRead = NewAnnouncementReadClient(tx.config)
tx.Group = NewGroupClient(tx.config)
tx.PromoCode = NewPromoCodeClient(tx.config)
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)

View File

@@ -39,6 +39,12 @@ type User struct {
Username string `json:"username,omitempty"`
// Notes holds the value of the "notes" field.
Notes string `json:"notes,omitempty"`
// TotpSecretEncrypted holds the value of the "totp_secret_encrypted" field.
TotpSecretEncrypted *string `json:"totp_secret_encrypted,omitempty"`
// TotpEnabled holds the value of the "totp_enabled" field.
TotpEnabled bool `json:"totp_enabled,omitempty"`
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserQuery when eager-loading is set.
Edges UserEdges `json:"edges"`
@@ -55,6 +61,8 @@ type UserEdges struct {
Subscriptions []*UserSubscription `json:"subscriptions,omitempty"`
// AssignedSubscriptions holds the value of the assigned_subscriptions edge.
AssignedSubscriptions []*UserSubscription `json:"assigned_subscriptions,omitempty"`
// AnnouncementReads holds the value of the announcement_reads edge.
AnnouncementReads []*AnnouncementRead `json:"announcement_reads,omitempty"`
// AllowedGroups holds the value of the allowed_groups edge.
AllowedGroups []*Group `json:"allowed_groups,omitempty"`
// UsageLogs holds the value of the usage_logs edge.
@@ -67,7 +75,7 @@ type UserEdges struct {
UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [9]bool
loadedTypes [10]bool
}
// APIKeysOrErr returns the APIKeys value or an error if the edge
@@ -106,10 +114,19 @@ func (e UserEdges) AssignedSubscriptionsOrErr() ([]*UserSubscription, error) {
return nil, &NotLoadedError{edge: "assigned_subscriptions"}
}
// AnnouncementReadsOrErr returns the AnnouncementReads value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) AnnouncementReadsOrErr() ([]*AnnouncementRead, error) {
if e.loadedTypes[4] {
return e.AnnouncementReads, nil
}
return nil, &NotLoadedError{edge: "announcement_reads"}
}
// AllowedGroupsOrErr returns the AllowedGroups value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) AllowedGroupsOrErr() ([]*Group, error) {
if e.loadedTypes[4] {
if e.loadedTypes[5] {
return e.AllowedGroups, nil
}
return nil, &NotLoadedError{edge: "allowed_groups"}
@@ -118,7 +135,7 @@ func (e UserEdges) AllowedGroupsOrErr() ([]*Group, error) {
// UsageLogsOrErr returns the UsageLogs value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) UsageLogsOrErr() ([]*UsageLog, error) {
if e.loadedTypes[5] {
if e.loadedTypes[6] {
return e.UsageLogs, nil
}
return nil, &NotLoadedError{edge: "usage_logs"}
@@ -127,7 +144,7 @@ func (e UserEdges) UsageLogsOrErr() ([]*UsageLog, error) {
// AttributeValuesOrErr returns the AttributeValues value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) AttributeValuesOrErr() ([]*UserAttributeValue, error) {
if e.loadedTypes[6] {
if e.loadedTypes[7] {
return e.AttributeValues, nil
}
return nil, &NotLoadedError{edge: "attribute_values"}
@@ -136,7 +153,7 @@ func (e UserEdges) AttributeValuesOrErr() ([]*UserAttributeValue, error) {
// PromoCodeUsagesOrErr returns the PromoCodeUsages value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) PromoCodeUsagesOrErr() ([]*PromoCodeUsage, error) {
if e.loadedTypes[7] {
if e.loadedTypes[8] {
return e.PromoCodeUsages, nil
}
return nil, &NotLoadedError{edge: "promo_code_usages"}
@@ -145,7 +162,7 @@ func (e UserEdges) PromoCodeUsagesOrErr() ([]*PromoCodeUsage, error) {
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) {
if e.loadedTypes[8] {
if e.loadedTypes[9] {
return e.UserAllowedGroups, nil
}
return nil, &NotLoadedError{edge: "user_allowed_groups"}
@@ -156,13 +173,15 @@ func (*User) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case user.FieldTotpEnabled:
values[i] = new(sql.NullBool)
case user.FieldBalance:
values[i] = new(sql.NullFloat64)
case user.FieldID, user.FieldConcurrency:
values[i] = new(sql.NullInt64)
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes:
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted:
values[i] = new(sql.NullString)
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt:
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
@@ -252,6 +271,26 @@ func (_m *User) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.Notes = value.String
}
case user.FieldTotpSecretEncrypted:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field totp_secret_encrypted", values[i])
} else if value.Valid {
_m.TotpSecretEncrypted = new(string)
*_m.TotpSecretEncrypted = value.String
}
case user.FieldTotpEnabled:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field totp_enabled", values[i])
} else if value.Valid {
_m.TotpEnabled = value.Bool
}
case user.FieldTotpEnabledAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field totp_enabled_at", values[i])
} else if value.Valid {
_m.TotpEnabledAt = new(time.Time)
*_m.TotpEnabledAt = value.Time
}
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -285,6 +324,11 @@ func (_m *User) QueryAssignedSubscriptions() *UserSubscriptionQuery {
return NewUserClient(_m.config).QueryAssignedSubscriptions(_m)
}
// QueryAnnouncementReads queries the "announcement_reads" edge of the User entity.
func (_m *User) QueryAnnouncementReads() *AnnouncementReadQuery {
return NewUserClient(_m.config).QueryAnnouncementReads(_m)
}
// QueryAllowedGroups queries the "allowed_groups" edge of the User entity.
func (_m *User) QueryAllowedGroups() *GroupQuery {
return NewUserClient(_m.config).QueryAllowedGroups(_m)
@@ -367,6 +411,19 @@ func (_m *User) String() string {
builder.WriteString(", ")
builder.WriteString("notes=")
builder.WriteString(_m.Notes)
builder.WriteString(", ")
if v := _m.TotpSecretEncrypted; v != nil {
builder.WriteString("totp_secret_encrypted=")
builder.WriteString(*v)
}
builder.WriteString(", ")
builder.WriteString("totp_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.TotpEnabled))
builder.WriteString(", ")
if v := _m.TotpEnabledAt; v != nil {
builder.WriteString("totp_enabled_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteByte(')')
return builder.String()
}

View File

@@ -37,6 +37,12 @@ const (
FieldUsername = "username"
// FieldNotes holds the string denoting the notes field in the database.
FieldNotes = "notes"
// FieldTotpSecretEncrypted holds the string denoting the totp_secret_encrypted field in the database.
FieldTotpSecretEncrypted = "totp_secret_encrypted"
// FieldTotpEnabled holds the string denoting the totp_enabled field in the database.
FieldTotpEnabled = "totp_enabled"
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
FieldTotpEnabledAt = "totp_enabled_at"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -45,6 +51,8 @@ const (
EdgeSubscriptions = "subscriptions"
// EdgeAssignedSubscriptions holds the string denoting the assigned_subscriptions edge name in mutations.
EdgeAssignedSubscriptions = "assigned_subscriptions"
// EdgeAnnouncementReads holds the string denoting the announcement_reads edge name in mutations.
EdgeAnnouncementReads = "announcement_reads"
// EdgeAllowedGroups holds the string denoting the allowed_groups edge name in mutations.
EdgeAllowedGroups = "allowed_groups"
// EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations.
@@ -85,6 +93,13 @@ const (
AssignedSubscriptionsInverseTable = "user_subscriptions"
// AssignedSubscriptionsColumn is the table column denoting the assigned_subscriptions relation/edge.
AssignedSubscriptionsColumn = "assigned_by"
// AnnouncementReadsTable is the table that holds the announcement_reads relation/edge.
AnnouncementReadsTable = "announcement_reads"
// AnnouncementReadsInverseTable is the table name for the AnnouncementRead entity.
// It exists in this package in order to avoid circular dependency with the "announcementread" package.
AnnouncementReadsInverseTable = "announcement_reads"
// AnnouncementReadsColumn is the table column denoting the announcement_reads relation/edge.
AnnouncementReadsColumn = "user_id"
// AllowedGroupsTable is the table that holds the allowed_groups relation/edge. The primary key declared below.
AllowedGroupsTable = "user_allowed_groups"
// AllowedGroupsInverseTable is the table name for the Group entity.
@@ -134,6 +149,9 @@ var Columns = []string{
FieldStatus,
FieldUsername,
FieldNotes,
FieldTotpSecretEncrypted,
FieldTotpEnabled,
FieldTotpEnabledAt,
}
var (
@@ -188,6 +206,8 @@ var (
UsernameValidator func(string) error
// DefaultNotes holds the default value on creation for the "notes" field.
DefaultNotes string
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
DefaultTotpEnabled bool
)
// OrderOption defines the ordering options for the User queries.
@@ -253,6 +273,21 @@ func ByNotes(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldNotes, opts...).ToFunc()
}
// ByTotpSecretEncrypted orders the results by the totp_secret_encrypted field.
func ByTotpSecretEncrypted(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpSecretEncrypted, opts...).ToFunc()
}
// ByTotpEnabled orders the results by the totp_enabled field.
func ByTotpEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabled, opts...).ToFunc()
}
// ByTotpEnabledAt orders the results by the totp_enabled_at field.
func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
}
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
@@ -309,6 +344,20 @@ func ByAssignedSubscriptions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOp
}
}
// ByAnnouncementReadsCount orders the results by announcement_reads count.
func ByAnnouncementReadsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newAnnouncementReadsStep(), opts...)
}
}
// ByAnnouncementReads orders the results by announcement_reads terms.
func ByAnnouncementReads(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newAnnouncementReadsStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
// ByAllowedGroupsCount orders the results by allowed_groups count.
func ByAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
@@ -406,6 +455,13 @@ func newAssignedSubscriptionsStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.O2M, false, AssignedSubscriptionsTable, AssignedSubscriptionsColumn),
)
}
func newAnnouncementReadsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(AnnouncementReadsInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, AnnouncementReadsTable, AnnouncementReadsColumn),
)
}
func newAllowedGroupsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),

View File

@@ -110,6 +110,21 @@ func Notes(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldNotes, v))
}
// TotpSecretEncrypted applies equality check predicate on the "totp_secret_encrypted" field. It's identical to TotpSecretEncryptedEQ.
func TotpSecretEncrypted(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpSecretEncrypted, v))
}
// TotpEnabled applies equality check predicate on the "totp_enabled" field. It's identical to TotpEnabledEQ.
func TotpEnabled(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabled, v))
}
// TotpEnabledAt applies equality check predicate on the "totp_enabled_at" field. It's identical to TotpEnabledAtEQ.
func TotpEnabledAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
@@ -710,6 +725,141 @@ func NotesContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldNotes, v))
}
// TotpSecretEncryptedEQ applies the EQ predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedNEQ applies the NEQ predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedNEQ(v string) predicate.User {
return predicate.User(sql.FieldNEQ(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedIn applies the In predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedIn(vs ...string) predicate.User {
return predicate.User(sql.FieldIn(FieldTotpSecretEncrypted, vs...))
}
// TotpSecretEncryptedNotIn applies the NotIn predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedNotIn(vs ...string) predicate.User {
return predicate.User(sql.FieldNotIn(FieldTotpSecretEncrypted, vs...))
}
// TotpSecretEncryptedGT applies the GT predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedGT(v string) predicate.User {
return predicate.User(sql.FieldGT(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedGTE applies the GTE predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedGTE(v string) predicate.User {
return predicate.User(sql.FieldGTE(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedLT applies the LT predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedLT(v string) predicate.User {
return predicate.User(sql.FieldLT(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedLTE applies the LTE predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedLTE(v string) predicate.User {
return predicate.User(sql.FieldLTE(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedContains applies the Contains predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedContains(v string) predicate.User {
return predicate.User(sql.FieldContains(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedHasPrefix applies the HasPrefix predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedHasPrefix(v string) predicate.User {
return predicate.User(sql.FieldHasPrefix(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedHasSuffix applies the HasSuffix predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedHasSuffix(v string) predicate.User {
return predicate.User(sql.FieldHasSuffix(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedIsNil applies the IsNil predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedIsNil() predicate.User {
return predicate.User(sql.FieldIsNull(FieldTotpSecretEncrypted))
}
// TotpSecretEncryptedNotNil applies the NotNil predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpSecretEncrypted))
}
// TotpSecretEncryptedEqualFold applies the EqualFold predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedEqualFold(v string) predicate.User {
return predicate.User(sql.FieldEqualFold(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedContainsFold applies the ContainsFold predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldTotpSecretEncrypted, v))
}
// TotpEnabledEQ applies the EQ predicate on the "totp_enabled" field.
func TotpEnabledEQ(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabled, v))
}
// TotpEnabledNEQ applies the NEQ predicate on the "totp_enabled" field.
func TotpEnabledNEQ(v bool) predicate.User {
return predicate.User(sql.FieldNEQ(FieldTotpEnabled, v))
}
// TotpEnabledAtEQ applies the EQ predicate on the "totp_enabled_at" field.
func TotpEnabledAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
}
// TotpEnabledAtNEQ applies the NEQ predicate on the "totp_enabled_at" field.
func TotpEnabledAtNEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldNEQ(FieldTotpEnabledAt, v))
}
// TotpEnabledAtIn applies the In predicate on the "totp_enabled_at" field.
func TotpEnabledAtIn(vs ...time.Time) predicate.User {
return predicate.User(sql.FieldIn(FieldTotpEnabledAt, vs...))
}
// TotpEnabledAtNotIn applies the NotIn predicate on the "totp_enabled_at" field.
func TotpEnabledAtNotIn(vs ...time.Time) predicate.User {
return predicate.User(sql.FieldNotIn(FieldTotpEnabledAt, vs...))
}
// TotpEnabledAtGT applies the GT predicate on the "totp_enabled_at" field.
func TotpEnabledAtGT(v time.Time) predicate.User {
return predicate.User(sql.FieldGT(FieldTotpEnabledAt, v))
}
// TotpEnabledAtGTE applies the GTE predicate on the "totp_enabled_at" field.
func TotpEnabledAtGTE(v time.Time) predicate.User {
return predicate.User(sql.FieldGTE(FieldTotpEnabledAt, v))
}
// TotpEnabledAtLT applies the LT predicate on the "totp_enabled_at" field.
func TotpEnabledAtLT(v time.Time) predicate.User {
return predicate.User(sql.FieldLT(FieldTotpEnabledAt, v))
}
// TotpEnabledAtLTE applies the LTE predicate on the "totp_enabled_at" field.
func TotpEnabledAtLTE(v time.Time) predicate.User {
return predicate.User(sql.FieldLTE(FieldTotpEnabledAt, v))
}
// TotpEnabledAtIsNil applies the IsNil predicate on the "totp_enabled_at" field.
func TotpEnabledAtIsNil() predicate.User {
return predicate.User(sql.FieldIsNull(FieldTotpEnabledAt))
}
// TotpEnabledAtNotNil applies the NotNil predicate on the "totp_enabled_at" field.
func TotpEnabledAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
}
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.User {
return predicate.User(func(s *sql.Selector) {
@@ -802,6 +952,29 @@ func HasAssignedSubscriptionsWith(preds ...predicate.UserSubscription) predicate
})
}
// HasAnnouncementReads applies the HasEdge predicate on the "announcement_reads" edge.
func HasAnnouncementReads() predicate.User {
return predicate.User(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, AnnouncementReadsTable, AnnouncementReadsColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasAnnouncementReadsWith applies the HasEdge predicate on the "announcement_reads" edge with a given conditions (other predicates).
func HasAnnouncementReadsWith(preds ...predicate.AnnouncementRead) predicate.User {
return predicate.User(func(s *sql.Selector) {
step := newAnnouncementReadsStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasAllowedGroups applies the HasEdge predicate on the "allowed_groups" edge.
func HasAllowedGroups() predicate.User {
return predicate.User(func(s *sql.Selector) {

View File

@@ -11,6 +11,7 @@ import (
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
@@ -167,6 +168,48 @@ func (_c *UserCreate) SetNillableNotes(v *string) *UserCreate {
return _c
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (_c *UserCreate) SetTotpSecretEncrypted(v string) *UserCreate {
_c.mutation.SetTotpSecretEncrypted(v)
return _c
}
// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
func (_c *UserCreate) SetNillableTotpSecretEncrypted(v *string) *UserCreate {
if v != nil {
_c.SetTotpSecretEncrypted(*v)
}
return _c
}
// SetTotpEnabled sets the "totp_enabled" field.
func (_c *UserCreate) SetTotpEnabled(v bool) *UserCreate {
_c.mutation.SetTotpEnabled(v)
return _c
}
// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
func (_c *UserCreate) SetNillableTotpEnabled(v *bool) *UserCreate {
if v != nil {
_c.SetTotpEnabled(*v)
}
return _c
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (_c *UserCreate) SetTotpEnabledAt(v time.Time) *UserCreate {
_c.mutation.SetTotpEnabledAt(v)
return _c
}
// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
if v != nil {
_c.SetTotpEnabledAt(*v)
}
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -227,6 +270,21 @@ func (_c *UserCreate) AddAssignedSubscriptions(v ...*UserSubscription) *UserCrea
return _c.AddAssignedSubscriptionIDs(ids...)
}
// AddAnnouncementReadIDs adds the "announcement_reads" edge to the AnnouncementRead entity by IDs.
func (_c *UserCreate) AddAnnouncementReadIDs(ids ...int64) *UserCreate {
_c.mutation.AddAnnouncementReadIDs(ids...)
return _c
}
// AddAnnouncementReads adds the "announcement_reads" edges to the AnnouncementRead entity.
func (_c *UserCreate) AddAnnouncementReads(v ...*AnnouncementRead) *UserCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddAnnouncementReadIDs(ids...)
}
// AddAllowedGroupIDs adds the "allowed_groups" edge to the Group entity by IDs.
func (_c *UserCreate) AddAllowedGroupIDs(ids ...int64) *UserCreate {
_c.mutation.AddAllowedGroupIDs(ids...)
@@ -362,6 +420,10 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultNotes
_c.mutation.SetNotes(v)
}
if _, ok := _c.mutation.TotpEnabled(); !ok {
v := user.DefaultTotpEnabled
_c.mutation.SetTotpEnabled(v)
}
return nil
}
@@ -422,6 +484,9 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.Notes(); !ok {
return &ValidationError{Name: "notes", err: errors.New(`ent: missing required field "User.notes"`)}
}
if _, ok := _c.mutation.TotpEnabled(); !ok {
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
}
return nil
}
@@ -493,6 +558,18 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldNotes, field.TypeString, value)
_node.Notes = value
}
if value, ok := _c.mutation.TotpSecretEncrypted(); ok {
_spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
_node.TotpSecretEncrypted = &value
}
if value, ok := _c.mutation.TotpEnabled(); ok {
_spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
_node.TotpEnabled = value
}
if value, ok := _c.mutation.TotpEnabledAt(); ok {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
_node.TotpEnabledAt = &value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -557,6 +634,22 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
}
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.AnnouncementReadsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AnnouncementReadsTable,
Columns: []string{user.AnnouncementReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.AllowedGroupsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2M,
@@ -815,6 +908,54 @@ func (u *UserUpsert) UpdateNotes() *UserUpsert {
return u
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (u *UserUpsert) SetTotpSecretEncrypted(v string) *UserUpsert {
u.Set(user.FieldTotpSecretEncrypted, v)
return u
}
// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
func (u *UserUpsert) UpdateTotpSecretEncrypted() *UserUpsert {
u.SetExcluded(user.FieldTotpSecretEncrypted)
return u
}
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
func (u *UserUpsert) ClearTotpSecretEncrypted() *UserUpsert {
u.SetNull(user.FieldTotpSecretEncrypted)
return u
}
// SetTotpEnabled sets the "totp_enabled" field.
func (u *UserUpsert) SetTotpEnabled(v bool) *UserUpsert {
u.Set(user.FieldTotpEnabled, v)
return u
}
// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
func (u *UserUpsert) UpdateTotpEnabled() *UserUpsert {
u.SetExcluded(user.FieldTotpEnabled)
return u
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (u *UserUpsert) SetTotpEnabledAt(v time.Time) *UserUpsert {
u.Set(user.FieldTotpEnabledAt, v)
return u
}
// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
func (u *UserUpsert) UpdateTotpEnabledAt() *UserUpsert {
u.SetExcluded(user.FieldTotpEnabledAt)
return u
}
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
u.SetNull(user.FieldTotpEnabledAt)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -1021,6 +1162,62 @@ func (u *UserUpsertOne) UpdateNotes() *UserUpsertOne {
})
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (u *UserUpsertOne) SetTotpSecretEncrypted(v string) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetTotpSecretEncrypted(v)
})
}
// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateTotpSecretEncrypted() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateTotpSecretEncrypted()
})
}
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
func (u *UserUpsertOne) ClearTotpSecretEncrypted() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.ClearTotpSecretEncrypted()
})
}
// SetTotpEnabled sets the "totp_enabled" field.
func (u *UserUpsertOne) SetTotpEnabled(v bool) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetTotpEnabled(v)
})
}
// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateTotpEnabled() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateTotpEnabled()
})
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (u *UserUpsertOne) SetTotpEnabledAt(v time.Time) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetTotpEnabledAt(v)
})
}
// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateTotpEnabledAt() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateTotpEnabledAt()
})
}
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.ClearTotpEnabledAt()
})
}
// Exec executes the query.
func (u *UserUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -1393,6 +1590,62 @@ func (u *UserUpsertBulk) UpdateNotes() *UserUpsertBulk {
})
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (u *UserUpsertBulk) SetTotpSecretEncrypted(v string) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetTotpSecretEncrypted(v)
})
}
// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateTotpSecretEncrypted() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateTotpSecretEncrypted()
})
}
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
func (u *UserUpsertBulk) ClearTotpSecretEncrypted() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.ClearTotpSecretEncrypted()
})
}
// SetTotpEnabled sets the "totp_enabled" field.
func (u *UserUpsertBulk) SetTotpEnabled(v bool) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetTotpEnabled(v)
})
}
// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateTotpEnabled() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateTotpEnabled()
})
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (u *UserUpsertBulk) SetTotpEnabledAt(v time.Time) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetTotpEnabledAt(v)
})
}
// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateTotpEnabledAt() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateTotpEnabledAt()
})
}
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.ClearTotpEnabledAt()
})
}
// Exec executes the query.
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {

View File

@@ -13,6 +13,7 @@ import (
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
@@ -36,6 +37,7 @@ type UserQuery struct {
withRedeemCodes *RedeemCodeQuery
withSubscriptions *UserSubscriptionQuery
withAssignedSubscriptions *UserSubscriptionQuery
withAnnouncementReads *AnnouncementReadQuery
withAllowedGroups *GroupQuery
withUsageLogs *UsageLogQuery
withAttributeValues *UserAttributeValueQuery
@@ -166,6 +168,28 @@ func (_q *UserQuery) QueryAssignedSubscriptions() *UserSubscriptionQuery {
return query
}
// QueryAnnouncementReads chains the current query on the "announcement_reads" edge.
func (_q *UserQuery) QueryAnnouncementReads() *AnnouncementReadQuery {
query := (&AnnouncementReadClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, selector),
sqlgraph.To(announcementread.Table, announcementread.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.AnnouncementReadsTable, user.AnnouncementReadsColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryAllowedGroups chains the current query on the "allowed_groups" edge.
func (_q *UserQuery) QueryAllowedGroups() *GroupQuery {
query := (&GroupClient{config: _q.config}).Query()
@@ -472,6 +496,7 @@ func (_q *UserQuery) Clone() *UserQuery {
withRedeemCodes: _q.withRedeemCodes.Clone(),
withSubscriptions: _q.withSubscriptions.Clone(),
withAssignedSubscriptions: _q.withAssignedSubscriptions.Clone(),
withAnnouncementReads: _q.withAnnouncementReads.Clone(),
withAllowedGroups: _q.withAllowedGroups.Clone(),
withUsageLogs: _q.withUsageLogs.Clone(),
withAttributeValues: _q.withAttributeValues.Clone(),
@@ -527,6 +552,17 @@ func (_q *UserQuery) WithAssignedSubscriptions(opts ...func(*UserSubscriptionQue
return _q
}
// WithAnnouncementReads tells the query-builder to eager-load the nodes that are connected to
// the "announcement_reads" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithAnnouncementReads(opts ...func(*AnnouncementReadQuery)) *UserQuery {
query := (&AnnouncementReadClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withAnnouncementReads = query
return _q
}
// WithAllowedGroups tells the query-builder to eager-load the nodes that are connected to
// the "allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithAllowedGroups(opts ...func(*GroupQuery)) *UserQuery {
@@ -660,11 +696,12 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
var (
nodes = []*User{}
_spec = _q.querySpec()
loadedTypes = [9]bool{
loadedTypes = [10]bool{
_q.withAPIKeys != nil,
_q.withRedeemCodes != nil,
_q.withSubscriptions != nil,
_q.withAssignedSubscriptions != nil,
_q.withAnnouncementReads != nil,
_q.withAllowedGroups != nil,
_q.withUsageLogs != nil,
_q.withAttributeValues != nil,
@@ -723,6 +760,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
return nil, err
}
}
if query := _q.withAnnouncementReads; query != nil {
if err := _q.loadAnnouncementReads(ctx, query, nodes,
func(n *User) { n.Edges.AnnouncementReads = []*AnnouncementRead{} },
func(n *User, e *AnnouncementRead) { n.Edges.AnnouncementReads = append(n.Edges.AnnouncementReads, e) }); err != nil {
return nil, err
}
}
if query := _q.withAllowedGroups; query != nil {
if err := _q.loadAllowedGroups(ctx, query, nodes,
func(n *User) { n.Edges.AllowedGroups = []*Group{} },
@@ -887,6 +931,36 @@ func (_q *UserQuery) loadAssignedSubscriptions(ctx context.Context, query *UserS
}
return nil
}
func (_q *UserQuery) loadAnnouncementReads(ctx context.Context, query *AnnouncementReadQuery, nodes []*User, init func(*User), assign func(*User, *AnnouncementRead)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(announcementread.FieldUserID)
}
query.Where(predicate.AnnouncementRead(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(user.AnnouncementReadsColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.UserID
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *UserQuery) loadAllowedGroups(ctx context.Context, query *GroupQuery, nodes []*User, init func(*User), assign func(*User, *Group)) error {
edgeIDs := make([]driver.Value, len(nodes))
byID := make(map[int64]*User)

View File

@@ -11,6 +11,7 @@ import (
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
@@ -187,6 +188,60 @@ func (_u *UserUpdate) SetNillableNotes(v *string) *UserUpdate {
return _u
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (_u *UserUpdate) SetTotpSecretEncrypted(v string) *UserUpdate {
_u.mutation.SetTotpSecretEncrypted(v)
return _u
}
// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
func (_u *UserUpdate) SetNillableTotpSecretEncrypted(v *string) *UserUpdate {
if v != nil {
_u.SetTotpSecretEncrypted(*v)
}
return _u
}
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
func (_u *UserUpdate) ClearTotpSecretEncrypted() *UserUpdate {
_u.mutation.ClearTotpSecretEncrypted()
return _u
}
// SetTotpEnabled sets the "totp_enabled" field.
func (_u *UserUpdate) SetTotpEnabled(v bool) *UserUpdate {
_u.mutation.SetTotpEnabled(v)
return _u
}
// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
func (_u *UserUpdate) SetNillableTotpEnabled(v *bool) *UserUpdate {
if v != nil {
_u.SetTotpEnabled(*v)
}
return _u
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (_u *UserUpdate) SetTotpEnabledAt(v time.Time) *UserUpdate {
_u.mutation.SetTotpEnabledAt(v)
return _u
}
// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
func (_u *UserUpdate) SetNillableTotpEnabledAt(v *time.Time) *UserUpdate {
if v != nil {
_u.SetTotpEnabledAt(*v)
}
return _u
}
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
_u.mutation.ClearTotpEnabledAt()
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -247,6 +302,21 @@ func (_u *UserUpdate) AddAssignedSubscriptions(v ...*UserSubscription) *UserUpda
return _u.AddAssignedSubscriptionIDs(ids...)
}
// AddAnnouncementReadIDs adds the "announcement_reads" edge to the AnnouncementRead entity by IDs.
func (_u *UserUpdate) AddAnnouncementReadIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAnnouncementReadIDs(ids...)
return _u
}
// AddAnnouncementReads adds the "announcement_reads" edges to the AnnouncementRead entity.
func (_u *UserUpdate) AddAnnouncementReads(v ...*AnnouncementRead) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddAnnouncementReadIDs(ids...)
}
// AddAllowedGroupIDs adds the "allowed_groups" edge to the Group entity by IDs.
func (_u *UserUpdate) AddAllowedGroupIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAllowedGroupIDs(ids...)
@@ -396,6 +466,27 @@ func (_u *UserUpdate) RemoveAssignedSubscriptions(v ...*UserSubscription) *UserU
return _u.RemoveAssignedSubscriptionIDs(ids...)
}
// ClearAnnouncementReads clears all "announcement_reads" edges to the AnnouncementRead entity.
func (_u *UserUpdate) ClearAnnouncementReads() *UserUpdate {
_u.mutation.ClearAnnouncementReads()
return _u
}
// RemoveAnnouncementReadIDs removes the "announcement_reads" edge to AnnouncementRead entities by IDs.
func (_u *UserUpdate) RemoveAnnouncementReadIDs(ids ...int64) *UserUpdate {
_u.mutation.RemoveAnnouncementReadIDs(ids...)
return _u
}
// RemoveAnnouncementReads removes "announcement_reads" edges to AnnouncementRead entities.
func (_u *UserUpdate) RemoveAnnouncementReads(v ...*AnnouncementRead) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveAnnouncementReadIDs(ids...)
}
// ClearAllowedGroups clears all "allowed_groups" edges to the Group entity.
func (_u *UserUpdate) ClearAllowedGroups() *UserUpdate {
_u.mutation.ClearAllowedGroups()
@@ -603,6 +694,21 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.Notes(); ok {
_spec.SetField(user.FieldNotes, field.TypeString, value)
}
if value, ok := _u.mutation.TotpSecretEncrypted(); ok {
_spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
}
if _u.mutation.TotpSecretEncryptedCleared() {
_spec.ClearField(user.FieldTotpSecretEncrypted, field.TypeString)
}
if value, ok := _u.mutation.TotpEnabled(); ok {
_spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
}
if value, ok := _u.mutation.TotpEnabledAt(); ok {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
}
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -783,6 +889,51 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.AnnouncementReadsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AnnouncementReadsTable,
Columns: []string{user.AnnouncementReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedAnnouncementReadsIDs(); len(nodes) > 0 && !_u.mutation.AnnouncementReadsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AnnouncementReadsTable,
Columns: []string{user.AnnouncementReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AnnouncementReadsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AnnouncementReadsTable,
Columns: []string{user.AnnouncementReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.AllowedGroupsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2M,
@@ -1147,6 +1298,60 @@ func (_u *UserUpdateOne) SetNillableNotes(v *string) *UserUpdateOne {
return _u
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (_u *UserUpdateOne) SetTotpSecretEncrypted(v string) *UserUpdateOne {
_u.mutation.SetTotpSecretEncrypted(v)
return _u
}
// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableTotpSecretEncrypted(v *string) *UserUpdateOne {
if v != nil {
_u.SetTotpSecretEncrypted(*v)
}
return _u
}
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
func (_u *UserUpdateOne) ClearTotpSecretEncrypted() *UserUpdateOne {
_u.mutation.ClearTotpSecretEncrypted()
return _u
}
// SetTotpEnabled sets the "totp_enabled" field.
func (_u *UserUpdateOne) SetTotpEnabled(v bool) *UserUpdateOne {
_u.mutation.SetTotpEnabled(v)
return _u
}
// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableTotpEnabled(v *bool) *UserUpdateOne {
if v != nil {
_u.SetTotpEnabled(*v)
}
return _u
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (_u *UserUpdateOne) SetTotpEnabledAt(v time.Time) *UserUpdateOne {
_u.mutation.SetTotpEnabledAt(v)
return _u
}
// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableTotpEnabledAt(v *time.Time) *UserUpdateOne {
if v != nil {
_u.SetTotpEnabledAt(*v)
}
return _u
}
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
_u.mutation.ClearTotpEnabledAt()
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1207,6 +1412,21 @@ func (_u *UserUpdateOne) AddAssignedSubscriptions(v ...*UserSubscription) *UserU
return _u.AddAssignedSubscriptionIDs(ids...)
}
// AddAnnouncementReadIDs adds the "announcement_reads" edge to the AnnouncementRead entity by IDs.
func (_u *UserUpdateOne) AddAnnouncementReadIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAnnouncementReadIDs(ids...)
return _u
}
// AddAnnouncementReads adds the "announcement_reads" edges to the AnnouncementRead entity.
func (_u *UserUpdateOne) AddAnnouncementReads(v ...*AnnouncementRead) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddAnnouncementReadIDs(ids...)
}
// AddAllowedGroupIDs adds the "allowed_groups" edge to the Group entity by IDs.
func (_u *UserUpdateOne) AddAllowedGroupIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAllowedGroupIDs(ids...)
@@ -1356,6 +1576,27 @@ func (_u *UserUpdateOne) RemoveAssignedSubscriptions(v ...*UserSubscription) *Us
return _u.RemoveAssignedSubscriptionIDs(ids...)
}
// ClearAnnouncementReads clears all "announcement_reads" edges to the AnnouncementRead entity.
func (_u *UserUpdateOne) ClearAnnouncementReads() *UserUpdateOne {
_u.mutation.ClearAnnouncementReads()
return _u
}
// RemoveAnnouncementReadIDs removes the "announcement_reads" edge to AnnouncementRead entities by IDs.
func (_u *UserUpdateOne) RemoveAnnouncementReadIDs(ids ...int64) *UserUpdateOne {
_u.mutation.RemoveAnnouncementReadIDs(ids...)
return _u
}
// RemoveAnnouncementReads removes "announcement_reads" edges to AnnouncementRead entities.
func (_u *UserUpdateOne) RemoveAnnouncementReads(v ...*AnnouncementRead) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveAnnouncementReadIDs(ids...)
}
// ClearAllowedGroups clears all "allowed_groups" edges to the Group entity.
func (_u *UserUpdateOne) ClearAllowedGroups() *UserUpdateOne {
_u.mutation.ClearAllowedGroups()
@@ -1593,6 +1834,21 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if value, ok := _u.mutation.Notes(); ok {
_spec.SetField(user.FieldNotes, field.TypeString, value)
}
if value, ok := _u.mutation.TotpSecretEncrypted(); ok {
_spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
}
if _u.mutation.TotpSecretEncryptedCleared() {
_spec.ClearField(user.FieldTotpSecretEncrypted, field.TypeString)
}
if value, ok := _u.mutation.TotpEnabled(); ok {
_spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
}
if value, ok := _u.mutation.TotpEnabledAt(); ok {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
}
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1773,6 +2029,51 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.AnnouncementReadsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AnnouncementReadsTable,
Columns: []string{user.AnnouncementReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedAnnouncementReadsIDs(); len(nodes) > 0 && !_u.mutation.AnnouncementReadsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AnnouncementReadsTable,
Columns: []string{user.AnnouncementReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AnnouncementReadsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AnnouncementReadsTable,
Columns: []string{user.AnnouncementReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.AllowedGroupsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2M,

View File

@@ -1,6 +1,6 @@
module github.com/Wei-Shaw/sub2api
go 1.25.5
go 1.25.6
require (
entgo.io/ent v0.14.5
@@ -37,6 +37,7 @@ require (
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
@@ -106,6 +107,7 @@ require (
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/pquerna/otp v1.5.0 // indirect
github.com/quic-go/qpack v0.6.0 // indirect
github.com/quic-go/quic-go v0.57.1 // indirect
github.com/refraction-networking/utls v1.8.1 // indirect

View File

@@ -20,6 +20,8 @@ github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
@@ -217,6 +219,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=

View File

@@ -47,6 +47,7 @@ type Config struct {
Redis RedisConfig `mapstructure:"redis"`
Ops OpsConfig `mapstructure:"ops"`
JWT JWTConfig `mapstructure:"jwt"`
Totp TotpConfig `mapstructure:"totp"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
@@ -414,6 +415,8 @@ type RedisConfig struct {
PoolSize int `mapstructure:"pool_size"`
// MinIdleConns: 最小空闲连接数,保持热连接减少冷启动延迟
MinIdleConns int `mapstructure:"min_idle_conns"`
// EnableTLS: 是否启用 TLS/SSL 连接
EnableTLS bool `mapstructure:"enable_tls"`
}
func (r *RedisConfig) Address() string {
@@ -466,6 +469,16 @@ type JWTConfig struct {
ExpireHour int `mapstructure:"expire_hour"`
}
// TotpConfig TOTP 双因素认证配置
type TotpConfig struct {
// EncryptionKey 用于加密 TOTP 密钥的 AES-256 密钥32 字节 hex 编码)
// 如果为空,将自动生成一个随机密钥(仅适用于开发环境)
EncryptionKey string `mapstructure:"encryption_key"`
// EncryptionKeyConfigured 标记加密密钥是否为手动配置(非自动生成)
// 只有手动配置了密钥才允许在管理后台启用 TOTP 功能
EncryptionKeyConfigured bool `mapstructure:"-"`
}
type TurnstileConfig struct {
Required bool `mapstructure:"required"`
}
@@ -626,6 +639,20 @@ func Load() (*Config, error) {
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
}
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
if cfg.Totp.EncryptionKey == "" {
key, err := generateJWTSecret(32) // Reuse the same random generation function
if err != nil {
return nil, fmt.Errorf("generate totp encryption key error: %w", err)
}
cfg.Totp.EncryptionKey = key
cfg.Totp.EncryptionKeyConfigured = false
log.Println("Warning: TOTP encryption key auto-generated. Consider setting a fixed key for production.")
} else {
cfg.Totp.EncryptionKeyConfigured = true
}
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("validate config error: %w", err)
}
@@ -737,6 +764,7 @@ func setDefaults() {
viper.SetDefault("redis.write_timeout_seconds", 3)
viper.SetDefault("redis.pool_size", 128)
viper.SetDefault("redis.min_idle_conns", 10)
viper.SetDefault("redis.enable_tls", false)
// Ops (vNext)
viper.SetDefault("ops.enabled", true)
@@ -756,6 +784,9 @@ func setDefaults() {
viper.SetDefault("jwt.secret", "")
viper.SetDefault("jwt.expire_hour", 24)
// TOTP
viper.SetDefault("totp.encryption_key", "")
// Default
// Admin credentials are created via the setup flow (web wizard / CLI / AUTO_SETUP).
// Do not ship fixed defaults here to avoid insecure "known credentials" in production.

View File

@@ -0,0 +1,226 @@
package domain
import (
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
const (
AnnouncementStatusDraft = "draft"
AnnouncementStatusActive = "active"
AnnouncementStatusArchived = "archived"
)
const (
AnnouncementConditionTypeSubscription = "subscription"
AnnouncementConditionTypeBalance = "balance"
)
const (
AnnouncementOperatorIn = "in"
AnnouncementOperatorGT = "gt"
AnnouncementOperatorGTE = "gte"
AnnouncementOperatorLT = "lt"
AnnouncementOperatorLTE = "lte"
AnnouncementOperatorEQ = "eq"
)
var (
ErrAnnouncementNotFound = infraerrors.NotFound("ANNOUNCEMENT_NOT_FOUND", "announcement not found")
ErrAnnouncementInvalidTarget = infraerrors.BadRequest("ANNOUNCEMENT_INVALID_TARGET", "invalid announcement targeting rules")
)
type AnnouncementTargeting struct {
// AnyOf 表示 OR任意一个条件组满足即可展示。
AnyOf []AnnouncementConditionGroup `json:"any_of,omitempty"`
}
type AnnouncementConditionGroup struct {
// AllOf 表示 AND组内所有条件都满足才算命中该组。
AllOf []AnnouncementCondition `json:"all_of,omitempty"`
}
type AnnouncementCondition struct {
// Type: subscription | balance
Type string `json:"type"`
// Operator:
// - subscription: in
// - balance: gt/gte/lt/lte/eq
Operator string `json:"operator"`
// subscription 条件匹配的订阅套餐group_id
GroupIDs []int64 `json:"group_ids,omitempty"`
// balance 条件:比较阈值
Value float64 `json:"value,omitempty"`
}
func (t AnnouncementTargeting) Matches(balance float64, activeSubscriptionGroupIDs map[int64]struct{}) bool {
// 空规则:展示给所有用户
if len(t.AnyOf) == 0 {
return true
}
for _, group := range t.AnyOf {
if len(group.AllOf) == 0 {
// 空条件组不命中(避免 OR 中出现无条件 “全命中”)
continue
}
allMatched := true
for _, cond := range group.AllOf {
if !cond.Matches(balance, activeSubscriptionGroupIDs) {
allMatched = false
break
}
}
if allMatched {
return true
}
}
return false
}
func (c AnnouncementCondition) Matches(balance float64, activeSubscriptionGroupIDs map[int64]struct{}) bool {
switch c.Type {
case AnnouncementConditionTypeSubscription:
if c.Operator != AnnouncementOperatorIn {
return false
}
if len(c.GroupIDs) == 0 {
return false
}
if len(activeSubscriptionGroupIDs) == 0 {
return false
}
for _, gid := range c.GroupIDs {
if _, ok := activeSubscriptionGroupIDs[gid]; ok {
return true
}
}
return false
case AnnouncementConditionTypeBalance:
switch c.Operator {
case AnnouncementOperatorGT:
return balance > c.Value
case AnnouncementOperatorGTE:
return balance >= c.Value
case AnnouncementOperatorLT:
return balance < c.Value
case AnnouncementOperatorLTE:
return balance <= c.Value
case AnnouncementOperatorEQ:
return balance == c.Value
default:
return false
}
default:
return false
}
}
func (t AnnouncementTargeting) NormalizeAndValidate() (AnnouncementTargeting, error) {
normalized := AnnouncementTargeting{AnyOf: make([]AnnouncementConditionGroup, 0, len(t.AnyOf))}
// 允许空 targeting展示给所有用户
if len(t.AnyOf) == 0 {
return normalized, nil
}
if len(t.AnyOf) > 50 {
return AnnouncementTargeting{}, ErrAnnouncementInvalidTarget
}
for _, g := range t.AnyOf {
if len(g.AllOf) == 0 {
return AnnouncementTargeting{}, ErrAnnouncementInvalidTarget
}
if len(g.AllOf) > 50 {
return AnnouncementTargeting{}, ErrAnnouncementInvalidTarget
}
group := AnnouncementConditionGroup{AllOf: make([]AnnouncementCondition, 0, len(g.AllOf))}
for _, c := range g.AllOf {
cond := AnnouncementCondition{
Type: strings.TrimSpace(c.Type),
Operator: strings.TrimSpace(c.Operator),
Value: c.Value,
}
for _, gid := range c.GroupIDs {
if gid <= 0 {
return AnnouncementTargeting{}, ErrAnnouncementInvalidTarget
}
cond.GroupIDs = append(cond.GroupIDs, gid)
}
if err := cond.validate(); err != nil {
return AnnouncementTargeting{}, err
}
group.AllOf = append(group.AllOf, cond)
}
normalized.AnyOf = append(normalized.AnyOf, group)
}
return normalized, nil
}
func (c AnnouncementCondition) validate() error {
switch c.Type {
case AnnouncementConditionTypeSubscription:
if c.Operator != AnnouncementOperatorIn {
return ErrAnnouncementInvalidTarget
}
if len(c.GroupIDs) == 0 {
return ErrAnnouncementInvalidTarget
}
return nil
case AnnouncementConditionTypeBalance:
switch c.Operator {
case AnnouncementOperatorGT, AnnouncementOperatorGTE, AnnouncementOperatorLT, AnnouncementOperatorLTE, AnnouncementOperatorEQ:
return nil
default:
return ErrAnnouncementInvalidTarget
}
default:
return ErrAnnouncementInvalidTarget
}
}
type Announcement struct {
ID int64
Title string
Content string
Status string
Targeting AnnouncementTargeting
StartsAt *time.Time
EndsAt *time.Time
CreatedBy *int64
UpdatedBy *int64
CreatedAt time.Time
UpdatedAt time.Time
}
func (a *Announcement) IsActiveAt(now time.Time) bool {
if a == nil {
return false
}
if a.Status != AnnouncementStatusActive {
return false
}
if a.StartsAt != nil && now.Before(*a.StartsAt) {
return false
}
if a.EndsAt != nil && !now.Before(*a.EndsAt) {
// ends_at 语义:到点即下线
return false
}
return true
}

View File

@@ -0,0 +1,64 @@
package domain
// Status constants
const (
StatusActive = "active"
StatusDisabled = "disabled"
StatusError = "error"
StatusUnused = "unused"
StatusUsed = "used"
StatusExpired = "expired"
)
// Role constants
const (
RoleAdmin = "admin"
RoleUser = "user"
)
// Platform constants
const (
PlatformAnthropic = "anthropic"
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
PlatformAntigravity = "antigravity"
)
// Account type constants
const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeAPIKey = "apikey" // API Key类型账号
)
// Redeem type constants
const (
RedeemTypeBalance = "balance"
RedeemTypeConcurrency = "concurrency"
RedeemTypeSubscription = "subscription"
)
// PromoCode status constants
const (
PromoCodeStatusActive = "active"
PromoCodeStatusDisabled = "disabled"
)
// Admin adjustment type constants
const (
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
)
// Group subscription type constants
const (
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
)
// Subscription status constants
const (
SubscriptionStatusActive = "active"
SubscriptionStatusExpired = "expired"
SubscriptionStatusSuspended = "suspended"
)

View File

@@ -45,6 +45,7 @@ type AccountHandler struct {
concurrencyService *service.ConcurrencyService
crsSyncService *service.CRSSyncService
sessionLimitCache service.SessionLimitCache
tokenCacheInvalidator service.TokenCacheInvalidator
}
// NewAccountHandler creates a new admin account handler
@@ -60,6 +61,7 @@ func NewAccountHandler(
concurrencyService *service.ConcurrencyService,
crsSyncService *service.CRSSyncService,
sessionLimitCache service.SessionLimitCache,
tokenCacheInvalidator service.TokenCacheInvalidator,
) *AccountHandler {
return &AccountHandler{
adminService: adminService,
@@ -73,6 +75,7 @@ func NewAccountHandler(
concurrencyService: concurrencyService,
crsSyncService: crsSyncService,
sessionLimitCache: sessionLimitCache,
tokenCacheInvalidator: tokenCacheInvalidator,
}
}
@@ -173,6 +176,7 @@ func (h *AccountHandler) List(c *gin.Context) {
// 识别需要查询窗口费用和会话数的账号Anthropic OAuth/SetupToken 且启用了相应功能)
windowCostAccountIDs := make([]int64, 0)
sessionLimitAccountIDs := make([]int64, 0)
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
for i := range accounts {
acc := &accounts[i]
if acc.IsAnthropicOAuthOrSetupToken() {
@@ -181,6 +185,7 @@ func (h *AccountHandler) List(c *gin.Context) {
}
if acc.GetMaxSessions() > 0 {
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
}
}
}
@@ -189,9 +194,9 @@ func (h *AccountHandler) List(c *gin.Context) {
var windowCosts map[int64]float64
var activeSessions map[int64]int
// 获取活跃会话数(批量查询)
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs)
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
if activeSessions == nil {
activeSessions = make(map[int64]int)
}
@@ -542,9 +547,18 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
}
}
// 如果 project_id 获取失败,先更新凭证,再标记账户为 error
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
// 这确保了即使 LoadCodeAssist 失败project_id 也不会丢失
if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
newCredentials["project_id"] = oldProjectID
}
}
// 如果 project_id 获取失败,更新凭证但不标记为 error
// LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试
if tokenInfo.ProjectIDMissing {
// 先更新凭证
// 先更新凭证token 本身刷新成功了)
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: newCredentials,
})
@@ -552,14 +566,10 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
return
}
// 标记账户为 error
if setErr := h.adminService.SetAccountError(c.Request.Context(), accountID, "missing_project_id: 账户缺少project id可能无法使用Antigravity"); setErr != nil {
response.InternalError(c, "Failed to set account error: "+setErr.Error())
return
}
// 标记为 error,只返回警告信息
response.Success(c, gin.H{
"message": "Token refreshed but project_id is missing, account marked as error",
"warning": "missing_project_id",
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
"warning": "missing_project_id_temporary",
})
return
}
@@ -606,6 +616,14 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
return
}
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
if h.tokenCacheInvalidator != nil {
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), updatedAccount); invalidateErr != nil {
// 缓存失效失败只记录日志,不影响主流程
_ = c.Error(invalidateErr)
}
}
response.Success(c, dto.AccountFromService(updatedAccount))
}
@@ -655,6 +673,15 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
return
}
// 清除错误后,同时清除 token 缓存,确保下次请求会获取最新的 token触发刷新或从 DB 读取)
// 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil {
// 缓存失效失败只记录日志,不影响主流程
_ = c.Error(invalidateErr)
}
}
response.Success(c, dto.AccountFromService(account))
}

View File

@@ -0,0 +1,246 @@
package admin
import (
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AnnouncementHandler handles admin announcement management
type AnnouncementHandler struct {
announcementService *service.AnnouncementService
}
// NewAnnouncementHandler creates a new admin announcement handler
func NewAnnouncementHandler(announcementService *service.AnnouncementService) *AnnouncementHandler {
return &AnnouncementHandler{
announcementService: announcementService,
}
}
type CreateAnnouncementRequest struct {
Title string `json:"title" binding:"required"`
Content string `json:"content" binding:"required"`
Status string `json:"status" binding:"omitempty,oneof=draft active archived"`
Targeting service.AnnouncementTargeting `json:"targeting"`
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0/empty = immediate
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0/empty = never
}
type UpdateAnnouncementRequest struct {
Title *string `json:"title"`
Content *string `json:"content"`
Status *string `json:"status" binding:"omitempty,oneof=draft active archived"`
Targeting *service.AnnouncementTargeting `json:"targeting"`
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0 = clear
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0 = clear
}
// List handles listing announcements with filters
// GET /api/v1/admin/announcements
func (h *AnnouncementHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
status := strings.TrimSpace(c.Query("status"))
search := strings.TrimSpace(c.Query("search"))
if len(search) > 200 {
search = search[:200]
}
params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
}
items, paginationResult, err := h.announcementService.List(
c.Request.Context(),
params,
service.AnnouncementListFilters{Status: status, Search: search},
)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.Announcement, 0, len(items))
for i := range items {
out = append(out, *dto.AnnouncementFromService(&items[i]))
}
response.Paginated(c, out, paginationResult.Total, page, pageSize)
}
// GetByID handles getting an announcement by ID
// GET /api/v1/admin/announcements/:id
func (h *AnnouncementHandler) GetByID(c *gin.Context) {
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || announcementID <= 0 {
response.BadRequest(c, "Invalid announcement ID")
return
}
item, err := h.announcementService.GetByID(c.Request.Context(), announcementID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.AnnouncementFromService(item))
}
// Create handles creating a new announcement
// POST /api/v1/admin/announcements
func (h *AnnouncementHandler) Create(c *gin.Context) {
var req CreateAnnouncementRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
input := &service.CreateAnnouncementInput{
Title: req.Title,
Content: req.Content,
Status: req.Status,
Targeting: req.Targeting,
ActorID: &subject.UserID,
}
if req.StartsAt != nil && *req.StartsAt > 0 {
t := time.Unix(*req.StartsAt, 0)
input.StartsAt = &t
}
if req.EndsAt != nil && *req.EndsAt > 0 {
t := time.Unix(*req.EndsAt, 0)
input.EndsAt = &t
}
created, err := h.announcementService.Create(c.Request.Context(), input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.AnnouncementFromService(created))
}
// Update handles updating an announcement
// PUT /api/v1/admin/announcements/:id
func (h *AnnouncementHandler) Update(c *gin.Context) {
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || announcementID <= 0 {
response.BadRequest(c, "Invalid announcement ID")
return
}
var req UpdateAnnouncementRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
input := &service.UpdateAnnouncementInput{
Title: req.Title,
Content: req.Content,
Status: req.Status,
Targeting: req.Targeting,
ActorID: &subject.UserID,
}
if req.StartsAt != nil {
if *req.StartsAt == 0 {
var cleared *time.Time = nil
input.StartsAt = &cleared
} else {
t := time.Unix(*req.StartsAt, 0)
ptr := &t
input.StartsAt = &ptr
}
}
if req.EndsAt != nil {
if *req.EndsAt == 0 {
var cleared *time.Time = nil
input.EndsAt = &cleared
} else {
t := time.Unix(*req.EndsAt, 0)
ptr := &t
input.EndsAt = &ptr
}
}
updated, err := h.announcementService.Update(c.Request.Context(), announcementID, input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.AnnouncementFromService(updated))
}
// Delete handles deleting an announcement
// DELETE /api/v1/admin/announcements/:id
func (h *AnnouncementHandler) Delete(c *gin.Context) {
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || announcementID <= 0 {
response.BadRequest(c, "Invalid announcement ID")
return
}
if err := h.announcementService.Delete(c.Request.Context(), announcementID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Announcement deleted successfully"})
}
// ListReadStatus handles listing users read status for an announcement
// GET /api/v1/admin/announcements/:id/read-status
func (h *AnnouncementHandler) ListReadStatus(c *gin.Context) {
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || announcementID <= 0 {
response.BadRequest(c, "Invalid announcement ID")
return
}
page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
}
search := strings.TrimSpace(c.Query("search"))
if len(search) > 200 {
search = search[:200]
}
items, paginationResult, err := h.announcementService.ListUserReadStatus(
c.Request.Context(),
announcementID,
params,
search,
)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, items, paginationResult.Total, page, pageSize)
}

View File

@@ -94,9 +94,9 @@ func (h *GroupHandler) List(c *gin.Context) {
return
}
outGroups := make([]dto.Group, 0, len(groups))
outGroups := make([]dto.AdminGroup, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
outGroups = append(outGroups, *dto.GroupFromServiceAdmin(&groups[i]))
}
response.Paginated(c, outGroups, total, page, pageSize)
}
@@ -120,9 +120,9 @@ func (h *GroupHandler) GetAll(c *gin.Context) {
return
}
outGroups := make([]dto.Group, 0, len(groups))
outGroups := make([]dto.AdminGroup, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
outGroups = append(outGroups, *dto.GroupFromServiceAdmin(&groups[i]))
}
response.Success(c, outGroups)
}
@@ -142,7 +142,7 @@ func (h *GroupHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.GroupFromService(group))
response.Success(c, dto.GroupFromServiceAdmin(group))
}
// Create handles creating a new group
@@ -177,7 +177,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
return
}
response.Success(c, dto.GroupFromService(group))
response.Success(c, dto.GroupFromServiceAdmin(group))
}
// Update handles updating a group
@@ -219,7 +219,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
return
}
response.Success(c, dto.GroupFromService(group))
response.Success(c, dto.GroupFromServiceAdmin(group))
}
// Delete handles deleting a group

View File

@@ -54,9 +54,9 @@ func (h *RedeemHandler) List(c *gin.Context) {
return
}
out := make([]dto.RedeemCode, 0, len(codes))
out := make([]dto.AdminRedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
@@ -76,7 +76,7 @@ func (h *RedeemHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.RedeemCodeFromService(code))
response.Success(c, dto.RedeemCodeFromServiceAdmin(code))
}
// Generate handles generating new redeem codes
@@ -100,9 +100,9 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
return
}
out := make([]dto.RedeemCode, 0, len(codes))
out := make([]dto.AdminRedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
}
response.Success(c, out)
}
@@ -163,7 +163,7 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
return
}
response.Success(c, dto.RedeemCodeFromService(code))
response.Success(c, dto.RedeemCodeFromServiceAdmin(code))
}
// GetStats handles getting redeem code statistics

View File

@@ -47,6 +47,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
TotpEnabled: settings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername,
@@ -68,6 +72,9 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
EnableModelFallback: settings.EnableModelFallback,
@@ -87,8 +94,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
// UpdateSettingsRequest 更新设置请求
type UpdateSettingsRequest struct {
// 注册设置
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
// 邮件服务设置
SMTPHost string `json:"smtp_host"`
@@ -111,13 +121,16 @@ type UpdateSettingsRequest struct {
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
// OEM设置
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
@@ -194,6 +207,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
// TOTP 双因素认证参数验证
// 只有手动配置了加密密钥才允许启用 TOTP 功能
if req.TotpEnabled && !previousSettings.TotpEnabled {
// 尝试启用 TOTP检查加密密钥是否已手动配置
if !h.settingService.IsTotpEncryptionKeyConfigured() {
response.BadRequest(c, "Cannot enable TOTP: TOTP_ENCRYPTION_KEY environment variable must be configured first. Generate a key with 'openssl rand -hex 32' and set it in your environment.")
return
}
}
// LinuxDo Connect 参数验证
if req.LinuxDoConnectEnabled {
req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID)
@@ -223,6 +246,34 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
// “购买订阅”页面配置验证
purchaseEnabled := previousSettings.PurchaseSubscriptionEnabled
if req.PurchaseSubscriptionEnabled != nil {
purchaseEnabled = *req.PurchaseSubscriptionEnabled
}
purchaseURL := previousSettings.PurchaseSubscriptionURL
if req.PurchaseSubscriptionURL != nil {
purchaseURL = strings.TrimSpace(*req.PurchaseSubscriptionURL)
}
// - 启用时要求 URL 合法且非空
// - 禁用时允许为空;若提供了 URL 也做基本校验,避免误配置
if purchaseEnabled {
if purchaseURL == "" {
response.BadRequest(c, "Purchase Subscription URL is required when enabled")
return
}
if err := config.ValidateAbsoluteHTTPURL(purchaseURL); err != nil {
response.BadRequest(c, "Purchase Subscription URL must be an absolute http(s) URL")
return
}
} else if purchaseURL != "" {
if err := config.ValidateAbsoluteHTTPURL(purchaseURL); err != nil {
response.BadRequest(c, "Purchase Subscription URL must be an absolute http(s) URL")
return
}
}
// Ops metrics collector interval validation (seconds).
if req.OpsMetricsIntervalSeconds != nil {
v := *req.OpsMetricsIntervalSeconds
@@ -236,38 +287,44 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
APIBaseURL: req.APIBaseURL,
ContactInfo: req.ContactInfo,
DocURL: req.DocURL,
HomeContent: req.HomeContent,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
PromoCodeEnabled: req.PromoCodeEnabled,
PasswordResetEnabled: req.PasswordResetEnabled,
TotpEnabled: req.TotpEnabled,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
APIBaseURL: req.APIBaseURL,
ContactInfo: req.ContactInfo,
DocURL: req.DocURL,
HomeContent: req.HomeContent,
HideCcsImportButton: req.HideCcsImportButton,
PurchaseSubscriptionEnabled: purchaseEnabled,
PurchaseSubscriptionURL: purchaseURL,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
OpsMonitoringEnabled: func() bool {
if req.OpsMonitoringEnabled != nil {
return *req.OpsMonitoringEnabled
@@ -311,6 +368,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
TotpEnabled: updatedSettings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername,
@@ -332,6 +393,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
ContactInfo: updatedSettings.ContactInfo,
DocURL: updatedSettings.DocURL,
HomeContent: updatedSettings.HomeContent,
HideCcsImportButton: updatedSettings.HideCcsImportButton,
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
EnableModelFallback: updatedSettings.EnableModelFallback,
@@ -376,6 +440,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
changed = append(changed, "email_verify_enabled")
}
if before.PasswordResetEnabled != after.PasswordResetEnabled {
changed = append(changed, "password_reset_enabled")
}
if before.TotpEnabled != after.TotpEnabled {
changed = append(changed, "totp_enabled")
}
if before.SMTPHost != after.SMTPHost {
changed = append(changed, "smtp_host")
}
@@ -439,6 +509,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.HomeContent != after.HomeContent {
changed = append(changed, "home_content")
}
if before.HideCcsImportButton != after.HideCcsImportButton {
changed = append(changed, "hide_ccs_import_button")
}
if before.DefaultConcurrency != after.DefaultConcurrency {
changed = append(changed, "default_concurrency")
}

View File

@@ -53,9 +53,9 @@ type BulkAssignSubscriptionRequest struct {
Notes string `json:"notes"`
}
// ExtendSubscriptionRequest represents extend subscription request
type ExtendSubscriptionRequest struct {
Days int `json:"days" binding:"required,min=1,max=36500"` // max 100 years
// AdjustSubscriptionRequest represents adjust subscription request (extend or shorten)
type AdjustSubscriptionRequest struct {
Days int `json:"days" binding:"required,min=-36500,max=36500"` // negative to shorten, positive to extend
}
// List handles listing all subscriptions with pagination and filters
@@ -77,15 +77,19 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
}
status := c.Query("status")
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
// Parse sorting parameters
sortBy := c.DefaultQuery("sort_by", "created_at")
sortOrder := c.DefaultQuery("sort_order", "desc")
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.UserSubscription, 0, len(subscriptions))
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
}
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
}
@@ -105,7 +109,7 @@ func (h *SubscriptionHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.UserSubscriptionFromService(subscription))
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
}
// GetProgress handles getting subscription usage progress
@@ -150,7 +154,7 @@ func (h *SubscriptionHandler) Assign(c *gin.Context) {
return
}
response.Success(c, dto.UserSubscriptionFromService(subscription))
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
}
// BulkAssign handles bulk assigning subscriptions to multiple users
@@ -180,7 +184,7 @@ func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
response.Success(c, dto.BulkAssignResultFromService(result))
}
// Extend handles extending a subscription
// Extend handles adjusting a subscription (extend or shorten)
// POST /api/v1/admin/subscriptions/:id/extend
func (h *SubscriptionHandler) Extend(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
@@ -189,7 +193,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
return
}
var req ExtendSubscriptionRequest
var req AdjustSubscriptionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
@@ -201,7 +205,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
return
}
response.Success(c, dto.UserSubscriptionFromService(subscription))
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
}
// Revoke handles revoking a subscription
@@ -239,9 +243,9 @@ func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
return
}
out := make([]dto.UserSubscription, 0, len(subscriptions))
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
}
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
}
@@ -261,9 +265,9 @@ func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
return
}
out := make([]dto.UserSubscription, 0, len(subscriptions))
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
}
response.Success(c, out)
}

View File

@@ -163,7 +163,7 @@ func (h *UsageHandler) List(c *gin.Context) {
return
}
out := make([]dto.UsageLog, 0, len(records))
out := make([]dto.AdminUsageLog, 0, len(records))
for i := range records {
out = append(out, *dto.UsageLogFromServiceAdmin(&records[i]))
}

View File

@@ -84,9 +84,9 @@ func (h *UserHandler) List(c *gin.Context) {
return
}
out := make([]dto.User, 0, len(users))
out := make([]dto.AdminUser, 0, len(users))
for i := range users {
out = append(out, *dto.UserFromService(&users[i]))
out = append(out, *dto.UserFromServiceAdmin(&users[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
@@ -129,7 +129,7 @@ func (h *UserHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.UserFromService(user))
response.Success(c, dto.UserFromServiceAdmin(user))
}
// Create handles creating a new user
@@ -155,7 +155,7 @@ func (h *UserHandler) Create(c *gin.Context) {
return
}
response.Success(c, dto.UserFromService(user))
response.Success(c, dto.UserFromServiceAdmin(user))
}
// Update handles updating a user
@@ -189,7 +189,7 @@ func (h *UserHandler) Update(c *gin.Context) {
return
}
response.Success(c, dto.UserFromService(user))
response.Success(c, dto.UserFromServiceAdmin(user))
}
// Delete handles deleting a user
@@ -231,7 +231,7 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
return
}
response.Success(c, dto.UserFromService(user))
response.Success(c, dto.UserFromServiceAdmin(user))
}
// GetUserAPIKeys handles getting user's API keys

View File

@@ -0,0 +1,81 @@
package handler
import (
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AnnouncementHandler handles user announcement operations
type AnnouncementHandler struct {
announcementService *service.AnnouncementService
}
// NewAnnouncementHandler creates a new user announcement handler
func NewAnnouncementHandler(announcementService *service.AnnouncementService) *AnnouncementHandler {
return &AnnouncementHandler{
announcementService: announcementService,
}
}
// List handles listing announcements visible to current user
// GET /api/v1/announcements
func (h *AnnouncementHandler) List(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
unreadOnly := parseBoolQuery(c.Query("unread_only"))
items, err := h.announcementService.ListForUser(c.Request.Context(), subject.UserID, unreadOnly)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.UserAnnouncement, 0, len(items))
for i := range items {
out = append(out, *dto.UserAnnouncementFromService(&items[i]))
}
response.Success(c, out)
}
// MarkRead marks an announcement as read for current user
// POST /api/v1/announcements/:id/read
func (h *AnnouncementHandler) MarkRead(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || announcementID <= 0 {
response.BadRequest(c, "Invalid announcement ID")
return
}
if err := h.announcementService.MarkRead(c.Request.Context(), subject.UserID, announcementID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "ok"})
}
func parseBoolQuery(v string) bool {
switch strings.TrimSpace(strings.ToLower(v)) {
case "1", "true", "yes", "y", "on":
return true
default:
return false
}
}

View File

@@ -1,6 +1,8 @@
package handler
import (
"log/slog"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
@@ -18,16 +20,18 @@ type AuthHandler struct {
userService *service.UserService
settingSvc *service.SettingService
promoService *service.PromoService
totpService *service.TotpService
}
// NewAuthHandler creates a new AuthHandler
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService) *AuthHandler {
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, totpService *service.TotpService) *AuthHandler {
return &AuthHandler{
cfg: cfg,
authService: authService,
userService: userService,
settingSvc: settingService,
promoService: promoService,
totpService: totpService,
}
}
@@ -144,6 +148,100 @@ func (h *AuthHandler) Login(c *gin.Context) {
return
}
// Check if TOTP 2FA is enabled for this user
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
// Create a temporary login session for 2FA
tempToken, err := h.totpService.CreateLoginSession(c.Request.Context(), user.ID, user.Email)
if err != nil {
response.InternalError(c, "Failed to create 2FA session")
return
}
response.Success(c, TotpLoginResponse{
Requires2FA: true,
TempToken: tempToken,
UserEmailMasked: service.MaskEmail(user.Email),
})
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
User: dto.UserFromService(user),
})
}
// TotpLoginResponse represents the response when 2FA is required
type TotpLoginResponse struct {
Requires2FA bool `json:"requires_2fa"`
TempToken string `json:"temp_token,omitempty"`
UserEmailMasked string `json:"user_email_masked,omitempty"`
}
// Login2FARequest represents the 2FA login request
type Login2FARequest struct {
TempToken string `json:"temp_token" binding:"required"`
TotpCode string `json:"totp_code" binding:"required,len=6"`
}
// Login2FA completes the login with 2FA verification
// POST /api/v1/auth/login/2fa
func (h *AuthHandler) Login2FA(c *gin.Context) {
var req Login2FARequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
slog.Debug("login_2fa_request",
"temp_token_len", len(req.TempToken),
"totp_code_len", len(req.TotpCode))
// Get the login session
session, err := h.totpService.GetLoginSession(c.Request.Context(), req.TempToken)
if err != nil || session == nil {
tokenPrefix := ""
if len(req.TempToken) >= 8 {
tokenPrefix = req.TempToken[:8]
}
slog.Debug("login_2fa_session_invalid",
"temp_token_prefix", tokenPrefix,
"error", err)
response.BadRequest(c, "Invalid or expired 2FA session")
return
}
slog.Debug("login_2fa_session_found",
"user_id", session.UserID,
"email", session.Email)
// Verify the TOTP code
if err := h.totpService.VerifyCode(c.Request.Context(), session.UserID, req.TotpCode); err != nil {
slog.Debug("login_2fa_verify_failed",
"user_id", session.UserID,
"error", err)
response.ErrorFrom(c, err)
return
}
// Delete the login session
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
// Get the user
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Generate the JWT token
token, err := h.authService.GenerateToken(user)
if err != nil {
response.InternalError(c, "Failed to generate token")
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
@@ -195,6 +293,15 @@ type ValidatePromoCodeResponse struct {
// ValidatePromoCode 验证优惠码(公开接口,注册前调用)
// POST /api/v1/auth/validate-promo-code
func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
// 检查优惠码功能是否启用
if h.settingSvc != nil && !h.settingSvc.IsPromoCodeEnabled(c.Request.Context()) {
response.Success(c, ValidatePromoCodeResponse{
Valid: false,
ErrorCode: "PROMO_CODE_DISABLED",
})
return
}
var req ValidatePromoCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
@@ -238,3 +345,85 @@ func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
BonusAmount: promoCode.BonusAmount,
})
}
// ForgotPasswordRequest 忘记密码请求
type ForgotPasswordRequest struct {
Email string `json:"email" binding:"required,email"`
TurnstileToken string `json:"turnstile_token"`
}
// ForgotPasswordResponse 忘记密码响应
type ForgotPasswordResponse struct {
Message string `json:"message"`
}
// ForgotPassword 请求密码重置
// POST /api/v1/auth/forgot-password
func (h *AuthHandler) ForgotPassword(c *gin.Context) {
var req ForgotPasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Turnstile 验证
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
response.ErrorFrom(c, err)
return
}
// Build frontend base URL from request
scheme := "https"
if c.Request.TLS == nil {
// Check X-Forwarded-Proto header (common in reverse proxy setups)
if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" {
scheme = proto
} else {
scheme = "http"
}
}
frontendBaseURL := scheme + "://" + c.Request.Host
// Request password reset (async)
// Note: This returns success even if email doesn't exist (to prevent enumeration)
if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, ForgotPasswordResponse{
Message: "If your email is registered, you will receive a password reset link shortly.",
})
}
// ResetPasswordRequest 重置密码请求
type ResetPasswordRequest struct {
Email string `json:"email" binding:"required,email"`
Token string `json:"token" binding:"required"`
NewPassword string `json:"new_password" binding:"required,min=6"`
}
// ResetPasswordResponse 重置密码响应
type ResetPasswordResponse struct {
Message string `json:"message"`
}
// ResetPassword 重置密码
// POST /api/v1/auth/reset-password
func (h *AuthHandler) ResetPassword(c *gin.Context) {
var req ResetPasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Reset password
if err := h.authService.ResetPassword(c.Request.Context(), req.Email, req.Token, req.NewPassword); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, ResetPasswordResponse{
Message: "Your password has been reset successfully. You can now log in with your new password.",
})
}

View File

@@ -0,0 +1,74 @@
package dto
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type Announcement struct {
ID int64 `json:"id"`
Title string `json:"title"`
Content string `json:"content"`
Status string `json:"status"`
Targeting service.AnnouncementTargeting `json:"targeting"`
StartsAt *time.Time `json:"starts_at,omitempty"`
EndsAt *time.Time `json:"ends_at,omitempty"`
CreatedBy *int64 `json:"created_by,omitempty"`
UpdatedBy *int64 `json:"updated_by,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type UserAnnouncement struct {
ID int64 `json:"id"`
Title string `json:"title"`
Content string `json:"content"`
StartsAt *time.Time `json:"starts_at,omitempty"`
EndsAt *time.Time `json:"ends_at,omitempty"`
ReadAt *time.Time `json:"read_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func AnnouncementFromService(a *service.Announcement) *Announcement {
if a == nil {
return nil
}
return &Announcement{
ID: a.ID,
Title: a.Title,
Content: a.Content,
Status: a.Status,
Targeting: a.Targeting,
StartsAt: a.StartsAt,
EndsAt: a.EndsAt,
CreatedBy: a.CreatedBy,
UpdatedBy: a.UpdatedBy,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
}
}
func UserAnnouncementFromService(a *service.UserAnnouncement) *UserAnnouncement {
if a == nil {
return nil
}
return &UserAnnouncement{
ID: a.Announcement.ID,
Title: a.Announcement.Title,
Content: a.Announcement.Content,
StartsAt: a.Announcement.StartsAt,
EndsAt: a.Announcement.EndsAt,
ReadAt: a.ReadAt,
CreatedAt: a.Announcement.CreatedAt,
UpdatedAt: a.Announcement.UpdatedAt,
}
}

View File

@@ -15,7 +15,6 @@ func UserFromServiceShallow(u *service.User) *User {
ID: u.ID,
Email: u.Email,
Username: u.Username,
Notes: u.Notes,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
@@ -48,6 +47,22 @@ func UserFromService(u *service.User) *User {
return out
}
// UserFromServiceAdmin converts a service User to DTO for admin users.
// It includes notes - user-facing endpoints must not use this.
func UserFromServiceAdmin(u *service.User) *AdminUser {
if u == nil {
return nil
}
base := UserFromService(u)
if base == nil {
return nil
}
return &AdminUser{
User: *base,
Notes: u.Notes,
}
}
func APIKeyFromService(k *service.APIKey) *APIKey {
if k == nil {
return nil
@@ -72,36 +87,29 @@ func GroupFromServiceShallow(g *service.Group) *Group {
if g == nil {
return nil
}
return &Group{
ID: g.ID,
Name: g.Name,
Description: g.Description,
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUSD,
WeeklyLimitUSD: g.WeeklyLimitUSD,
MonthlyLimitUSD: g.MonthlyLimitUSD,
ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
AccountCount: g.AccountCount,
}
out := groupFromServiceBase(g)
return &out
}
func GroupFromService(g *service.Group) *Group {
if g == nil {
return nil
}
out := GroupFromServiceShallow(g)
return GroupFromServiceShallow(g)
}
// GroupFromServiceAdmin converts a service Group to DTO for admin users.
// It includes internal fields like model_routing and account_count.
func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
if g == nil {
return nil
}
out := &AdminGroup{
Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
AccountCount: g.AccountCount,
}
if len(g.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
for i := range g.AccountGroups {
@@ -112,6 +120,29 @@ func GroupFromService(g *service.Group) *Group {
return out
}
func groupFromServiceBase(g *service.Group) Group {
return Group{
ID: g.ID,
Name: g.Name,
Description: g.Description,
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUSD,
WeeklyLimitUSD: g.WeeklyLimitUSD,
MonthlyLimitUSD: g.MonthlyLimitUSD,
ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}
}
func AccountFromServiceShallow(a *service.Account) *Account {
if a == nil {
return nil
@@ -273,7 +304,24 @@ func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
if rc == nil {
return nil
}
return &RedeemCode{
out := redeemCodeFromServiceBase(rc)
return &out
}
// RedeemCodeFromServiceAdmin converts a service RedeemCode to DTO for admin users.
// It includes notes - user-facing endpoints must not use this.
func RedeemCodeFromServiceAdmin(rc *service.RedeemCode) *AdminRedeemCode {
if rc == nil {
return nil
}
return &AdminRedeemCode{
RedeemCode: redeemCodeFromServiceBase(rc),
Notes: rc.Notes,
}
}
func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode {
out := RedeemCode{
ID: rc.ID,
Code: rc.Code,
Type: rc.Type,
@@ -281,13 +329,20 @@ func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
Status: rc.Status,
UsedBy: rc.UsedBy,
UsedAt: rc.UsedAt,
Notes: rc.Notes,
CreatedAt: rc.CreatedAt,
GroupID: rc.GroupID,
ValidityDays: rc.ValidityDays,
User: UserFromServiceShallow(rc.User),
Group: GroupFromServiceShallow(rc.Group),
}
// For admin_balance/admin_concurrency types, include notes so users can see
// why they were charged or credited by admin
if (rc.Type == "admin_balance" || rc.Type == "admin_concurrency") && rc.Notes != "" {
out.Notes = &rc.Notes
}
return out
}
// AccountSummaryFromService returns a minimal AccountSummary for usage log display.
@@ -302,14 +357,9 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary {
}
}
// usageLogFromServiceBase is a helper that converts service UsageLog to DTO.
// The account parameter allows caller to control what Account info is included.
// The includeIPAddress parameter controls whether to include the IP address (admin-only).
func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, includeIPAddress bool) *UsageLog {
if l == nil {
return nil
}
result := &UsageLog{
func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
// 普通用户 DTO严禁包含管理员字段例如 account_rate_multiplier、ip_address、account
return UsageLog{
ID: l.ID,
UserID: l.UserID,
APIKeyID: l.APIKeyID,
@@ -331,7 +381,6 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, inclu
TotalCost: l.TotalCost,
ActualCost: l.ActualCost,
RateMultiplier: l.RateMultiplier,
AccountRateMultiplier: l.AccountRateMultiplier,
BillingType: l.BillingType,
Stream: l.Stream,
DurationMs: l.DurationMs,
@@ -342,30 +391,33 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, inclu
CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User),
APIKey: APIKeyFromService(l.APIKey),
Account: account,
Group: GroupFromServiceShallow(l.Group),
Subscription: UserSubscriptionFromService(l.Subscription),
}
// IP 地址仅对管理员可见
if includeIPAddress {
result.IPAddress = l.IPAddress
}
return result
}
// UsageLogFromService converts a service UsageLog to DTO for regular users.
// It excludes Account details and IP address - users should not see these.
func UsageLogFromService(l *service.UsageLog) *UsageLog {
return usageLogFromServiceBase(l, nil, false)
if l == nil {
return nil
}
u := usageLogFromServiceUser(l)
return &u
}
// UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users.
// It includes minimal Account info (ID, Name only) and IP address.
func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog {
func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog {
if l == nil {
return nil
}
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true)
return &AdminUsageLog{
UsageLog: usageLogFromServiceUser(l),
AccountRateMultiplier: l.AccountRateMultiplier,
IPAddress: l.IPAddress,
Account: AccountSummaryFromService(l.Account),
}
}
func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTask {
@@ -414,7 +466,27 @@ func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscriptio
if sub == nil {
return nil
}
return &UserSubscription{
out := userSubscriptionFromServiceBase(sub)
return &out
}
// UserSubscriptionFromServiceAdmin converts a service UserSubscription to DTO for admin users.
// It includes assignment metadata and notes.
func UserSubscriptionFromServiceAdmin(sub *service.UserSubscription) *AdminUserSubscription {
if sub == nil {
return nil
}
return &AdminUserSubscription{
UserSubscription: userSubscriptionFromServiceBase(sub),
AssignedBy: sub.AssignedBy,
AssignedAt: sub.AssignedAt,
Notes: sub.Notes,
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
}
}
func userSubscriptionFromServiceBase(sub *service.UserSubscription) UserSubscription {
return UserSubscription{
ID: sub.ID,
UserID: sub.UserID,
GroupID: sub.GroupID,
@@ -427,14 +499,10 @@ func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscriptio
DailyUsageUSD: sub.DailyUsageUSD,
WeeklyUsageUSD: sub.WeeklyUsageUSD,
MonthlyUsageUSD: sub.MonthlyUsageUSD,
AssignedBy: sub.AssignedBy,
AssignedAt: sub.AssignedAt,
Notes: sub.Notes,
CreatedAt: sub.CreatedAt,
UpdatedAt: sub.UpdatedAt,
User: UserFromServiceShallow(sub.User),
Group: GroupFromServiceShallow(sub.Group),
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
}
}
@@ -442,9 +510,9 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
if r == nil {
return nil
}
subs := make([]UserSubscription, 0, len(r.Subscriptions))
subs := make([]AdminUserSubscription, 0, len(r.Subscriptions))
for i := range r.Subscriptions {
subs = append(subs, *UserSubscriptionFromService(&r.Subscriptions[i]))
subs = append(subs, *UserSubscriptionFromServiceAdmin(&r.Subscriptions[i]))
}
return &BulkAssignResult{
SuccessCount: r.SuccessCount,

View File

@@ -2,8 +2,12 @@ package dto
// SystemSettings represents the admin settings API response payload.
type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
@@ -22,13 +26,16 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
@@ -52,19 +59,25 @@ type SystemSettings struct {
}
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version"`
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version"`
}
// StreamTimeoutSettings 流超时处理配置 DTO

View File

@@ -6,7 +6,6 @@ type User struct {
ID int64 `json:"id"`
Email string `json:"email"`
Username string `json:"username"`
Notes string `json:"notes"`
Role string `json:"role"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
@@ -19,6 +18,14 @@ type User struct {
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
}
// AdminUser 是管理员接口使用的 user DTO包含敏感/内部字段)。
// 注意:普通用户接口不得返回 notes 等管理员备注信息。
type AdminUser struct {
User
Notes string `json:"notes"`
}
type APIKey struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
@@ -58,13 +65,19 @@ type Group struct {
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// AdminGroup 是管理员接口使用的 group DTO包含敏感/内部字段)。
// 注意:普通用户接口不得返回 model_routing/account_count/account_groups 等内部信息。
type AdminGroup struct {
Group
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
AccountCount int64 `json:"account_count,omitempty"`
}
@@ -180,16 +193,28 @@ type RedeemCode struct {
Status string `json:"status"`
UsedBy *int64 `json:"used_by"`
UsedAt *time.Time `json:"used_at"`
Notes string `json:"notes"`
CreatedAt time.Time `json:"created_at"`
GroupID *int64 `json:"group_id"`
ValidityDays int `json:"validity_days"`
// Notes is only populated for admin_balance/admin_concurrency types
// so users can see why they were charged or credited
Notes *string `json:"notes,omitempty"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
}
// AdminRedeemCode 是管理员接口使用的 redeem code DTO包含 notes 等字段)。
// 注意:普通用户接口不得返回 notes 等内部信息。
type AdminRedeemCode struct {
RedeemCode
Notes string `json:"notes"`
}
// UsageLog 是普通用户接口使用的 usage log DTO不包含管理员字段
type UsageLog struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
@@ -209,14 +234,13 @@ type UsageLog struct {
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
InputCost float64 `json:"input_cost"`
OutputCost float64 `json:"output_cost"`
CacheCreationCost float64 `json:"cache_creation_cost"`
CacheReadCost float64 `json:"cache_read_cost"`
TotalCost float64 `json:"total_cost"`
ActualCost float64 `json:"actual_cost"`
RateMultiplier float64 `json:"rate_multiplier"`
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
InputCost float64 `json:"input_cost"`
OutputCost float64 `json:"output_cost"`
CacheCreationCost float64 `json:"cache_creation_cost"`
CacheReadCost float64 `json:"cache_read_cost"`
TotalCost float64 `json:"total_cost"`
ActualCost float64 `json:"actual_cost"`
RateMultiplier float64 `json:"rate_multiplier"`
BillingType int8 `json:"billing_type"`
Stream bool `json:"stream"`
@@ -230,18 +254,28 @@ type UsageLog struct {
// User-Agent
UserAgent *string `json:"user_agent"`
// IP 地址(仅管理员可见)
IPAddress *string `json:"ip_address,omitempty"`
CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"`
APIKey *APIKey `json:"api_key,omitempty"`
Account *AccountSummary `json:"account,omitempty"` // Use minimal AccountSummary to prevent data leakage
Group *Group `json:"group,omitempty"`
Subscription *UserSubscription `json:"subscription,omitempty"`
}
// AdminUsageLog 是管理员接口使用的 usage log DTO包含管理员字段
type AdminUsageLog struct {
UsageLog
// AccountRateMultiplier 账号计费倍率快照nil 表示按 1.0 处理)
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
// IPAddress 用户请求 IP仅管理员可见
IPAddress *string `json:"ip_address,omitempty"`
// Account 最小账号信息(避免泄露敏感字段)
Account *AccountSummary `json:"account,omitempty"`
}
type UsageCleanupFilters struct {
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
@@ -300,23 +334,30 @@ type UserSubscription struct {
WeeklyUsageUSD float64 `json:"weekly_usage_usd"`
MonthlyUsageUSD float64 `json:"monthly_usage_usd"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
}
// AdminUserSubscription 是管理员接口使用的订阅 DTO包含分配信息/备注等字段)。
// 注意:普通用户接口不得返回 assigned_by/assigned_at/notes/assigned_by_user 等管理员字段。
type AdminUserSubscription struct {
UserSubscription
AssignedBy *int64 `json:"assigned_by"`
AssignedAt time.Time `json:"assigned_at"`
Notes string `json:"notes"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
AssignedByUser *User `json:"assigned_by_user,omitempty"`
AssignedByUser *User `json:"assigned_by_user,omitempty"`
}
type BulkAssignResult struct {
SuccessCount int `json:"success_count"`
FailedCount int `json:"failed_count"`
Subscriptions []UserSubscription `json:"subscriptions"`
Errors []string `json:"errors"`
SuccessCount int `json:"success_count"`
FailedCount int `json:"failed_count"`
Subscriptions []AdminUserSubscription `json:"subscriptions"`
Errors []string `json:"errors"`
}
// PromoCode 注册优惠码

View File

@@ -30,6 +30,7 @@ type GatewayHandler struct {
antigravityGatewayService *service.AntigravityGatewayService
userService *service.UserService
billingCacheService *service.BillingCacheService
usageService *service.UsageService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
maxAccountSwitchesGemini int
@@ -43,6 +44,7 @@ func NewGatewayHandler(
userService *service.UserService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
usageService *service.UsageService,
cfg *config.Config,
) *GatewayHandler {
pingInterval := time.Duration(0)
@@ -63,6 +65,7 @@ func NewGatewayHandler(
antigravityGatewayService: antigravityGatewayService,
userService: userService,
billingCacheService: billingCacheService,
usageService: usageService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
maxAccountSwitches: maxAccountSwitches,
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
@@ -209,17 +212,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account := selection.Account
setOpsSelectedAccount(c, account.ID)
// 检查预热请求拦截(在账号选择后、转发前检查
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
// 检查请求拦截(预热请求、SUGGESTION MODE等
if account.IsInterceptWarmupEnabled() {
interceptType := detectInterceptType(body)
if interceptType != InterceptTypeNone {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if reqStream {
sendMockInterceptStream(c, reqModel, interceptType)
} else {
sendMockInterceptResponse(c, reqModel, interceptType)
}
return
}
if reqStream {
sendMockWarmupStream(c, reqModel)
} else {
sendMockWarmupResponse(c, reqModel)
}
return
}
// 3. 获取账号并发槽位
@@ -344,17 +350,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account := selection.Account
setOpsSelectedAccount(c, account.ID)
// 检查预热请求拦截(在账号选择后、转发前检查
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
// 检查请求拦截(预热请求、SUGGESTION MODE等
if account.IsInterceptWarmupEnabled() {
interceptType := detectInterceptType(body)
if interceptType != InterceptTypeNone {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if reqStream {
sendMockInterceptStream(c, reqModel, interceptType)
} else {
sendMockInterceptResponse(c, reqModel, interceptType)
}
return
}
if reqStream {
sendMockWarmupStream(c, reqModel)
} else {
sendMockWarmupResponse(c, reqModel)
}
return
}
// 3. 获取账号并发槽位
@@ -518,7 +527,7 @@ func (h *GatewayHandler) AntigravityModels(c *gin.Context) {
})
}
// Usage handles getting account balance for CC Switch integration
// Usage handles getting account balance and usage statistics for CC Switch integration
// GET /v1/usage
func (h *GatewayHandler) Usage(c *gin.Context) {
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
@@ -533,7 +542,40 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
return
}
// 订阅模式:返回订阅限额信息
// Best-effort: 获取用量统计,失败不影响基础响应
var usageData gin.H
if h.usageService != nil {
dashStats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID)
if err == nil && dashStats != nil {
usageData = gin.H{
"today": gin.H{
"requests": dashStats.TodayRequests,
"input_tokens": dashStats.TodayInputTokens,
"output_tokens": dashStats.TodayOutputTokens,
"cache_creation_tokens": dashStats.TodayCacheCreationTokens,
"cache_read_tokens": dashStats.TodayCacheReadTokens,
"total_tokens": dashStats.TodayTokens,
"cost": dashStats.TodayCost,
"actual_cost": dashStats.TodayActualCost,
},
"total": gin.H{
"requests": dashStats.TotalRequests,
"input_tokens": dashStats.TotalInputTokens,
"output_tokens": dashStats.TotalOutputTokens,
"cache_creation_tokens": dashStats.TotalCacheCreationTokens,
"cache_read_tokens": dashStats.TotalCacheReadTokens,
"total_tokens": dashStats.TotalTokens,
"cost": dashStats.TotalCost,
"actual_cost": dashStats.TotalActualCost,
},
"average_duration_ms": dashStats.AverageDurationMs,
"rpm": dashStats.Rpm,
"tpm": dashStats.Tpm,
}
}
}
// 订阅模式:返回订阅限额信息 + 用量统计
if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() {
subscription, ok := middleware2.GetSubscriptionFromContext(c)
if !ok {
@@ -542,28 +584,46 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
}
remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription)
c.JSON(http.StatusOK, gin.H{
resp := gin.H{
"isValid": true,
"planName": apiKey.Group.Name,
"remaining": remaining,
"unit": "USD",
})
"subscription": gin.H{
"daily_usage_usd": subscription.DailyUsageUSD,
"weekly_usage_usd": subscription.WeeklyUsageUSD,
"monthly_usage_usd": subscription.MonthlyUsageUSD,
"daily_limit_usd": apiKey.Group.DailyLimitUSD,
"weekly_limit_usd": apiKey.Group.WeeklyLimitUSD,
"monthly_limit_usd": apiKey.Group.MonthlyLimitUSD,
"expires_at": subscription.ExpiresAt,
},
}
if usageData != nil {
resp["usage"] = usageData
}
c.JSON(http.StatusOK, resp)
return
}
// 余额模式:返回钱包余额
// 余额模式:返回钱包余额 + 用量统计
latestUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info")
return
}
c.JSON(http.StatusOK, gin.H{
resp := gin.H{
"isValid": true,
"planName": "钱包余额",
"remaining": latestUser.Balance,
"unit": "USD",
})
"balance": latestUser.Balance,
}
if usageData != nil {
resp["usage"] = usageData
}
c.JSON(http.StatusOK, resp)
}
// calculateSubscriptionRemaining 计算订阅剩余可用额度
@@ -765,17 +825,30 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
}
}
// isWarmupRequest 检测是否为预热请求标题生成、Warmup等
func isWarmupRequest(body []byte) bool {
// 快速检查如果body不包含关键字直接返回false
// InterceptType 表示请求拦截类型
type InterceptType int
const (
InterceptTypeNone InterceptType = iota
InterceptTypeWarmup // 预热请求(返回 "New Conversation"
InterceptTypeSuggestionMode // SUGGESTION MODE返回空字符串
)
// detectInterceptType 检测请求是否需要拦截,返回拦截类型
func detectInterceptType(body []byte) InterceptType {
// 快速检查:如果不包含任何关键字,直接返回
bodyStr := string(body)
if !strings.Contains(bodyStr, "title") && !strings.Contains(bodyStr, "Warmup") {
return false
hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:")
hasWarmupKeyword := strings.Contains(bodyStr, "title") || strings.Contains(bodyStr, "Warmup")
if !hasSuggestionMode && !hasWarmupKeyword {
return InterceptTypeNone
}
// 解析完整请求
// 解析请求(只解析一次)
var req struct {
Messages []struct {
Role string `json:"role"`
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
@@ -786,43 +859,71 @@ func isWarmupRequest(body []byte) bool {
} `json:"system"`
}
if err := json.Unmarshal(body, &req); err != nil {
return false
return InterceptTypeNone
}
// 检查 messages 中的标题提示模式
for _, msg := range req.Messages {
for _, content := range msg.Content {
if content.Type == "text" {
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
content.Text == "Warmup" {
return true
// 检查 SUGGESTION MODE最后一条 user 消息)
if hasSuggestionMode && len(req.Messages) > 0 {
lastMsg := req.Messages[len(req.Messages)-1]
if lastMsg.Role == "user" && len(lastMsg.Content) > 0 &&
lastMsg.Content[0].Type == "text" &&
strings.HasPrefix(lastMsg.Content[0].Text, "[SUGGESTION MODE:") {
return InterceptTypeSuggestionMode
}
}
// 检查 Warmup 请求
if hasWarmupKeyword {
// 检查 messages 中的标题提示模式
for _, msg := range req.Messages {
for _, content := range msg.Content {
if content.Type == "text" {
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
content.Text == "Warmup" {
return InterceptTypeWarmup
}
}
}
}
// 检查 system 中的标题提取模式
for _, sys := range req.System {
if strings.Contains(sys.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
return InterceptTypeWarmup
}
}
}
// 检查 system 中的标题提取模式
for _, system := range req.System {
if strings.Contains(system.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
return true
}
}
return false
return InterceptTypeNone
}
// sendMockWarmupStream 发送流式 mock 响应(用于预热请求拦截)
func sendMockWarmupStream(c *gin.Context, model string) {
// sendMockInterceptStream 发送流式 mock 响应(用于请求拦截)
func sendMockInterceptStream(c *gin.Context, model string, interceptType InterceptType) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
// 根据拦截类型决定响应内容
var msgID string
var outputTokens int
var textDeltas []string
switch interceptType {
case InterceptTypeSuggestionMode:
msgID = "msg_mock_suggestion"
outputTokens = 1
textDeltas = []string{""} // 空内容
default: // InterceptTypeWarmup
msgID = "msg_mock_warmup"
outputTokens = 2
textDeltas = []string{"New", " Conversation"}
}
// Build message_start event with proper JSON marshaling
messageStart := map[string]any{
"type": "message_start",
"message": map[string]any{
"id": "msg_mock_warmup",
"id": msgID,
"type": "message",
"role": "assistant",
"model": model,
@@ -837,16 +938,46 @@ func sendMockWarmupStream(c *gin.Context, model string) {
}
messageStartJSON, _ := json.Marshal(messageStart)
// Build events
events := []string{
`event: message_start` + "\n" + `data: ` + string(messageStartJSON),
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
`event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
`event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
`event: content_block_stop` + "\n" + `data: {"index":0,"type":"content_block_stop"}`,
`event: message_delta` + "\n" + `data: {"delta":{"stop_reason":"end_turn","stop_sequence":null},"type":"message_delta","usage":{"input_tokens":10,"output_tokens":2}}`,
`event: message_stop` + "\n" + `data: {"type":"message_stop"}`,
}
// Add text deltas
for _, text := range textDeltas {
delta := map[string]any{
"type": "content_block_delta",
"index": 0,
"delta": map[string]string{
"type": "text_delta",
"text": text,
},
}
deltaJSON, _ := json.Marshal(delta)
events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON))
}
// Add final events
messageDelta := map[string]any{
"type": "message_delta",
"delta": map[string]any{
"stop_reason": "end_turn",
"stop_sequence": nil,
},
"usage": map[string]int{
"input_tokens": 10,
"output_tokens": outputTokens,
},
}
messageDeltaJSON, _ := json.Marshal(messageDelta)
events = append(events,
`event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`,
`event: message_delta`+"\n"+`data: `+string(messageDeltaJSON),
`event: message_stop`+"\n"+`data: {"type":"message_stop"}`,
)
for _, event := range events {
_, _ = c.Writer.WriteString(event + "\n\n")
c.Writer.Flush()
@@ -854,18 +985,32 @@ func sendMockWarmupStream(c *gin.Context, model string) {
}
}
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
func sendMockWarmupResponse(c *gin.Context, model string) {
// sendMockInterceptResponse 发送非流式 mock 响应(用于请求拦截)
func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) {
var msgID, text string
var outputTokens int
switch interceptType {
case InterceptTypeSuggestionMode:
msgID = "msg_mock_suggestion"
text = ""
outputTokens = 1
default: // InterceptTypeWarmup
msgID = "msg_mock_warmup"
text = "New Conversation"
outputTokens = 2
}
c.JSON(http.StatusOK, gin.H{
"id": "msg_mock_warmup",
"id": msgID,
"type": "message",
"role": "assistant",
"model": model,
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
"content": []gin.H{{"type": "text", "text": text}},
"stop_reason": "end_turn",
"usage": gin.H{
"input_tokens": 10,
"output_tokens": 2,
"output_tokens": outputTokens,
},
})
}

View File

@@ -0,0 +1,122 @@
//go:build unit
package handler
import (
"crypto/sha256"
"encoding/hex"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestExtractGeminiCLISessionHash(t *testing.T) {
tests := []struct {
name string
body string
privilegedUserID string
wantEmpty bool
wantHash string
}{
{
name: "with privileged-user-id and tmp dir",
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: false,
wantHash: func() string {
combined := "90785f52-8bbe-4b17-b111-a1ddea1636c3:f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"
hash := sha256.Sum256([]byte(combined))
return hex.EncodeToString(hash[:])
}(),
},
{
name: "without privileged-user-id but with tmp dir",
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
privilegedUserID: "",
wantEmpty: false,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "without tmp dir",
body: `{"contents":[{"parts":[{"text":"Hello world"}]}]}`,
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: true,
},
{
name: "empty body",
body: "",
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 创建测试上下文
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/test", nil)
if tt.privilegedUserID != "" {
c.Request.Header.Set("x-gemini-api-privileged-user-id", tt.privilegedUserID)
}
// 调用函数
result := extractGeminiCLISessionHash(c, []byte(tt.body))
// 验证结果
if tt.wantEmpty {
require.Empty(t, result, "expected empty session hash")
} else {
require.NotEmpty(t, result, "expected non-empty session hash")
require.Equal(t, tt.wantHash, result, "session hash mismatch")
}
})
}
}
func TestGeminiCLITmpDirRegex(t *testing.T) {
tests := []struct {
name string
input string
wantMatch bool
wantHash string
}{
{
name: "valid tmp dir path",
input: "/Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
wantMatch: true,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "valid tmp dir path in text",
input: "The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740\nOther text",
wantMatch: true,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "invalid hash length",
input: "/Users/ianshaw/.gemini/tmp/abc123",
wantMatch: false,
},
{
name: "no tmp dir",
input: "Hello world",
wantMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
match := geminiCLITmpDirRegex.FindStringSubmatch(tt.input)
if tt.wantMatch {
require.NotNil(t, match, "expected regex to match")
require.Len(t, match, 2, "expected 2 capture groups")
require.Equal(t, tt.wantHash, match[1], "hash mismatch")
} else {
require.Nil(t, match, "expected regex not to match")
}
})
}
}

View File

@@ -1,11 +1,15 @@
package handler
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"io"
"log"
"net/http"
"regexp"
"strings"
"time"
@@ -19,6 +23,17 @@ import (
"github.com/gin-gonic/gin"
)
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
func isGeminiCLIRequest(c *gin.Context, body []byte) bool {
if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" {
return true
}
return geminiCLITmpDirRegex.Match(body)
}
// GeminiV1BetaListModels proxies:
// GET /v1beta/models
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
@@ -214,12 +229,26 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 3) select account (sticky session based on request body)
parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 优先使用 Gemini CLI 的会话标识privileged-user-id + tmp 目录哈希)
sessionHash := extractGeminiCLISessionHash(c, body)
if sessionHash == "" {
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
}
sessionKey := sessionHash
if sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
// 查询粘性会话绑定的账号 ID用于检测账号切换
var sessionBoundAccountID int64
if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
}
isCLI := isGeminiCLIRequest(c, body)
cleanedForUnknownBinding := false
maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
@@ -238,6 +267,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
account := selection.Account
setOpsSelectedAccount(c, account.ID)
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
// 注意Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
body = service.CleanGeminiNativeThoughtSignatures(body)
sessionBoundAccountID = account.ID
} else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
// 无缓存绑定但请求里已有 thoughtSignature常见于缓存丢失/TTL 过期后CLI 继续携带旧签名。
// 为避免第一次转发就 400这里做一次确定性清理让新账号重新生成签名链路。
log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively")
body = service.CleanGeminiNativeThoughtSignatures(body)
cleanedForUnknownBinding = true
sessionBoundAccountID = account.ID
} else if sessionBoundAccountID == 0 {
// 记录本次请求中首次选择到的账号,便于同一请求内 failover 时检测切换。
sessionBoundAccountID = account.ID
}
// 4) account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
@@ -433,3 +480,38 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
}
return false
}
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
//
// 会话标识生成策略:
// 1. 从请求体中提取 tmp 目录哈希64位十六进制
// 2. 从 header 中提取 privileged-user-idUUID
// 3. 组合两者生成 SHA256 哈希作为最终的会话标识
//
// 如果找不到 tmp 目录哈希,返回空字符串(不使用粘性会话)。
//
// extractGeminiCLISessionHash extracts session identifier from Gemini CLI requests.
// Combines x-gemini-api-privileged-user-id header with tmp directory hash from request body.
func extractGeminiCLISessionHash(c *gin.Context, body []byte) string {
// 1. 从请求体中提取 tmp 目录哈希
match := geminiCLITmpDirRegex.FindSubmatch(body)
if len(match) < 2 {
return "" // 没有找到 tmp 目录,不使用粘性会话
}
tmpDirHash := string(match[1])
// 2. 提取 privileged-user-id
privilegedUserID := strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id"))
// 3. 组合生成最终的 session hash
if privilegedUserID != "" {
// 组合两个标识符privileged-user-id + tmp 目录哈希
combined := privilegedUserID + ":" + tmpDirHash
hash := sha256.Sum256([]byte(combined))
return hex.EncodeToString(hash[:])
}
// 如果没有 privileged-user-id直接使用 tmp 目录哈希
return tmpDirHash
}

View File

@@ -10,6 +10,7 @@ type AdminHandlers struct {
User *admin.UserHandler
Group *admin.GroupHandler
Account *admin.AccountHandler
Announcement *admin.AnnouncementHandler
OAuth *admin.OAuthHandler
OpenAIOAuth *admin.OpenAIOAuthHandler
GeminiOAuth *admin.GeminiOAuthHandler
@@ -33,10 +34,12 @@ type Handlers struct {
Usage *UsageHandler
Redeem *RedeemHandler
Subscription *SubscriptionHandler
Announcement *AnnouncementHandler
Admin *AdminHandlers
Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler
Setting *SettingHandler
Totp *TotpHandler
}
// BuildInfo contains build-time information

View File

@@ -905,7 +905,7 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool {
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
switch strings.TrimSpace(code) {
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID", "USER_INACTIVE":
return true
}
if phase == "billing" || phase == "concurrency" {
@@ -1011,5 +1011,12 @@ func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message
}
}
// Check if invalid/missing API key errors should be ignored (user misconfiguration)
if settings.IgnoreInvalidApiKeyErrors {
if strings.Contains(bodyLower, "invalid_api_key") || strings.Contains(bodyLower, "api_key_required") {
return true
}
}
return false
}

View File

@@ -32,18 +32,24 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
}
response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: h.version,
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
TotpEnabled: settings.TotpEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: h.version,
})
}

View File

@@ -0,0 +1,181 @@
package handler
import (
"github.com/gin-gonic/gin"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// TotpHandler handles TOTP-related requests
type TotpHandler struct {
totpService *service.TotpService
}
// NewTotpHandler creates a new TotpHandler
func NewTotpHandler(totpService *service.TotpService) *TotpHandler {
return &TotpHandler{
totpService: totpService,
}
}
// TotpStatusResponse represents the TOTP status response
type TotpStatusResponse struct {
Enabled bool `json:"enabled"`
EnabledAt *int64 `json:"enabled_at,omitempty"` // Unix timestamp
FeatureEnabled bool `json:"feature_enabled"`
}
// GetStatus returns the TOTP status for the current user
// GET /api/v1/user/totp/status
func (h *TotpHandler) GetStatus(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
status, err := h.totpService.GetStatus(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
resp := TotpStatusResponse{
Enabled: status.Enabled,
FeatureEnabled: status.FeatureEnabled,
}
if status.EnabledAt != nil {
ts := status.EnabledAt.Unix()
resp.EnabledAt = &ts
}
response.Success(c, resp)
}
// TotpSetupRequest represents the request to initiate TOTP setup
type TotpSetupRequest struct {
EmailCode string `json:"email_code"`
Password string `json:"password"`
}
// TotpSetupResponse represents the TOTP setup response
type TotpSetupResponse struct {
Secret string `json:"secret"`
QRCodeURL string `json:"qr_code_url"`
SetupToken string `json:"setup_token"`
Countdown int `json:"countdown"`
}
// InitiateSetup starts the TOTP setup process
// POST /api/v1/user/totp/setup
func (h *TotpHandler) InitiateSetup(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req TotpSetupRequest
if err := c.ShouldBindJSON(&req); err != nil {
// Allow empty body (optional params)
req = TotpSetupRequest{}
}
result, err := h.totpService.InitiateSetup(c.Request.Context(), subject.UserID, req.EmailCode, req.Password)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, TotpSetupResponse{
Secret: result.Secret,
QRCodeURL: result.QRCodeURL,
SetupToken: result.SetupToken,
Countdown: result.Countdown,
})
}
// TotpEnableRequest represents the request to enable TOTP
type TotpEnableRequest struct {
TotpCode string `json:"totp_code" binding:"required,len=6"`
SetupToken string `json:"setup_token" binding:"required"`
}
// Enable completes the TOTP setup
// POST /api/v1/user/totp/enable
func (h *TotpHandler) Enable(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req TotpEnableRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.totpService.CompleteSetup(c.Request.Context(), subject.UserID, req.TotpCode, req.SetupToken); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"success": true})
}
// TotpDisableRequest represents the request to disable TOTP
type TotpDisableRequest struct {
EmailCode string `json:"email_code"`
Password string `json:"password"`
}
// Disable disables TOTP for the current user
// POST /api/v1/user/totp/disable
func (h *TotpHandler) Disable(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req TotpDisableRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.totpService.Disable(c.Request.Context(), subject.UserID, req.EmailCode, req.Password); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"success": true})
}
// GetVerificationMethod returns the verification method for TOTP operations
// GET /api/v1/user/totp/verification-method
func (h *TotpHandler) GetVerificationMethod(c *gin.Context) {
method := h.totpService.GetVerificationMethod(c.Request.Context())
response.Success(c, method)
}
// SendVerifyCode sends an email verification code for TOTP operations
// POST /api/v1/user/totp/send-code
func (h *TotpHandler) SendVerifyCode(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"success": true})
}

View File

@@ -47,9 +47,6 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
return
}
// 清空notes字段普通用户不应看到备注
userData.Notes = ""
response.Success(c, dto.UserFromService(userData))
}
@@ -105,8 +102,5 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
return
}
// 清空notes字段普通用户不应看到备注
updatedUser.Notes = ""
response.Success(c, dto.UserFromService(updatedUser))
}

View File

@@ -13,6 +13,7 @@ func ProvideAdminHandlers(
userHandler *admin.UserHandler,
groupHandler *admin.GroupHandler,
accountHandler *admin.AccountHandler,
announcementHandler *admin.AnnouncementHandler,
oauthHandler *admin.OAuthHandler,
openaiOAuthHandler *admin.OpenAIOAuthHandler,
geminiOAuthHandler *admin.GeminiOAuthHandler,
@@ -32,6 +33,7 @@ func ProvideAdminHandlers(
User: userHandler,
Group: groupHandler,
Account: accountHandler,
Announcement: announcementHandler,
OAuth: oauthHandler,
OpenAIOAuth: openaiOAuthHandler,
GeminiOAuth: geminiOAuthHandler,
@@ -66,10 +68,12 @@ func ProvideHandlers(
usageHandler *UsageHandler,
redeemHandler *RedeemHandler,
subscriptionHandler *SubscriptionHandler,
announcementHandler *AnnouncementHandler,
adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler,
settingHandler *SettingHandler,
totpHandler *TotpHandler,
) *Handlers {
return &Handlers{
Auth: authHandler,
@@ -78,10 +82,12 @@ func ProvideHandlers(
Usage: usageHandler,
Redeem: redeemHandler,
Subscription: subscriptionHandler,
Announcement: announcementHandler,
Admin: adminHandlers,
Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler,
Setting: settingHandler,
Totp: totpHandler,
}
}
@@ -94,8 +100,10 @@ var ProviderSet = wire.NewSet(
NewUsageHandler,
NewRedeemHandler,
NewSubscriptionHandler,
NewAnnouncementHandler,
NewGatewayHandler,
NewOpenAIGatewayHandler,
NewTotpHandler,
ProvideSettingHandler,
// Admin handlers
@@ -103,6 +111,7 @@ var ProviderSet = wire.NewSet(
admin.NewUserHandler,
admin.NewGroupHandler,
admin.NewAccountHandler,
admin.NewAnnouncementHandler,
admin.NewOAuthHandler,
admin.NewOpenAIOAuthHandler,
admin.NewGeminiOAuthHandler,

View File

@@ -33,7 +33,7 @@ const (
"https://www.googleapis.com/auth/experimentsandconfigs"
// User-Agent与 Antigravity-Manager 保持一致)
UserAgent = "antigravity/1.11.9 windows/amd64"
UserAgent = "antigravity/1.15.8 windows/amd64"
// Session 过期时间
SessionTTL = 30 * time.Minute

View File

@@ -7,13 +7,11 @@ import (
"fmt"
"log"
"math/rand"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
@@ -369,8 +367,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
Text: block.Thinking,
Thought: true,
}
// 保留原有 signatureClaude 模型需要有效的 signature
if block.Signature != "" {
// signature 处理:
// - Claude 模型allowDummyThought=false必须是上游返回的真实 signaturedummy 视为缺失)
// - Gemini 模型allowDummyThought=true优先透传真实 signature缺失时使用 dummy signature
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
part.ThoughtSignature = block.Signature
} else if !allowDummyThought {
// Claude 模型需要有效 signature在缺失时降级为普通文本并在上层禁用 thinking mode。
@@ -409,12 +409,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
},
}
// tool_use 的 signature 处理:
// - Gemini 模型:使用 dummy signature跳过 thought_signature 校验
// - Claude 模型:透传上游返回的真实 signatureVertex/Google 需要完整签名链路)
if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature
} else if block.Signature != "" && block.Signature != dummyThoughtSignature {
// - Claude 模型allowDummyThought=false必须是上游返回的真实 signaturedummy 视为缺失
// - Gemini 模型allowDummyThought=true优先透传真实 signature缺失时使用 dummy signature
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
part.ThoughtSignature = block.Signature
} else if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature
}
parts = append(parts, part)
@@ -594,11 +594,14 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
}
// 清理 JSON Schema
params := cleanJSONSchema(inputSchema)
// 1. 深度清理 [undefined] 值
DeepCleanUndefined(inputSchema)
// 2. 转换为符合 Gemini v1internal 的 schema
params := CleanJSONSchema(inputSchema)
// 为 nil schema 提供默认值
if params == nil {
params = map[string]any{
"type": "OBJECT",
"type": "object", // lowercase type
"properties": map[string]any{},
}
}
@@ -631,236 +634,3 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
FunctionDeclarations: funcDecls,
}}
}
// cleanJSONSchema 清理 JSON Schema移除 Antigravity/Gemini 不支持的字段
// 参考 proxycast 的实现,确保 schema 符合 JSON Schema draft 2020-12
func cleanJSONSchema(schema map[string]any) map[string]any {
if schema == nil {
return nil
}
cleaned := cleanSchemaValue(schema, "$")
result, ok := cleaned.(map[string]any)
if !ok {
return nil
}
// 确保有 type 字段(默认 OBJECT
if _, hasType := result["type"]; !hasType {
result["type"] = "OBJECT"
}
// 确保有 properties 字段(默认空对象)
if _, hasProps := result["properties"]; !hasProps {
result["properties"] = make(map[string]any)
}
// 验证 required 中的字段都存在于 properties 中
if required, ok := result["required"].([]any); ok {
if props, ok := result["properties"].(map[string]any); ok {
validRequired := make([]any, 0, len(required))
for _, r := range required {
if reqName, ok := r.(string); ok {
if _, exists := props[reqName]; exists {
validRequired = append(validRequired, r)
}
}
}
if len(validRequired) > 0 {
result["required"] = validRequired
} else {
delete(result, "required")
}
}
}
return result
}
var schemaValidationKeys = map[string]bool{
"minLength": true,
"maxLength": true,
"pattern": true,
"minimum": true,
"maximum": true,
"exclusiveMinimum": true,
"exclusiveMaximum": true,
"multipleOf": true,
"uniqueItems": true,
"minItems": true,
"maxItems": true,
"minProperties": true,
"maxProperties": true,
"patternProperties": true,
"propertyNames": true,
"dependencies": true,
"dependentSchemas": true,
"dependentRequired": true,
}
var warnedSchemaKeys sync.Map
func schemaCleaningWarningsEnabled() bool {
// 可通过环境变量强制开关方便排查SUB2API_SCHEMA_CLEAN_WARN=true/false
if v := strings.TrimSpace(os.Getenv("SUB2API_SCHEMA_CLEAN_WARN")); v != "" {
switch strings.ToLower(v) {
case "1", "true", "yes", "on":
return true
case "0", "false", "no", "off":
return false
}
}
// 默认:非 release 模式下输出debug/test
return gin.Mode() != gin.ReleaseMode
}
func warnSchemaKeyRemovedOnce(key, path string) {
if !schemaCleaningWarningsEnabled() {
return
}
if !schemaValidationKeys[key] {
return
}
if _, loaded := warnedSchemaKeys.LoadOrStore(key, struct{}{}); loaded {
return
}
log.Printf("[SchemaClean] removed unsupported JSON Schema validation field key=%q path=%q", key, path)
}
// excludedSchemaKeys 不支持的 schema 字段
// 基于 Claude API (Vertex AI) 的实际支持情况
// 支持: type, description, enum, properties, required, additionalProperties, items
// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段
var excludedSchemaKeys = map[string]bool{
// 元 schema 字段
"$schema": true,
"$id": true,
"$ref": true,
// 字符串验证Gemini 不支持)
"minLength": true,
"maxLength": true,
"pattern": true,
// 数字验证Claude API 通过 Vertex AI 不支持这些字段)
"minimum": true,
"maximum": true,
"exclusiveMinimum": true,
"exclusiveMaximum": true,
"multipleOf": true,
// 数组验证Claude API 通过 Vertex AI 不支持这些字段)
"uniqueItems": true,
"minItems": true,
"maxItems": true,
// 组合 schemaGemini 不支持)
"oneOf": true,
"anyOf": true,
"allOf": true,
"not": true,
"if": true,
"then": true,
"else": true,
"$defs": true,
"definitions": true,
// 对象验证(仅保留 properties/required/additionalProperties
"minProperties": true,
"maxProperties": true,
"patternProperties": true,
"propertyNames": true,
"dependencies": true,
"dependentSchemas": true,
"dependentRequired": true,
// 其他不支持的字段
"default": true,
"const": true,
"examples": true,
"deprecated": true,
"readOnly": true,
"writeOnly": true,
"contentMediaType": true,
"contentEncoding": true,
// Claude 特有字段
"strict": true,
}
// cleanSchemaValue 递归清理 schema 值
func cleanSchemaValue(value any, path string) any {
switch v := value.(type) {
case map[string]any:
result := make(map[string]any)
for k, val := range v {
// 跳过不支持的字段
if excludedSchemaKeys[k] {
warnSchemaKeyRemovedOnce(k, path)
continue
}
// 特殊处理 type 字段
if k == "type" {
result[k] = cleanTypeValue(val)
continue
}
// 特殊处理 format 字段:只保留 Gemini 支持的 format 值
if k == "format" {
if formatStr, ok := val.(string); ok {
// Gemini 只支持 date-time, date, time
if formatStr == "date-time" || formatStr == "date" || formatStr == "time" {
result[k] = val
}
// 其他 format 值直接跳过
}
continue
}
// 特殊处理 additionalPropertiesClaude API 只支持布尔值,不支持 schema 对象
if k == "additionalProperties" {
if boolVal, ok := val.(bool); ok {
result[k] = boolVal
} else {
// 如果是 schema 对象,转换为 false更安全的默认值
result[k] = false
}
continue
}
// 递归清理所有值
result[k] = cleanSchemaValue(val, path+"."+k)
}
return result
case []any:
// 递归处理数组中的每个元素
cleaned := make([]any, 0, len(v))
for i, item := range v {
cleaned = append(cleaned, cleanSchemaValue(item, fmt.Sprintf("%s[%d]", path, i)))
}
return cleaned
default:
return value
}
}
// cleanTypeValue 处理 type 字段,转换为大写
func cleanTypeValue(value any) any {
switch v := value.(type) {
case string:
return strings.ToUpper(v)
case []any:
// 联合类型 ["string", "null"] -> 取第一个非 null 类型
for _, t := range v {
if ts, ok := t.(string); ok && ts != "null" {
return strings.ToUpper(ts)
}
}
// 如果只有 null返回 STRING
return "STRING"
default:
return value
}
}

View File

@@ -100,7 +100,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"}
]`
t.Run("Gemini uses dummy tool_use signature", func(t *testing.T) {
t.Run("Gemini preserves provided tool_use signature", func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true)
if err != nil {
@@ -109,6 +109,23 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
if len(parts) != 1 || parts[0].FunctionCall == nil {
t.Fatalf("expected 1 functionCall part, got %+v", parts)
}
if parts[0].ThoughtSignature != "sig_tool_abc" {
t.Fatalf("expected preserved tool signature %q, got %q", "sig_tool_abc", parts[0].ThoughtSignature)
}
})
t.Run("Gemini falls back to dummy tool_use signature when missing", func(t *testing.T) {
contentNoSig := `[
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}}
]`
toolIDToName := make(map[string]string)
parts, _, err := buildParts(json.RawMessage(contentNoSig), toolIDToName, true)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
if len(parts) != 1 || parts[0].FunctionCall == nil {
t.Fatalf("expected 1 functionCall part, got %+v", parts)
}
if parts[0].ThoughtSignature != dummyThoughtSignature {
t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature)
}

View File

@@ -3,6 +3,7 @@ package antigravity
import (
"encoding/json"
"fmt"
"log"
"strings"
)
@@ -19,6 +20,15 @@ func TransformGeminiToClaude(geminiResp []byte, originalModel string) ([]byte, *
v1Resp.Response = directResp
v1Resp.ResponseID = directResp.ResponseID
v1Resp.ModelVersion = directResp.ModelVersion
} else if len(v1Resp.Response.Candidates) == 0 {
// 第一次解析成功但 candidates 为空,说明是直接的 GeminiResponse 格式
var directResp GeminiResponse
if err2 := json.Unmarshal(geminiResp, &directResp); err2 != nil {
return nil, nil, fmt.Errorf("parse gemini response as direct: %w", err2)
}
v1Resp.Response = directResp
v1Resp.ResponseID = directResp.ResponseID
v1Resp.ModelVersion = directResp.ModelVersion
}
// 使用处理器转换
@@ -173,16 +183,20 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
p.trailingSignature = ""
}
p.textBuilder += part.Text
// 非空 text 带签名 - 立即刷新并输出空 thinking 块
// 非空 text 带签名 - 特殊处理:先输出 text再输出空 thinking 块
if signature != "" {
p.flushText()
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "text",
Text: part.Text,
})
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking",
Thinking: "",
Signature: signature,
})
} else {
// 普通 text (无签名) - 累积到 builder
p.textBuilder += part.Text
}
}
}
@@ -242,6 +256,14 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
var finishReason string
if len(geminiResp.Candidates) > 0 {
finishReason = geminiResp.Candidates[0].FinishReason
if finishReason == "MALFORMED_FUNCTION_CALL" {
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in response for model %s", originalModel)
if geminiResp.Candidates[0].Content != nil {
if b, err := json.Marshal(geminiResp.Candidates[0].Content); err == nil {
log.Printf("[Antigravity] Malformed content: %s", string(b))
}
}
}
}
stopReason := "end_turn"

View File

@@ -0,0 +1,519 @@
package antigravity
import (
"fmt"
"strings"
)
// CleanJSONSchema 清理 JSON Schema移除 Antigravity/Gemini 不支持的字段
// 参考 Antigravity-Manager/src-tauri/src/proxy/common/json_schema.rs 实现
// 确保 schema 符合 JSON Schema draft 2020-12 且适配 Gemini v1internal
func CleanJSONSchema(schema map[string]any) map[string]any {
if schema == nil {
return nil
}
// 0. 预处理:展开 $ref (Schema Flattening)
// (Go map 是引用的,直接修改 schema)
flattenRefs(schema, extractDefs(schema))
// 递归清理
cleaned := cleanJSONSchemaRecursive(schema)
result, ok := cleaned.(map[string]any)
if !ok {
return nil
}
return result
}
// extractDefs 提取并移除定义的 helper
func extractDefs(schema map[string]any) map[string]any {
defs := make(map[string]any)
if d, ok := schema["$defs"].(map[string]any); ok {
for k, v := range d {
defs[k] = v
}
delete(schema, "$defs")
}
if d, ok := schema["definitions"].(map[string]any); ok {
for k, v := range d {
defs[k] = v
}
delete(schema, "definitions")
}
return defs
}
// flattenRefs 递归展开 $ref
func flattenRefs(schema map[string]any, defs map[string]any) {
if len(defs) == 0 {
return // 无需展开
}
// 检查并替换 $ref
if ref, ok := schema["$ref"].(string); ok {
delete(schema, "$ref")
// 解析引用名 (例如 #/$defs/MyType -> MyType)
parts := strings.Split(ref, "/")
refName := parts[len(parts)-1]
if defSchema, exists := defs[refName]; exists {
if defMap, ok := defSchema.(map[string]any); ok {
// 合并定义内容 (不覆盖现有 key)
for k, v := range defMap {
if _, has := schema[k]; !has {
schema[k] = deepCopy(v) // 需深拷贝避免共享引用
}
}
// 递归处理刚刚合并进来的内容
flattenRefs(schema, defs)
}
}
}
// 遍历子节点
for _, v := range schema {
if subMap, ok := v.(map[string]any); ok {
flattenRefs(subMap, defs)
} else if subArr, ok := v.([]any); ok {
for _, item := range subArr {
if itemMap, ok := item.(map[string]any); ok {
flattenRefs(itemMap, defs)
}
}
}
}
}
// deepCopy 深拷贝 (简单实现,仅针对 JSON 类型)
func deepCopy(src any) any {
if src == nil {
return nil
}
switch v := src.(type) {
case map[string]any:
dst := make(map[string]any)
for k, val := range v {
dst[k] = deepCopy(val)
}
return dst
case []any:
dst := make([]any, len(v))
for i, val := range v {
dst[i] = deepCopy(val)
}
return dst
default:
return src
}
}
// cleanJSONSchemaRecursive 递归核心清理逻辑
// 返回处理后的值 (通常是 input map但可能修改内部结构)
func cleanJSONSchemaRecursive(value any) any {
schemaMap, ok := value.(map[string]any)
if !ok {
return value
}
// 0. [NEW] 合并 allOf
mergeAllOf(schemaMap)
// 1. [CRITICAL] 深度递归处理子项
if props, ok := schemaMap["properties"].(map[string]any); ok {
for _, v := range props {
cleanJSONSchemaRecursive(v)
}
// Go 中不需要像 Rust 那样显式处理 nullable_keys remove required
// 因为我们在子项处理中会正确设置 type 和 description
} else if items, ok := schemaMap["items"]; ok {
// [FIX] Gemini 期望 "items" 是单个 Schema 对象(列表验证),而不是数组(元组验证)。
if itemsArr, ok := items.([]any); ok {
// 策略:将元组 [A, B] 视为 A、B 中的最佳匹配项。
best := extractBestSchemaFromUnion(itemsArr)
if best == nil {
// 回退到通用字符串
best = map[string]any{"type": "string"}
}
// 用处理后的对象替换原有数组
cleanedBest := cleanJSONSchemaRecursive(best)
schemaMap["items"] = cleanedBest
} else {
cleanJSONSchemaRecursive(items)
}
} else {
// 遍历所有值递归
for _, v := range schemaMap {
if _, isMap := v.(map[string]any); isMap {
cleanJSONSchemaRecursive(v)
} else if arr, isArr := v.([]any); isArr {
for _, item := range arr {
cleanJSONSchemaRecursive(item)
}
}
}
}
// 2. [FIX] 处理 anyOf/oneOf 联合类型: 合并属性而非直接删除
var unionArray []any
typeStr, _ := schemaMap["type"].(string)
if typeStr == "" || typeStr == "object" {
if anyOf, ok := schemaMap["anyOf"].([]any); ok {
unionArray = anyOf
} else if oneOf, ok := schemaMap["oneOf"].([]any); ok {
unionArray = oneOf
}
}
if len(unionArray) > 0 {
if bestBranch := extractBestSchemaFromUnion(unionArray); bestBranch != nil {
if bestMap, ok := bestBranch.(map[string]any); ok {
// 合并分支内容
for k, v := range bestMap {
if k == "properties" {
targetProps, _ := schemaMap["properties"].(map[string]any)
if targetProps == nil {
targetProps = make(map[string]any)
schemaMap["properties"] = targetProps
}
if sourceProps, ok := v.(map[string]any); ok {
for pk, pv := range sourceProps {
if _, exists := targetProps[pk]; !exists {
targetProps[pk] = deepCopy(pv)
}
}
}
} else if k == "required" {
targetReq, _ := schemaMap["required"].([]any)
if sourceReq, ok := v.([]any); ok {
for _, rv := range sourceReq {
// 简单的去重添加
exists := false
for _, tr := range targetReq {
if tr == rv {
exists = true
break
}
}
if !exists {
targetReq = append(targetReq, rv)
}
}
schemaMap["required"] = targetReq
}
} else if _, exists := schemaMap[k]; !exists {
schemaMap[k] = deepCopy(v)
}
}
}
}
}
// 3. [SAFETY] 检查当前对象是否为 JSON Schema 节点
looksLikeSchema := hasKey(schemaMap, "type") ||
hasKey(schemaMap, "properties") ||
hasKey(schemaMap, "items") ||
hasKey(schemaMap, "enum") ||
hasKey(schemaMap, "anyOf") ||
hasKey(schemaMap, "oneOf") ||
hasKey(schemaMap, "allOf")
if looksLikeSchema {
// 4. [ROBUST] 约束迁移
migrateConstraints(schemaMap)
// 5. [CRITICAL] 白名单过滤
allowedFields := map[string]bool{
"type": true,
"description": true,
"properties": true,
"required": true,
"items": true,
"enum": true,
"title": true,
}
for k := range schemaMap {
if !allowedFields[k] {
delete(schemaMap, k)
}
}
// 6. [SAFETY] 处理空 Object
if t, _ := schemaMap["type"].(string); t == "object" {
hasProps := false
if props, ok := schemaMap["properties"].(map[string]any); ok && len(props) > 0 {
hasProps = true
}
if !hasProps {
schemaMap["properties"] = map[string]any{
"reason": map[string]any{
"type": "string",
"description": "Reason for calling this tool",
},
}
schemaMap["required"] = []any{"reason"}
}
}
// 7. [SAFETY] Required 字段对齐
if props, ok := schemaMap["properties"].(map[string]any); ok {
if req, ok := schemaMap["required"].([]any); ok {
var validReq []any
for _, r := range req {
if rStr, ok := r.(string); ok {
if _, exists := props[rStr]; exists {
validReq = append(validReq, r)
}
}
}
if len(validReq) > 0 {
schemaMap["required"] = validReq
} else {
delete(schemaMap, "required")
}
}
}
// 8. 处理 type 字段 (Lowercase + Nullable 提取)
isEffectivelyNullable := false
if typeVal, exists := schemaMap["type"]; exists {
var selectedType string
switch v := typeVal.(type) {
case string:
lower := strings.ToLower(v)
if lower == "null" {
isEffectivelyNullable = true
selectedType = "string" // fallback
} else {
selectedType = lower
}
case []any:
// ["string", "null"]
for _, t := range v {
if ts, ok := t.(string); ok {
lower := strings.ToLower(ts)
if lower == "null" {
isEffectivelyNullable = true
} else if selectedType == "" {
selectedType = lower
}
}
}
if selectedType == "" {
selectedType = "string"
}
}
schemaMap["type"] = selectedType
} else {
// 默认 object 如果有 properties (虽然上面白名单过滤可能删了 type 如果它不在... 但 type 必在 allowlist)
// 如果没有 type但有 properties补一个
if hasKey(schemaMap, "properties") {
schemaMap["type"] = "object"
} else {
// 默认为 string ? or object? Gemini 通常需要明确 type
schemaMap["type"] = "object"
}
}
if isEffectivelyNullable {
desc, _ := schemaMap["description"].(string)
if !strings.Contains(desc, "nullable") {
if desc != "" {
desc += " "
}
desc += "(nullable)"
schemaMap["description"] = desc
}
}
// 9. Enum 值强制转字符串
if enumVals, ok := schemaMap["enum"].([]any); ok {
hasNonString := false
for i, val := range enumVals {
if _, isStr := val.(string); !isStr {
hasNonString = true
if val == nil {
enumVals[i] = "null"
} else {
enumVals[i] = fmt.Sprintf("%v", val)
}
}
}
// If we mandated string values, we must ensure type is string
if hasNonString {
schemaMap["type"] = "string"
}
}
}
return schemaMap
}
func hasKey(m map[string]any, k string) bool {
_, ok := m[k]
return ok
}
func migrateConstraints(m map[string]any) {
constraints := []struct {
key string
label string
}{
{"minLength", "minLen"},
{"maxLength", "maxLen"},
{"pattern", "pattern"},
{"minimum", "min"},
{"maximum", "max"},
{"multipleOf", "multipleOf"},
{"exclusiveMinimum", "exclMin"},
{"exclusiveMaximum", "exclMax"},
{"minItems", "minItems"},
{"maxItems", "maxItems"},
{"propertyNames", "propertyNames"},
{"format", "format"},
}
var hints []string
for _, c := range constraints {
if val, ok := m[c.key]; ok && val != nil {
hints = append(hints, fmt.Sprintf("%s: %v", c.label, val))
}
}
if len(hints) > 0 {
suffix := fmt.Sprintf(" [Constraint: %s]", strings.Join(hints, ", "))
desc, _ := m["description"].(string)
if !strings.Contains(desc, suffix) {
m["description"] = desc + suffix
}
}
}
// mergeAllOf 合并 allOf
func mergeAllOf(m map[string]any) {
allOf, ok := m["allOf"].([]any)
if !ok {
return
}
delete(m, "allOf")
mergedProps := make(map[string]any)
mergedReq := make(map[string]bool)
otherFields := make(map[string]any)
for _, sub := range allOf {
if subMap, ok := sub.(map[string]any); ok {
// Props
if props, ok := subMap["properties"].(map[string]any); ok {
for k, v := range props {
mergedProps[k] = v
}
}
// Required
if reqs, ok := subMap["required"].([]any); ok {
for _, r := range reqs {
if s, ok := r.(string); ok {
mergedReq[s] = true
}
}
}
// Others
for k, v := range subMap {
if k != "properties" && k != "required" && k != "allOf" {
if _, exists := otherFields[k]; !exists {
otherFields[k] = v
}
}
}
}
}
// Apply
for k, v := range otherFields {
if _, exists := m[k]; !exists {
m[k] = v
}
}
if len(mergedProps) > 0 {
existProps, _ := m["properties"].(map[string]any)
if existProps == nil {
existProps = make(map[string]any)
m["properties"] = existProps
}
for k, v := range mergedProps {
if _, exists := existProps[k]; !exists {
existProps[k] = v
}
}
}
if len(mergedReq) > 0 {
existReq, _ := m["required"].([]any)
var validReqs []any
for _, r := range existReq {
if s, ok := r.(string); ok {
validReqs = append(validReqs, s)
delete(mergedReq, s) // already exists
}
}
// append new
for r := range mergedReq {
validReqs = append(validReqs, r)
}
m["required"] = validReqs
}
}
// extractBestSchemaFromUnion 从 anyOf/oneOf 中选取最佳分支
func extractBestSchemaFromUnion(unionArray []any) any {
var bestOption any
bestScore := -1
for _, item := range unionArray {
score := scoreSchemaOption(item)
if score > bestScore {
bestScore = score
bestOption = item
}
}
return bestOption
}
func scoreSchemaOption(val any) int {
m, ok := val.(map[string]any)
if !ok {
return 0
}
typeStr, _ := m["type"].(string)
if hasKey(m, "properties") || typeStr == "object" {
return 3
}
if hasKey(m, "items") || typeStr == "array" {
return 2
}
if typeStr != "" && typeStr != "null" {
return 1
}
return 0
}
// DeepCleanUndefined 深度清理值为 "[undefined]" 的字段
func DeepCleanUndefined(value any) {
if value == nil {
return
}
switch v := value.(type) {
case map[string]any:
for k, val := range v {
if s, ok := val.(string); ok && s == "[undefined]" {
delete(v, k)
continue
}
DeepCleanUndefined(val)
}
case []any:
for _, val := range v {
DeepCleanUndefined(val)
}
}
}

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"log"
"strings"
)
@@ -102,6 +103,14 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
// 检查是否结束
if len(geminiResp.Candidates) > 0 {
finishReason := geminiResp.Candidates[0].FinishReason
if finishReason == "MALFORMED_FUNCTION_CALL" {
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in stream for model %s", p.originalModel)
if geminiResp.Candidates[0].Content != nil {
if b, err := json.Marshal(geminiResp.Candidates[0].Content); err == nil {
log.Printf("[Antigravity] Malformed content: %s", string(b))
}
}
}
if finishReason != "" {
_, _ = result.Write(p.emitFinish(finishReason))
}

View File

@@ -13,20 +13,26 @@ import (
"time"
)
// Claude OAuth Constants (from CRS project)
// Claude OAuth Constants
const (
// OAuth Client ID for Claude
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
// OAuth endpoints
AuthorizeURL = "https://claude.ai/oauth/authorize"
TokenURL = "https://console.anthropic.com/v1/oauth/token"
RedirectURI = "https://console.anthropic.com/oauth/code/callback"
TokenURL = "https://platform.claude.com/v1/oauth/token"
RedirectURI = "https://platform.claude.com/oauth/code/callback"
// Scopes
ScopeProfile = "user:profile"
// Scopes - Browser URL (includes org:create_api_key for user authorization)
ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers"
// Scopes - Internal API call (org:create_api_key not supported in API)
ScopeAPI = "user:profile user:inference user:sessions:claude_code user:mcp_servers"
// Scopes - Setup token (inference only)
ScopeInference = "user:inference"
// Code Verifier character set (RFC 7636 compliant)
codeVerifierCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
// Session TTL
SessionTTL = 30 * time.Minute
)
@@ -53,7 +59,6 @@ func NewSessionStore() *SessionStore {
sessions: make(map[string]*OAuthSession),
stopCh: make(chan struct{}),
}
// Start cleanup goroutine
go store.cleanup()
return store
}
@@ -78,7 +83,6 @@ func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
if !ok {
return nil, false
}
// Check if expired
if time.Since(session.CreatedAt) > SessionTTL {
return nil, false
}
@@ -122,13 +126,13 @@ func GenerateRandomBytes(n int) ([]byte, error) {
return b, nil
}
// GenerateState generates a random state string for OAuth
// GenerateState generates a random state string for OAuth (base64url encoded)
func GenerateState() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
return base64URLEncode(bytes), nil
}
// GenerateSessionID generates a unique session ID
@@ -140,13 +144,30 @@ func GenerateSessionID() (string, error) {
return hex.EncodeToString(bytes), nil
}
// GenerateCodeVerifier generates a PKCE code verifier (32 bytes -> base64url)
// GenerateCodeVerifier generates a PKCE code verifier using character set method
func GenerateCodeVerifier() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
const targetLen = 32
charsetLen := len(codeVerifierCharset)
limit := 256 - (256 % charsetLen)
result := make([]byte, 0, targetLen)
randBuf := make([]byte, targetLen*2)
for len(result) < targetLen {
if _, err := rand.Read(randBuf); err != nil {
return "", err
}
for _, b := range randBuf {
if int(b) < limit {
result = append(result, codeVerifierCharset[int(b)%charsetLen])
if len(result) >= targetLen {
break
}
}
}
}
return base64URLEncode(bytes), nil
return base64URLEncode(result), nil
}
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
@@ -158,42 +179,31 @@ func GenerateCodeChallenge(verifier string) string {
// base64URLEncode encodes bytes to base64url without padding
func base64URLEncode(data []byte) string {
encoded := base64.URLEncoding.EncodeToString(data)
// Remove padding
return strings.TrimRight(encoded, "=")
}
// BuildAuthorizationURL builds the OAuth authorization URL
// BuildAuthorizationURL builds the OAuth authorization URL with correct parameter order
func BuildAuthorizationURL(state, codeChallenge, scope string) string {
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", ClientID)
params.Set("redirect_uri", RedirectURI)
params.Set("scope", scope)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
encodedRedirectURI := url.QueryEscape(RedirectURI)
encodedScope := strings.ReplaceAll(url.QueryEscape(scope), "%20", "+")
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
// TokenRequest represents the token exchange request body
type TokenRequest struct {
GrantType string `json:"grant_type"`
ClientID string `json:"client_id"`
Code string `json:"code"`
RedirectURI string `json:"redirect_uri"`
CodeVerifier string `json:"code_verifier"`
State string `json:"state"`
return fmt.Sprintf("%s?code=true&client_id=%s&response_type=code&redirect_uri=%s&scope=%s&code_challenge=%s&code_challenge_method=S256&state=%s",
AuthorizeURL,
ClientID,
encodedRedirectURI,
encodedScope,
codeChallenge,
state,
)
}
// TokenResponse represents the token response from OAuth provider
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
// Organization and Account info from OAuth response
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
Organization *OrgInfo `json:"organization,omitempty"`
Account *AccountInfo `json:"account,omitempty"`
}
@@ -205,33 +215,6 @@ type OrgInfo struct {
// AccountInfo represents account info from OAuth response
type AccountInfo struct {
UUID string `json:"uuid"`
}
// RefreshTokenRequest represents the refresh token request
type RefreshTokenRequest struct {
GrantType string `json:"grant_type"`
RefreshToken string `json:"refresh_token"`
ClientID string `json:"client_id"`
}
// BuildTokenRequest creates a token exchange request
func BuildTokenRequest(code, codeVerifier, state string) *TokenRequest {
return &TokenRequest{
GrantType: "authorization_code",
ClientID: ClientID,
Code: code,
RedirectURI: RedirectURI,
CodeVerifier: codeVerifier,
State: state,
}
}
// BuildRefreshTokenRequest creates a refresh token request
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
return &RefreshTokenRequest{
GrantType: "refresh_token",
RefreshToken: refreshToken,
ClientID: ClientID,
}
UUID string `json:"uuid"`
EmailAddress string `json:"email_address"`
}

View File

@@ -2,6 +2,7 @@
package response
import (
"log"
"math"
"net/http"
@@ -74,6 +75,12 @@ func ErrorFrom(c *gin.Context, err error) bool {
}
statusCode, status := infraerrors.ToHTTP(err)
// Log internal errors with full details for debugging
if statusCode >= 500 && c.Request != nil {
log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, err.Error())
}
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
return true
}

View File

@@ -0,0 +1,278 @@
//go:build integration
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
//
// Integration tests for verifying TLS fingerprint correctness.
// These tests make actual network requests to external services and should be run manually.
//
// Run with: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
package tlsfingerprint
import (
"context"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"time"
)
// skipIfExternalServiceUnavailable checks if the external service is available.
// If not, it skips the test instead of failing.
func skipIfExternalServiceUnavailable(t *testing.T, err error) {
t.Helper()
if err != nil {
// Check for common network/TLS errors that indicate external service issues
errStr := err.Error()
if strings.Contains(errStr, "certificate has expired") ||
strings.Contains(errStr, "certificate is not yet valid") ||
strings.Contains(errStr, "connection refused") ||
strings.Contains(errStr, "no such host") ||
strings.Contains(errStr, "network is unreachable") ||
strings.Contains(errStr, "timeout") {
t.Skipf("skipping test: external service unavailable: %v", err)
}
t.Fatalf("failed to get fingerprint: %v", err)
}
}
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
// This test uses tls.peet.ws to verify the fingerprint.
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
func TestJA3Fingerprint(t *testing.T) {
// Skip if network is unavailable or if running in short mode
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
profile := &Profile{
Name: "Claude CLI Test",
EnableGREASE: false,
}
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
// Use tls.peet.ws fingerprint detection API
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
skipIfExternalServiceUnavailable(t, err)
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
}
// Log all fingerprint information
t.Logf("JA3: %s", fpResp.TLS.JA3)
t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
t.Logf("JA4: %s", fpResp.TLS.JA4)
t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
// Verify JA3 hash matches expected value
expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
if fpResp.TLS.JA3Hash == expectedJA3Hash {
t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
} else {
t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
}
// Verify JA4 fingerprint
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
// The suffix _a33745022dd6_1f22a2ca17c4 should match
expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
} else {
t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
}
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
// d = domain (SNI present), i = IP (no SNI)
// Since we connect to tls.peet.ws (domain), we expect 'd'
expectedJA4Prefix := "t13d5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
} else {
// Also accept 'i' variant for IP connections
altPrefix := "t13i5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
} else {
t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
}
}
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
} else {
t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
}
// Verify extension list (should be 11 extensions including SNI)
// Expected: 0-11-10-35-16-22-23-13-43-45-51
expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
} else {
t.Logf("Warning: JA3 extension list may differ")
}
}
// TestProfileExpectation defines expected fingerprint values for a profile.
type TestProfileExpectation struct {
Profile *Profile
ExpectedJA3 string // Expected JA3 hash (empty = don't check)
ExpectedJA4 string // Expected full JA4 (empty = don't check)
JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
}
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
// Run with: go test -v -tags=integration -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
func TestAllProfiles(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
// Define all profiles to test with their expected fingerprints
// These profiles are from config.yaml gateway.tls_fingerprint.profiles
profiles := []TestProfileExpectation{
{
// Linux x64 Node.js v22.17.1
// Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
Profile: &Profile{
Name: "linux_x64_node_v22171",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part
},
{
// MacOS arm64 Node.js v22.18.0
// Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
// Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
Profile: &Profile{
Name: "macos_arm64_node_v22180",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
},
}
for _, tc := range profiles {
tc := tc // capture range variable
t.Run(tc.Profile.Name, func(t *testing.T) {
fp := fetchFingerprint(t, tc.Profile)
if fp == nil {
return // fetchFingerprint already called t.Fatal
}
t.Logf("Profile: %s", tc.Profile.Name)
t.Logf(" JA3: %s", fp.JA3)
t.Logf(" JA3 Hash: %s", fp.JA3Hash)
t.Logf(" JA4: %s", fp.JA4)
t.Logf(" PeetPrint: %s", fp.PeetPrint)
t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
// Verify expectations
if tc.ExpectedJA3 != "" {
if fp.JA3Hash == tc.ExpectedJA3 {
t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
} else {
t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
}
}
if tc.ExpectedJA4 != "" {
if fp.JA4 == tc.ExpectedJA4 {
t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
} else {
t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
}
}
// Check JA4 cipher hash (stable middle part)
// JA4 format: prefix_cipherHash_extHash
if tc.JA4CipherHash != "" {
if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
} else {
t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
}
}
})
}
}
// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
t.Helper()
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
return nil
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
skipIfExternalServiceUnavailable(t, err)
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
return nil
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
return nil
}
return &fpResp.TLS
}

View File

@@ -1,21 +1,16 @@
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
//
// Integration tests for verifying TLS fingerprint correctness.
// These tests make actual network requests and should be run manually.
// Unit tests for TLS fingerprint dialer.
// Integration tests that require external network are in dialer_integration_test.go
// and require the 'integration' build tag.
//
// Run with: go test -v ./internal/pkg/tlsfingerprint/...
// Run integration tests: go test -v -run TestJA3 ./internal/pkg/tlsfingerprint/...
// Run unit tests: go test -v ./internal/pkg/tlsfingerprint/...
// Run integration tests: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
package tlsfingerprint
import (
"context"
"encoding/json"
"io"
"net/http"
"net/url"
"strings"
"testing"
"time"
)
// FingerprintResponse represents the response from tls.peet.ws/api/all.
@@ -36,148 +31,6 @@ type TLSInfo struct {
SessionID string `json:"session_id"`
}
// TestDialerBasicConnection tests that the dialer can establish TLS connections.
func TestDialerBasicConnection(t *testing.T) {
if testing.Short() {
t.Skip("skipping network test in short mode")
}
// Create a dialer with default profile
profile := &Profile{
Name: "Test Profile",
EnableGREASE: false,
}
dialer := NewDialer(profile, nil)
// Create HTTP client with custom TLS dialer
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
// Make a request to a known HTTPS endpoint
resp, err := client.Get("https://www.google.com")
if err != nil {
t.Fatalf("failed to connect: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
// This test uses tls.peet.ws to verify the fingerprint.
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
func TestJA3Fingerprint(t *testing.T) {
// Skip if network is unavailable or if running in short mode
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
profile := &Profile{
Name: "Claude CLI Test",
EnableGREASE: false,
}
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
// Use tls.peet.ws fingerprint detection API
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("failed to get fingerprint: %v", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
}
// Log all fingerprint information
t.Logf("JA3: %s", fpResp.TLS.JA3)
t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
t.Logf("JA4: %s", fpResp.TLS.JA4)
t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
// Verify JA3 hash matches expected value
expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
if fpResp.TLS.JA3Hash == expectedJA3Hash {
t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
} else {
t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
}
// Verify JA4 fingerprint
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
// The suffix _a33745022dd6_1f22a2ca17c4 should match
expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
} else {
t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
}
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
// d = domain (SNI present), i = IP (no SNI)
// Since we connect to tls.peet.ws (domain), we expect 'd'
expectedJA4Prefix := "t13d5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
} else {
// Also accept 'i' variant for IP connections
altPrefix := "t13i5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
} else {
t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
}
}
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
} else {
t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
}
// Verify extension list (should be 11 extensions including SNI)
// Expected: 0-11-10-35-16-22-23-13-43-45-51
expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
} else {
t.Logf("Warning: JA3 extension list may differ")
}
}
// TestDialerWithProfile tests that different profiles produce different fingerprints.
func TestDialerWithProfile(t *testing.T) {
// Create two dialers with different profiles

View File

@@ -39,9 +39,15 @@ import (
// 设计说明:
// - client: Ent 客户端,用于类型安全的 ORM 操作
// - sql: 原生 SQL 执行器,用于复杂查询和批量操作
// - schedulerCache: 调度器缓存,用于在账号状态变更时同步快照
type accountRepository struct {
client *dbent.Client // Ent ORM 客户端
sql sqlExecutor // 原生 SQL 执行接口
// schedulerCache 用于在账号状态变更时主动同步快照到缓存,
// 确保粘性会话能及时感知账号不可用状态。
// Used to proactively sync account snapshot to cache when status changes,
// ensuring sticky sessions can promptly detect unavailable accounts.
schedulerCache service.SchedulerCache
}
type tempUnschedSnapshot struct {
@@ -51,14 +57,14 @@ type tempUnschedSnapshot struct {
// NewAccountRepository 创建账户仓储实例。
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRepository {
return newAccountRepositoryWithSQL(client, sqlDB)
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
return newAccountRepositoryWithSQL(client, sqlDB, schedulerCache)
}
// newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。
// 这种设计便于单元测试时注入 mock 对象。
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accountRepository {
return &accountRepository{client: client, sql: sqlq}
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor, schedulerCache service.SchedulerCache) *accountRepository {
return &accountRepository{client: client, sql: sqlq, schedulerCache: schedulerCache}
}
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
@@ -356,6 +362,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
}
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
r.syncSchedulerAccountSnapshot(ctx, account.ID)
}
return nil
}
@@ -540,9 +549,32 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
}
r.syncSchedulerAccountSnapshot(ctx, id)
return nil
}
// syncSchedulerAccountSnapshot 在账号状态变更时主动同步快照到调度器缓存。
// 当账号被设置为错误、禁用、不可调度或临时不可调度时调用,
// 确保调度器和粘性会话逻辑能及时感知账号的最新状态,避免继续使用不可用账号。
//
// syncSchedulerAccountSnapshot proactively syncs account snapshot to scheduler cache
// when account status changes. Called when account is set to error, disabled,
// unschedulable, or temporarily unschedulable, ensuring scheduler and sticky session
// logic can promptly detect the latest account state and avoid using unavailable accounts.
func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, accountID int64) {
if r == nil || r.schedulerCache == nil || accountID <= 0 {
return
}
account, err := r.GetByID(ctx, accountID)
if err != nil {
log.Printf("[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
return
}
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
log.Printf("[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
}
}
func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
@@ -873,6 +905,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
}
r.syncSchedulerAccountSnapshot(ctx, id)
return nil
}
@@ -992,6 +1025,9 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
}
if !schedulable {
r.syncSchedulerAccountSnapshot(ctx, id)
}
return nil
}
@@ -1146,6 +1182,18 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
}
shouldSync := false
if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) {
shouldSync = true
}
if updates.Schedulable != nil && !*updates.Schedulable {
shouldSync = true
}
if shouldSync {
for _, id := range ids {
r.syncSchedulerAccountSnapshot(ctx, id)
}
}
}
return rows, nil
}

View File

@@ -21,11 +21,56 @@ type AccountRepoSuite struct {
repo *accountRepository
}
type schedulerCacheRecorder struct {
setAccounts []*service.Account
}
func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
return nil, false, nil
}
func (s *schedulerCacheRecorder) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
return nil
}
func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
return nil, nil
}
func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error {
s.setAccounts = append(s.setAccounts, account)
return nil
}
func (s *schedulerCacheRecorder) DeleteAccount(ctx context.Context, accountID int64) error {
return nil
}
func (s *schedulerCacheRecorder) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
return nil
}
func (s *schedulerCacheRecorder) TryLockBucket(ctx context.Context, bucket service.SchedulerBucket, ttl time.Duration) (bool, error) {
return true, nil
}
func (s *schedulerCacheRecorder) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
return nil, nil
}
func (s *schedulerCacheRecorder) GetOutboxWatermark(ctx context.Context) (int64, error) {
return 0, nil
}
func (s *schedulerCacheRecorder) SetOutboxWatermark(ctx context.Context, id int64) error {
return nil
}
func (s *AccountRepoSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
s.repo = newAccountRepositoryWithSQL(s.client, tx)
s.repo = newAccountRepositoryWithSQL(s.client, tx, nil)
}
func TestAccountRepoSuite(t *testing.T) {
@@ -73,6 +118,20 @@ func (s *AccountRepoSuite) TestUpdate() {
s.Require().Equal("updated", got.Name)
}
func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "sync-update", Status: service.StatusActive, Schedulable: true})
cacheRecorder := &schedulerCacheRecorder{}
s.repo.schedulerCache = cacheRecorder
account.Status = service.StatusDisabled
err := s.repo.Update(s.ctx, account)
s.Require().NoError(err, "Update")
s.Require().Len(cacheRecorder.setAccounts, 1)
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
}
func (s *AccountRepoSuite) TestDelete() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
@@ -174,7 +233,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
// 每个 case 重新获取隔离资源
tx := testEntTx(s.T())
client := tx.Client()
repo := newAccountRepositoryWithSQL(client, tx)
repo := newAccountRepositoryWithSQL(client, tx, nil)
ctx := context.Background()
tt.setup(client)
@@ -365,12 +424,38 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
func (s *AccountRepoSuite) TestSetSchedulable() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-sched", Schedulable: true})
cacheRecorder := &schedulerCacheRecorder{}
s.repo.schedulerCache = cacheRecorder
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().False(got.Schedulable)
s.Require().Len(cacheRecorder.setAccounts, 1)
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
}
func (s *AccountRepoSuite) TestBulkUpdate_SyncSchedulerSnapshotOnDisabled() {
account1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-1", Status: service.StatusActive, Schedulable: true})
account2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-2", Status: service.StatusActive, Schedulable: true})
cacheRecorder := &schedulerCacheRecorder{}
s.repo.schedulerCache = cacheRecorder
disabled := service.StatusDisabled
rows, err := s.repo.BulkUpdate(s.ctx, []int64{account1.ID, account2.ID}, service.AccountBulkUpdate{
Status: &disabled,
})
s.Require().NoError(err)
s.Require().Equal(int64(2), rows)
s.Require().Len(cacheRecorder.setAccounts, 2)
ids := map[int64]struct{}{}
for _, acc := range cacheRecorder.setAccounts {
ids[acc.ID] = struct{}{}
}
s.Require().Contains(ids, account1.ID)
s.Require().Contains(ids, account2.ID)
}
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---

View File

@@ -0,0 +1,95 @@
package repository
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// AESEncryptor implements SecretEncryptor using AES-256-GCM
type AESEncryptor struct {
key []byte
}
// NewAESEncryptor creates a new AES encryptor
func NewAESEncryptor(cfg *config.Config) (service.SecretEncryptor, error) {
key, err := hex.DecodeString(cfg.Totp.EncryptionKey)
if err != nil {
return nil, fmt.Errorf("invalid totp encryption key: %w", err)
}
if len(key) != 32 {
return nil, fmt.Errorf("totp encryption key must be 32 bytes (64 hex chars), got %d bytes", len(key))
}
return &AESEncryptor{key: key}, nil
}
// Encrypt encrypts plaintext using AES-256-GCM
// Output format: base64(nonce + ciphertext + tag)
func (e *AESEncryptor) Encrypt(plaintext string) (string, error) {
block, err := aes.NewCipher(e.key)
if err != nil {
return "", fmt.Errorf("create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("create gcm: %w", err)
}
// Generate a random nonce
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", fmt.Errorf("generate nonce: %w", err)
}
// Encrypt the plaintext
// Seal appends the ciphertext and tag to the nonce
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
// Encode as base64
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// Decrypt decrypts ciphertext using AES-256-GCM
func (e *AESEncryptor) Decrypt(ciphertext string) (string, error) {
// Decode from base64
data, err := base64.StdEncoding.DecodeString(ciphertext)
if err != nil {
return "", fmt.Errorf("decode base64: %w", err)
}
block, err := aes.NewCipher(e.key)
if err != nil {
return "", fmt.Errorf("create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("create gcm: %w", err)
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return "", fmt.Errorf("ciphertext too short")
}
// Extract nonce and ciphertext
nonce, ciphertextData := data[:nonceSize], data[nonceSize:]
// Decrypt
plaintext, err := gcm.Open(nil, nonce, ciphertextData, nil)
if err != nil {
return "", fmt.Errorf("decrypt: %w", err)
}
return string(plaintext), nil
}

View File

@@ -0,0 +1,83 @@
package repository
import (
"context"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type announcementReadRepository struct {
client *dbent.Client
}
func NewAnnouncementReadRepository(client *dbent.Client) service.AnnouncementReadRepository {
return &announcementReadRepository{client: client}
}
func (r *announcementReadRepository) MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error {
client := clientFromContext(ctx, r.client)
return client.AnnouncementRead.Create().
SetAnnouncementID(announcementID).
SetUserID(userID).
SetReadAt(readAt).
OnConflictColumns(announcementread.FieldAnnouncementID, announcementread.FieldUserID).
DoNothing().
Exec(ctx)
}
func (r *announcementReadRepository) GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error) {
if len(announcementIDs) == 0 {
return map[int64]time.Time{}, nil
}
rows, err := r.client.AnnouncementRead.Query().
Where(
announcementread.UserIDEQ(userID),
announcementread.AnnouncementIDIn(announcementIDs...),
).
All(ctx)
if err != nil {
return nil, err
}
out := make(map[int64]time.Time, len(rows))
for i := range rows {
out[rows[i].AnnouncementID] = rows[i].ReadAt
}
return out, nil
}
func (r *announcementReadRepository) GetReadMapByUsers(ctx context.Context, announcementID int64, userIDs []int64) (map[int64]time.Time, error) {
if len(userIDs) == 0 {
return map[int64]time.Time{}, nil
}
rows, err := r.client.AnnouncementRead.Query().
Where(
announcementread.AnnouncementIDEQ(announcementID),
announcementread.UserIDIn(userIDs...),
).
All(ctx)
if err != nil {
return nil, err
}
out := make(map[int64]time.Time, len(rows))
for i := range rows {
out[rows[i].UserID] = rows[i].ReadAt
}
return out, nil
}
func (r *announcementReadRepository) CountByAnnouncementID(ctx context.Context, announcementID int64) (int64, error) {
count, err := r.client.AnnouncementRead.Query().
Where(announcementread.AnnouncementIDEQ(announcementID)).
Count(ctx)
if err != nil {
return 0, err
}
return int64(count), nil
}

View File

@@ -0,0 +1,194 @@
package repository
import (
"context"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type announcementRepository struct {
client *dbent.Client
}
func NewAnnouncementRepository(client *dbent.Client) service.AnnouncementRepository {
return &announcementRepository{client: client}
}
func (r *announcementRepository) Create(ctx context.Context, a *service.Announcement) error {
client := clientFromContext(ctx, r.client)
builder := client.Announcement.Create().
SetTitle(a.Title).
SetContent(a.Content).
SetStatus(a.Status).
SetTargeting(a.Targeting)
if a.StartsAt != nil {
builder.SetStartsAt(*a.StartsAt)
}
if a.EndsAt != nil {
builder.SetEndsAt(*a.EndsAt)
}
if a.CreatedBy != nil {
builder.SetCreatedBy(*a.CreatedBy)
}
if a.UpdatedBy != nil {
builder.SetUpdatedBy(*a.UpdatedBy)
}
created, err := builder.Save(ctx)
if err != nil {
return err
}
applyAnnouncementEntityToService(a, created)
return nil
}
func (r *announcementRepository) GetByID(ctx context.Context, id int64) (*service.Announcement, error) {
m, err := r.client.Announcement.Query().
Where(announcement.IDEQ(id)).
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrAnnouncementNotFound, nil)
}
return announcementEntityToService(m), nil
}
func (r *announcementRepository) Update(ctx context.Context, a *service.Announcement) error {
client := clientFromContext(ctx, r.client)
builder := client.Announcement.UpdateOneID(a.ID).
SetTitle(a.Title).
SetContent(a.Content).
SetStatus(a.Status).
SetTargeting(a.Targeting)
if a.StartsAt != nil {
builder.SetStartsAt(*a.StartsAt)
} else {
builder.ClearStartsAt()
}
if a.EndsAt != nil {
builder.SetEndsAt(*a.EndsAt)
} else {
builder.ClearEndsAt()
}
if a.CreatedBy != nil {
builder.SetCreatedBy(*a.CreatedBy)
} else {
builder.ClearCreatedBy()
}
if a.UpdatedBy != nil {
builder.SetUpdatedBy(*a.UpdatedBy)
} else {
builder.ClearUpdatedBy()
}
updated, err := builder.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrAnnouncementNotFound, nil)
}
a.UpdatedAt = updated.UpdatedAt
return nil
}
func (r *announcementRepository) Delete(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.Announcement.Delete().Where(announcement.IDEQ(id)).Exec(ctx)
return err
}
func (r *announcementRepository) List(
ctx context.Context,
params pagination.PaginationParams,
filters service.AnnouncementListFilters,
) ([]service.Announcement, *pagination.PaginationResult, error) {
q := r.client.Announcement.Query()
if filters.Status != "" {
q = q.Where(announcement.StatusEQ(filters.Status))
}
if filters.Search != "" {
q = q.Where(
announcement.Or(
announcement.TitleContainsFold(filters.Search),
announcement.ContentContainsFold(filters.Search),
),
)
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
items, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(announcement.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
out := announcementEntitiesToService(items)
return out, paginationResultFromTotal(int64(total), params), nil
}
func (r *announcementRepository) ListActive(ctx context.Context, now time.Time) ([]service.Announcement, error) {
q := r.client.Announcement.Query().
Where(
announcement.StatusEQ(service.AnnouncementStatusActive),
announcement.Or(announcement.StartsAtIsNil(), announcement.StartsAtLTE(now)),
announcement.Or(announcement.EndsAtIsNil(), announcement.EndsAtGT(now)),
).
Order(dbent.Desc(announcement.FieldID))
items, err := q.All(ctx)
if err != nil {
return nil, err
}
return announcementEntitiesToService(items), nil
}
func applyAnnouncementEntityToService(dst *service.Announcement, src *dbent.Announcement) {
if dst == nil || src == nil {
return
}
dst.ID = src.ID
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}
func announcementEntityToService(m *dbent.Announcement) *service.Announcement {
if m == nil {
return nil
}
return &service.Announcement{
ID: m.ID,
Title: m.Title,
Content: m.Content,
Status: m.Status,
Targeting: m.Targeting,
StartsAt: m.StartsAt,
EndsAt: m.EndsAt,
CreatedBy: m.CreatedBy,
UpdatedBy: m.UpdatedBy,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func announcementEntitiesToService(models []*dbent.Announcement) []service.Announcement {
out := make([]service.Announcement, 0, len(models))
for i := range models {
if s := announcementEntityToService(models[i]); s != nil {
out = append(out, *s)
}
}
return out
}

View File

@@ -387,17 +387,20 @@ func userEntityToService(u *dbent.User) *service.User {
return nil
}
return &service.User{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
ID: u.ID,
Email: u.Email,
Username: u.Username,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
TotpSecretEncrypted: u.TotpSecretEncrypted,
TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}

View File

@@ -35,7 +35,9 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
client := s.clientFactory(proxyURL)
var orgs []struct {
UUID string `json:"uuid"`
UUID string `json:"uuid"`
Name string `json:"name"`
RavenType *string `json:"raven_type"` // nil for personal, "team" for team organization
}
targetURL := s.baseURL + "/api/organizations"
@@ -65,7 +67,23 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
return "", fmt.Errorf("no organizations found")
}
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
// 如果只有一个组织,直接使用
if len(orgs) == 1 {
log.Printf("[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
return orgs[0].UUID, nil
}
// 如果有多个组织,优先选择 raven_type 为 "team" 的组织
for _, org := range orgs {
if org.RavenType != nil && *org.RavenType == "team" {
log.Printf("[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s",
org.UUID, org.Name, *org.RavenType)
return org.UUID, nil
}
}
// 如果没有 team 类型的组织,使用第一个
log.Printf("[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
return orgs[0].UUID, nil
}
@@ -182,7 +200,9 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
resp, err := client.R().
SetContext(ctx).
SetHeader("Accept", "application/json, text/plain, */*").
SetHeader("Content-Type", "application/json").
SetHeader("User-Agent", "axios/1.8.4").
SetBody(reqBody).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)
@@ -205,8 +225,6 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
client := s.clientFactory(proxyURL)
// 使用 JSON 格式(与 ExchangeCodeForToken 保持一致)
// Anthropic OAuth API 期望 JSON 格式的请求体
reqBody := map[string]any{
"grant_type": "refresh_token",
"refresh_token": refreshToken,
@@ -217,7 +235,9 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
resp, err := client.R().
SetContext(ctx).
SetHeader("Accept", "application/json, text/plain, */*").
SetHeader("Content-Type", "application/json").
SetHeader("User-Agent", "axios/1.8.4").
SetBody(reqBody).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)

View File

@@ -171,7 +171,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
s.client.baseURL = "http://in-process"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "")
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeInference, "cc", "st", "")
if tt.wantErr {
require.Error(s.T(), err)

View File

@@ -14,37 +14,82 @@ import (
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
// 默认 User-Agent与用户抓包的请求一致
const defaultUsageUserAgent = "claude-code/2.1.7"
type claudeUsageService struct {
usageURL string
allowPrivateHosts bool
httpUpstream service.HTTPUpstream
}
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
return &claudeUsageService{usageURL: defaultClaudeUsageURL}
// NewClaudeUsageFetcher 创建 Claude 用量获取服务
// httpUpstream: 可选,如果提供则支持 TLS 指纹伪装
func NewClaudeUsageFetcher(httpUpstream service.HTTPUpstream) service.ClaudeUsageFetcher {
return &claudeUsageService{
usageURL: defaultClaudeUsageURL,
httpUpstream: httpUpstream,
}
}
// FetchUsage 简单版本,不支持 TLS 指纹(向后兼容)
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: s.allowPrivateHosts,
return s.FetchUsageWithOptions(ctx, &service.ClaudeUsageFetchOptions{
AccessToken: accessToken,
ProxyURL: proxyURL,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
}
// FetchUsageWithOptions 完整版本,支持 TLS 指纹和自定义 User-Agent
func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *service.ClaudeUsageFetchOptions) (*service.ClaudeUsageResponse, error) {
if opts == nil {
return nil, fmt.Errorf("options is nil")
}
// 创建请求
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
// 设置请求头(与抓包一致,但不设置 Accept-Encoding让 Go 自动处理压缩)
req.Header.Set("Accept", "application/json, text/plain, */*")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+opts.AccessToken)
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
// 设置 User-Agent优先使用缓存的 Fingerprint否则使用默认值
userAgent := defaultUsageUserAgent
if opts.Fingerprint != nil && opts.Fingerprint.UserAgent != "" {
userAgent = opts.Fingerprint.UserAgent
}
req.Header.Set("User-Agent", userAgent)
var resp *http.Response
// 如果启用 TLS 指纹且有 HTTPUpstream使用 DoWithTLS
if opts.EnableTLSFingerprint && s.httpUpstream != nil {
// accountConcurrency 传 0 使用默认连接池配置usage 请求不需要特殊的并发设置
resp, err = s.httpUpstream.DoWithTLS(req, opts.ProxyURL, opts.AccountID, 0, true)
if err != nil {
return nil, fmt.Errorf("request with TLS fingerprint failed: %w", err)
}
} else {
// 不启用 TLS 指纹,使用普通 HTTP 客户端
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: opts.ProxyURL,
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: s.allowPrivateHosts,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
}
resp, err = client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
}
defer func() { _ = resp.Body.Close() }()

View File

@@ -9,13 +9,27 @@ import (
"github.com/redis/go-redis/v9"
)
const verifyCodeKeyPrefix = "verify_code:"
const (
verifyCodeKeyPrefix = "verify_code:"
passwordResetKeyPrefix = "password_reset:"
passwordResetSentAtKeyPrefix = "password_reset_sent:"
)
// verifyCodeKey generates the Redis key for email verification code.
func verifyCodeKey(email string) string {
return verifyCodeKeyPrefix + email
}
// passwordResetKey generates the Redis key for password reset token.
func passwordResetKey(email string) string {
return passwordResetKeyPrefix + email
}
// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
func passwordResetSentAtKey(email string) string {
return passwordResetSentAtKeyPrefix + email
}
type emailCache struct {
rdb *redis.Client
}
@@ -50,3 +64,45 @@ func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) e
key := verifyCodeKey(email)
return c.rdb.Del(ctx, key).Err()
}
// Password reset token methods
func (c *emailCache) GetPasswordResetToken(ctx context.Context, email string) (*service.PasswordResetTokenData, error) {
key := passwordResetKey(email)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
}
var data service.PasswordResetTokenData
if err := json.Unmarshal([]byte(val), &data); err != nil {
return nil, err
}
return &data, nil
}
func (c *emailCache) SetPasswordResetToken(ctx context.Context, email string, data *service.PasswordResetTokenData, ttl time.Duration) error {
key := passwordResetKey(email)
val, err := json.Marshal(data)
if err != nil {
return err
}
return c.rdb.Set(ctx, key, val, ttl).Err()
}
func (c *emailCache) DeletePasswordResetToken(ctx context.Context, email string) error {
key := passwordResetKey(email)
return c.rdb.Del(ctx, key).Err()
}
// Password reset email cooldown methods
func (c *emailCache) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
key := passwordResetSentAtKey(email)
exists, err := c.rdb.Exists(ctx, key).Result()
return err == nil && exists > 0
}
func (c *emailCache) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
key := passwordResetSentAtKey(email)
return c.rdb.Set(ctx, key, "1", ttl).Err()
}

View File

@@ -39,3 +39,15 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Expire(ctx, key, ttl).Err()
}
// DeleteSessionAccountID 删除粘性会话与账号的绑定关系。
// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用,
// 以便下次请求能够重新选择可用账号。
//
// DeleteSessionAccountID removes the sticky session binding for the given session.
// Called when the bound account becomes unavailable (e.g., error status, disabled,
// or unschedulable), allowing subsequent requests to select a new available account.
func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -78,6 +78,19 @@ func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
}
func (s *GatewayCacheSuite) TestDeleteSessionAccountID() {
sessionID := "openai:s4"
accountID := int64(102)
groupID := int64(1)
sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.DeleteSessionAccountID(s.ctx, groupID, sessionID), "DeleteSessionAccountID")
_, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
}
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
sessionID := "corrupted"
groupID := int64(1)

View File

@@ -24,7 +24,7 @@ func (s *GatewayRoutingSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx)
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx, nil)
}
func TestGatewayRoutingSuite(t *testing.T) {

View File

@@ -2,10 +2,11 @@ package repository
import (
"context"
"fmt"
"net/http"
"net/url"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3"
@@ -38,16 +39,17 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
resp, err := client.R().
SetContext(ctx).
SetHeader("User-Agent", "codex-cli/0.91.0").
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_EXCHANGE_FAILED", "token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
@@ -66,16 +68,17 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
resp, err := client.R().
SetContext(ctx).
SetHeader("User-Agent", "codex-cli/0.91.0").
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
@@ -84,6 +87,6 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
func createOpenAIReqClient(proxyURL string) *req.Client {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 60 * time.Second,
Timeout: 120 * time.Second,
})
}

View File

@@ -244,6 +244,13 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
require.ErrorContains(s.T(), err, "status 401")
}
func TestNewOpenAIOAuthClient_DefaultTokenURL(t *testing.T) {
client := NewOpenAIOAuthClient()
svc, ok := client.(*openaiOAuthService)
require.True(t, ok)
require.Equal(t, openai.TokenURL, svc.tokenURL)
}
func TestOpenAIOAuthServiceSuite(t *testing.T) {
suite.Run(t, new(OpenAIOAuthServiceSuite))
}

Some files were not shown because too many files have changed in this diff Show More