Compare commits

..

168 Commits

Author SHA1 Message Date
shaw
8f39754812 fix(gateway): 修复模型前缀映射逻辑错误
问题:normalizeClaudeModelForAnthropic 函数错误地将长模型ID截断为短ID,
导致 APIKey 账号的模型名被错误修改。

修复:
- 删除错误的 normalizeClaudeModelForAnthropic 函数和 anthropicPrefixMappings 变量
- 直接使用 claude.NormalizeModelID(正确的短ID->长ID扩展)
- APIKey 账号无显式映射时透传原始模型名
2026-02-04 17:50:05 +08:00
Wesley Liddick
ac4371fa98 Merge pull request #482 from IanShaw027/fix/gateway-compatibility
fix(gemini): 优化 Gemini 接口认证兼容性,支持 Authorization: Bearer
2026-02-04 16:50:52 +08:00
柴叁
9985c4a344 fix(gemini): 优化 Gemini 接口认证兼容性,支持 Authorization: Bearer
调整 API key 提取优先级,让 /v1beta 接口同时支持 x-goog-api-key 和
Authorization: Bearer 两种认证方式,解决 OpenClaw 等使用 Bearer 认证
的客户端无法直接访问 Gemini 接口的问题。
2026-02-04 16:26:36 +08:00
Wesley Liddick
804b6f2282 Merge pull request #468 from s-Joshua-s/fix/thinking-block-modification-error
fix(api): 修复 thinking 块被意外修改导致的 400 错误
2026-02-03 22:21:06 +08:00
Wesley Liddick
cb58daf38d Merge pull request #476 from touwaeriol/fix/gemini-thoughtsignature-v2
fix(gemini): 导出 DummyThoughtSignature 常量并修复跨账号签名验证
2026-02-03 21:56:34 +08:00
Wesley Liddick
a9398d210b Merge pull request #475 from touwaeriol/fix/openai-oauth-instructions-v2
fix(openai): 统一 OAuth instructions 处理逻辑,修复 Codex CLI 400 错误
2026-02-03 21:56:13 +08:00
shaw
df1c2383da chore: fix gofmt formatting 2026-02-03 21:52:49 +08:00
Wesley Liddick
ff8b1b4ae3 Merge pull request #467 from slovx2/main
Antigravity 相关BUG修复及调度优化
2026-02-03 21:51:04 +08:00
Wesley Liddick
4cce21b125 Merge branch 'main' into main 2026-02-03 21:43:41 +08:00
liuxiongfeng
6baf810885 fix(gemini): 导出 DummyThoughtSignature 常量并修复跨账号签名验证
- 将 dummyThoughtSignature 改为导出的 DummyThoughtSignature 常量,供跨包使用
- 修改 gemini_native_signature_cleaner.go 将删除签名改为替换为 dummy 签名
  这样可以跳过 Gemini 3 的签名验证,解决粘性会话切换账号时的验证失败问题
- 更新相关测试文件

Fixes: 粘性会话切换账号时 thoughtSignature 验证失败导致 400 错误
2026-02-03 21:34:55 +08:00
liuxiongfeng
9a48b2e942 fix(openai): 统一 OAuth instructions 处理逻辑,修复 Codex CLI 400 错误
- 修改 applyCodexOAuthTransform 函数签名,增加 isCodexCLI 参数
- 移除 && !isCodexCLI 条件,对所有 OAuth 请求统一处理
- 新增 applyInstructions/applyCodexCLIInstructions/applyOpenCodeInstructions 辅助函数
- 新增 isInstructionsEmpty 函数检查 instructions 字段是否为空
- 添加 Codex CLI 和非 Codex CLI 场景的测试用例

逻辑说明:
- Codex CLI + 有 instructions: 保持不变
- Codex CLI + 无 instructions: 补充 opencode 指令
- 非 Codex CLI: 使用 opencode 指令覆盖
2026-02-03 21:22:33 +08:00
Wesley Liddick
c0c9c984d1 Merge pull request #471 from bayma888/feature/api-key-quota-expiration
feat(api-key): 添加API密钥独立配额和过期时间功能
2026-02-03 21:11:17 +08:00
bayma888
3fed478e4d fix(lint): format gateway_service.go with gofmt 2026-02-03 20:55:17 +08:00
bayma888
0afc5d0b1a fix(test): update API contract tests for new quota fields
Add quota, quota_used, expires_at fields to expected JSON responses
in POST /api/v1/keys and GET /api/v1/keys test cases.
2026-02-03 20:54:25 +08:00
Wesley Liddick
ba5a0d47eb Merge pull request #460 from s-Joshua-s/fix/proxy-probe-fallback
fix(proxy): 增加代理探测的多 URL 回退机制
2026-02-03 20:49:59 +08:00
bayma888
be7bc658fc fix(test): add IncrementQuotaUsed to all APIKeyRepository test stubs
- Add missing IncrementQuotaUsed method to stubApiKeyRepo in api_contract_test.go
- Fix gofmt formatting issues in api_key_service.go, dto/types.go, api_key_handler.go
2026-02-03 20:49:58 +08:00
Wesley Liddick
c89bbf5130 Merge pull request #458 from bayma888/feature/admin-user-balance-history
feat(admin): 管理员可查看每个用户充值和并发变动记录、点击余额可直接查看、优化弹框UI
2026-02-03 20:37:30 +08:00
bayma888
e59e3a9f00 fix(test): add IncrementQuotaUsed to all APIKeyRepository test stubs
Add the missing IncrementQuotaUsed method to:
- fakeAPIKeyRepo (api_key_auth_google_test.go)
- stubApiKeyRepo (api_key_auth_test.go)
- apiKeyRepoStub (api_key_service_delete_test.go)
- authRepoStub (api_key_service_cache_test.go)
2026-02-03 20:00:43 +08:00
bayma888
6146be1474 feat(api-key): add independent quota and expiration support
This feature allows API Keys to have their own quota limits and expiration
times, independent of the user's balance.

Backend:
- Add quota, quota_used, expires_at fields to api_key schema
- Implement IsExpired() and IsQuotaExhausted() checks in middleware
- Add ResetQuota and ClearExpiration API endpoints
- Integrate quota billing in gateway handlers (OpenAI, Anthropic, Gemini)
- Include quota/expiration fields in auth cache for performance
- Expiration check returns 403, quota exhausted returns 429

Frontend:
- Add quota and expiration inputs to key create/edit dialog
- Add quick-select buttons for expiration (+7, +30, +90 days)
- Add reset quota confirmation dialog
- Add expires_at column to keys list
- Add i18n translations for new features (en/zh)

Migration:
- Add 045_add_api_key_quota.sql for new columns
2026-02-03 19:49:31 +08:00
bayma888
730d2a9ad2 fix(test): add missing stub methods to stubRedeemCodeRepo in api_contract_test
Add ListByUserPaginated and SumPositiveBalanceByUser methods
2026-02-03 19:36:17 +08:00
bayma888
d008941cb3 fix(test): add missing stub methods for RedeemCodeRepository and AdminService
Add ListByUserPaginated and SumPositiveBalanceByUser to redeemRepoStub
Add GetUserBalanceHistory to stubAdminService

Fixes CI test compilation errors
2026-02-03 19:29:39 +08:00
JIA-ss
df7a3e65ee fix(proxy): 增加代理探测的多 URL 回退机制
某些 AI API 专用代理只允许访问特定域名(如 anthropic.com、openai.com),
会拦截对 ip-api.com 的请求。本次修改增加了多 URL 回退机制:

1. 优先使用 ip-api.com(可获取详细地理信息:城市、地区、国家)
2. 若失败则回退到 httpbin.org/ip(仅获取 IP 地址,速度快)

主要变更:
- 新增 probeURLs 列表支持多个探测 URL
- 重构 ProbeProxy() 实现回退逻辑
- 新增 parseHTTPBin() 解析 httpbin.org 响应
- 优化错误信息,JSON 解析失败时显示响应体前 200 字符

Co-Authored-By: Claude <noreply@anthropic.com>
2026-02-03 17:12:27 +08:00
song
0707f3d963 chore: gofmt 2026-02-03 17:03:54 +08:00
song
976d6fb03f chore: update antigravity example url 2026-02-03 17:00:46 +08:00
song
f1aafbc06f chore: gofmt 2026-02-03 16:55:13 +08:00
song
7cb5444dbb fix: update tests for group fallback 2026-02-03 16:48:52 +08:00
song
3bede6e65f merge upstream main 2026-02-03 16:21:58 +08:00
JIA-ss
ad90bb4645 fix(api): 修复 thinking 块被意外修改导致的 400 错误
问题描述:
使用扩展思考功能时,偶现以下错误:
"thinking or redacted_thinking blocks in the latest assistant message cannot be modified"

根因分析:
当代理服务修改请求体中的某些字段时(如 metadata.user_id、model),
使用 map[string]any 解析整个 JSON 后重新序列化,导致:
1. 字段顺序改变(Go map 序列化按字母排序)
2. 数字格式变化(如 1.0 → 1)
3. Unicode 转义变化

Claude API 对 thinking 块进行字节级验证,任何变化都会触发错误。

修复内容:
1. identity_service.go - RewriteUserID/RewriteUserIDWithMasking
   使用 json.RawMessage 保留其他字段的原始字节

2. gateway_service.go - replaceModelInBody
   使用 json.RawMessage 保留其他字段的原始字节

3. gateway_service.go - normalizeClaudeOAuthRequestBody
   保留 messages 的原始字节,跳过包含 thinking 块的消息修改

4. gateway_service.go - isThinkingBlockSignatureError
   添加 "cannot be modified" 错误检测,触发自动重试

5. antigravity_gateway_service.go - isSignatureRelatedError
   添加 "cannot be modified" 错误检测

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 16:15:37 +08:00
song
2220fd18ca merge upstream main 2026-02-03 15:36:17 +08:00
Wesley Liddick
bb3df5785a Merge pull request #322 from xlx0852/main
fix: 混合渠道警告确认框和过滤 prompt_cache_retention 参数
2026-02-03 15:13:10 +08:00
Wesley Liddick
6e54eda41f Merge pull request #464 from touwaeriol/pr/antigravity-scope-ratelimit
feat(antigravity): 支持按配额域(scope)级别限流
2026-02-03 15:02:15 +08:00
Wesley Liddick
df4c0adf0b Merge pull request #463 from DuckyProject/feat/usage-records-codex-reasoning-effort
feat(usage): add reasoning effort column
2026-02-03 14:57:45 +08:00
Wesley Liddick
7cbe4afdb8 Merge pull request #462 from touwaeriol/pr/antigravity-gemini-mapping
feat(antigravity): map all gemini-2.5 to gemini-3 series
2026-02-03 14:55:45 +08:00
Wesley Liddick
16f150caae Merge pull request #459 from IanShaw027/fix/gemini-error
fix(gemini): 为 Gemini 工具调用添加 thoughtSignature 避免 INVALID_ARGUMENT 错误
2026-02-03 14:48:50 +08:00
Wesley Liddick
7229b41fc7 Merge pull request #420 from shuike/feat-invitation-code
feat: 增加邀请码注册功能
2026-02-03 14:44:15 +08:00
ianshaw
2cd5037878 chore: 触发 CI 重新运行 2026-02-03 14:43:59 +08:00
ducky
53ee6383db feat(usage): add reasoning effort column 2026-02-03 14:36:29 +08:00
Wesley Liddick
a09478f374 Merge pull request #316 from cyhhao/fix/claude-oauth-compat
fix(网关): 完善 Claude OAuth/Claude Code 兼容
2026-02-03 14:26:19 +08:00
liuxiongfeng
d3c1d77a35 fix(frontend): import missing formatTime function in AccountStatusIndicator 2026-02-03 14:25:30 +08:00
liuxiongfeng
8824400c3e feat(accounts): 账号列表显示 Antigravity scope 级别限流状态
- 后端 DTO 新增 scope_rate_limits 字段,从 extra 提取限流信息
- 前端状态列显示 scope 级限流徽章(Claude/Gemini/Image)
- 清除速率限制时同时清除账号级和 scope 级限流(已有实现)

Cherry-picked from slovx2/sub2api: 66f49b67
2026-02-03 14:25:30 +08:00
liuxiongfeng
6e8eff9bb9 feat(ops): 运维界面展示 Antigravity 账号 scope 级别限流统计
在运维监控的并发/排队卡片中,为 Antigravity 平台账号显示各 scope
(claude/gemini_text/gemini_image) 的限流数量统计,便于管理员了解
哪些 scope 正在被限流。

Cherry-picked from slovx2/sub2api: 08d6dc52
2026-02-03 14:25:30 +08:00
liuxiongfeng
f5884d1608 fix: jsonb_set 嵌套路径无法创建多层 key 的问题
PostgreSQL jsonb_set 在 create_if_missing=true 时无法一次性创建多层嵌套路径。
例如设置 {antigravity_quota_scopes,gemini_image} 时,如果 antigravity_quota_scopes 不存在,
jsonb_set 不会自动创建外层 key,导致更新静默失败(affected=1 但数据未变)。

修复方案:嵌套两次 jsonb_set,先确保外层 key 存在,再设置内层值。
同时在设置限流时更新 last_used_at,使刚触发 429 的账号调度优先级降低。

Cherry-picked from slovx2/sub2api: 4b57e80e
2026-02-03 14:25:30 +08:00
liuxiongfeng
56949a58bc feat(antigravity): 默认开启按配额域限流,避免整个账号被锁定
将 GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT 的默认值从关闭改为开启。
当 Gemini 模型触发 429 限流时,只会限制对应的配额域(gemini_text),
而 Claude 和 gemini_image 仍可继续使用,提高账号利用率。
2026-02-03 14:25:30 +08:00
liuxiongfeng
7d256879c5 feat(antigravity): map all gemini-2.5 to gemini-3 series
Antigravity 上游不再支持 gemini-2.5 系列,统一映射到 gemini-3:
- gemini-2.5-flash → gemini-3-flash
- gemini-2.5-flash-lite → gemini-3-flash
- gemini-2.5-flash-thinking → gemini-3-flash
- gemini-2.5-flash-image → gemini-3-pro-image
- gemini-2.5-pro → gemini-3-pro-high
- gemini-2.5-pro-preview → gemini-3-pro-high
- gemini-2.5-pro-exp → gemini-3-pro-high
2026-02-03 14:23:47 +08:00
liuxiongfeng
f9512fda58 test: update gemini-2.5-pro mapping test case
Update test expectation to reflect new mapping:
gemini-2.5-pro -> gemini-3-pro-high
2026-02-03 14:23:47 +08:00
liuxiongfeng
beb63cb152 feat(antigravity): map gemini-2.5-pro to gemini-3-pro-high
Add prefix mapping rules for gemini-2.5-pro variants:
- gemini-2.5-pro -> gemini-3-pro-high
- gemini-2.5-pro-preview -> gemini-3-pro-high
- gemini-2.5-pro-exp -> gemini-3-pro-high
2026-02-03 14:23:47 +08:00
song
11ff73b578 fix: align migration checksum and import formatTime 2026-02-03 14:04:11 +08:00
shuike
0ed4a404e4 fix(test): api_contract_test添加 invitation_code_enabled 字段 2026-02-03 13:38:44 +08:00
shuike
6c86501d11 feat: 增加邀请码注册功能 2026-02-03 13:38:44 +08:00
Call White
2fe8932c1d Merge pull request #3 from cyhhao/main
merge to main
2026-02-03 11:44:54 +08:00
shaw
0ab68aa9fb fix(setup): 修复 Redis TLS 配置选项位置错误
- 将 TLS Toggle 从第一步(PostgreSQL)移动到第二步(Redis)
- 添加缺失的 Toggle 组件导入

问题描述:
TLS 配置选项错误地出现在数据库配置步骤中,而不是 Redis 配置步骤
2026-02-03 11:24:04 +08:00
Wesley Liddick
2f92b06869 Merge pull request #457 from touwaeriol/pr/group-copy-accounts
feat(groups): 添加从其他分组复制账号功能
2026-02-03 08:45:13 +08:00
ianshaw
03e94f9f53 fix(gemini): 为 Gemini 工具调用添加 thoughtSignature 避免 INVALID_ARGUMENT 错误 2026-02-03 06:01:29 +08:00
bayma888
606e29d390 feat(admin): add user balance/concurrency history modal
- Add new API endpoint GET /admin/users/:id/balance-history with pagination and type filter
- Add SumPositiveBalanceByUser for calculating total recharged amount
- Create UserBalanceHistoryModal component with:
  - User info header (email, username, created_at, current balance, notes, total recharged)
  - Type filter dropdown (all/balance/admin_balance/concurrency/admin_concurrency/subscription)
  - Quick deposit/withdraw buttons
  - Paginated history list with icons and colored values
- Add instant tooltip on balance column for better UX
- Add z-index prop to BaseDialog for modal stacking control
- Update i18n translations (zh/en)
2026-02-03 00:16:10 +08:00
song
3ecadf4aad chore: apply stashed changes 2026-02-02 22:20:08 +08:00
song
0170d19fa7 merge upstream main 2026-02-02 22:13:50 +08:00
liuxiongfeng
ce1d2904c7 test: 为测试 stub 添加缺失的 GroupRepository 接口方法
新增 BindAccountsToGroup 和 GetAccountIDsByGroupIDs 方法的 stub 实现,
确保测试文件中的 mock 类型满足 GroupRepository 接口要求。
2026-02-02 22:06:37 +08:00
Wesley Liddick
ea41f830fd Merge pull request #456 from touwaeriol/pr/gemini-long-context-billing
feat(billing): 添加 Gemini 200K 长上下文双倍计费功能
2026-02-02 21:57:25 +08:00
liuxiongfeng
e1a4a7b8c0 feat(groups): 添加从其他分组复制账号功能
- 创建分组时可选择从已有分组复制账号
- 编辑分组时支持同步账号(全量替换操作)
- 仅允许选择相同平台的源分组
- 添加完整的数据校验:去重、自引用检查、平台一致性检查
- 前端支持多选源分组,带提示说明操作行为
2026-02-02 21:47:47 +08:00
liuxiongfeng
b381e8ee73 refactor(billing): 简化 CalculateCostWithLongContext 逻辑
将 token 直接拆分为范围内和范围外两部分,分别调用 CalculateCost:
- 范围内:正常计费 (rateMultiplier)
- 范围外:双倍计费 (rateMultiplier × extraMultiplier)

代码更直观,便于理解和维护
2026-02-02 21:47:02 +08:00
liuxiongfeng
45e1429ae8 feat(billing): 添加 Gemini 200K 长上下文双倍计费功能
- 新增 CalculateCostWithLongContext 方法支持阈值双倍计费
- 新增 RecordUsageWithLongContext 方法专用于 Gemini 计费
- Gemini 超过 200K token 的部分按 2 倍费率计算
- 其他平台(Claude/OpenAI)完全不受影响
2026-02-02 21:47:02 +08:00
shaw
7b1d63a786 fix(types): 添加缺失的 ignore_invalid_api_key_errors 类型定义
OpsAdvancedSettings 接口缺少 ignore_invalid_api_key_errors 字段,
导致 TypeScript 编译报错。
2026-02-02 21:01:32 +08:00
Wesley Liddick
e204b4d81f Merge pull request #452 from s-Joshua-s/feat/enhance-usage-endpoint
feat(gateway): 增强 /v1/usage 端点返回完整用量统计
2026-02-02 20:35:00 +08:00
Wesley Liddick
325ed747d8 Merge pull request #455 from ZeroClover/feat/ops-ignore-invalid-api-key-errors
feat(ops): 支持过滤无效 API Key 错误,不写入错误日志
2026-02-02 20:28:00 +08:00
Wesley Liddick
cbf3dba28d Merge pull request #454 from ZeroClover/feat-exclude-user-inactive-errors
feat(ops): 将 USER_INACTIVE 错误排除在 SLA 统计之外
2026-02-02 20:19:48 +08:00
Wesley Liddick
4329f72abf Merge pull request #450 from bayma888/feature/show-admin-adjustment-notes
feat: 向用户显示管理员调整余额的备注
2026-02-02 20:19:23 +08:00
Zero Clover
ad1cdba338 feat(ops): 支持过滤无效 API Key 错误,不写入错误日志
新增 IgnoreInvalidApiKeyErrors 开关,启用后 INVALID_API_KEY 和
API_KEY_REQUIRED 错误将被完全跳过,不写入 Ops 错误日志。
这些错误由用户错误配置导致,与服务质量无关。
2026-02-02 20:16:17 +08:00
Wesley Liddick
016c3915d7 Merge pull request #449 from bayma888/feature/user-search-support-notes
feat: 支持在用户列表搜索中使用备注字段和模糊查询、支持用户名备注等搜索
2026-02-02 20:16:03 +08:00
shaw
79fa18132b fix(gateway): 修复 OAuth token 刷新后调度器缓存不一致问题
Token 刷新成功后,调度器缓存中的 Account 对象仍包含旧的 credentials,
导致在 Outbox 异步更新之前(最多 1 秒窗口)请求使用过期 token,
返回 403 错误(OAuth token has been revoked)。

修复方案:在 token 刷新成功后同步更新调度器缓存,确保调度获取的
Account 对象立即包含最新的 access_token 和 _token_version。

此修复覆盖所有 OAuth 平台:OpenAI、Claude、Gemini、Antigravity。
2026-02-02 20:05:37 +08:00
Zero Clover
673caf41a0 feat(ops): 将 USER_INACTIVE 错误排除在 SLA 统计之外
将账户停用 (USER_INACTIVE) 导致的请求失败视为业务限制类错误,不计入 SLA 和错误率统计。

账户停用是预期内的业务结果,不应被视为系统错误或服务质量问题。此改动使错误分类更加准确,避免将预期的业务限制误报为系统故障。

修改内容:
- 在 classifyOpsIsBusinessLimited 函数中添加 USER_INACTIVE 错误码
- 该类错误不再触发错误率告警

Fixes Wei-Shaw/sub2api#453
2026-02-02 18:50:54 +08:00
JIA-ss
c441638fc0 feat(gateway): 增强 /v1/usage 端点返回完整用量统计
为 CC Switch 集成增强 /v1/usage 网关端点,在保持原有 4 字段
(isValid, planName, remaining, unit) 向后兼容的基础上,新增:

- usage 对象:今日/累计的请求数、token 用量、费用,以及 RPM/TPM
- subscription 对象(订阅模式):日/周/月用量和限额、过期时间
- balance 字段(余额模式):当前钱包余额

用量数据获取采用 best-effort 策略,失败不影响基础响应。

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-02 18:30:06 +08:00
小北
ae18397ca6 feat: 向用户显示管理员调整余额的备注
- 为RedeemCode DTO添加notes字段(仅用于admin_balance/admin_concurrency类型)
- 更新mapper使其有条件地包含备注信息
- 在用户兑换历史UI中显示备注
- 备注以斜体显示,悬停时显示完整内容

用户现在可以看到管理员调整其余额的原因说明。

Changes:
- backend/internal/handler/dto/types.go: RedeemCode添加notes字段
- backend/internal/handler/dto/mappers.go: 条件性填充notes
- frontend/src/api/redeem.ts: TypeScript接口添加notes
- frontend/src/views/user/RedeemView.vue: UI显示备注信息
2026-02-02 17:44:50 +08:00
小北
426ce616c0 feat: 支持在用户搜索中使用备注字段
- 在用户仓库的搜索过滤器中添加备注字段
- 管理员现在可以通过备注/标记搜索用户
- 使用不区分大小写的搜索(ContainsFold)

Changes:
- backend/internal/repository/user_repo.go: 添加 NotesContainsFold 到搜索条件
2026-02-02 17:41:27 +08:00
shaw
5cda979209 feat(deploy): 优化 Docker 部署体验,新增一键部署脚本
## 新增功能

- 新增 docker-compose.local.yml:使用本地目录存储数据,便于迁移和备份
- 新增 docker-deploy.sh:一键部署脚本,自动生成安全密钥(JWT_SECRET、TOTP_ENCRYPTION_KEY、POSTGRES_PASSWORD)
- 新增 deploy/.gitignore:忽略运行时数据目录

## 优化改进

- docker-compose.local.yml 包含 PGDATA 环境变量修复,解决 PostgreSQL 18 Alpine 数据丢失问题
- 脚本自动设置 .env 文件权限为 600,增强安全性
- 脚本显示生成的凭证,方便用户记录

## 文档更新

- 更新 README.md(英文版):新增"快速开始"章节,添加部署版本对比表
- 更新 README_CN.md(中文版):同步英文版更新
- 更新 deploy/README.md:详细说明两种部署方式和迁移方法

## 使用方式

一键部署:
```bash
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
docker-compose -f docker-compose.local.yml up -d
```

轻松迁移:
```bash
tar czf sub2api-complete.tar.gz deploy/
# 传输到新服务器后直接解压启动即可
```
2026-02-02 16:17:07 +08:00
Wesley Liddick
cc7e67b01a Merge pull request #445 from touwaeriol/fix/gemini-cache-token-billing
fix(billing): 修复 Gemini 接口缓存 token 统计
2026-02-02 15:22:46 +08:00
Wesley Liddick
6999a9c011 Merge pull request #444 from touwaeriol/fix/gemini-apikey-passthrough
feat(gateway): Gemini API Key 账户跳过模型映射检查,直接透传
2026-02-02 15:17:05 +08:00
shaw
bbdc8663d3 feat: 重新设计公告系统为Header铃铛通知
- 新增 AnnouncementBell 组件,支持 Modal 弹窗和 Markdown 渲染
- 移除 Dashboard 横幅和独立公告页面
- 铃铛位置在 Header 文档按钮左侧,显示未读红点
- 支持点击查看详情、标记已读、全部已读等操作
- 完善国际化,移除所有硬编码中文
- 修复 AnnouncementTargetingEditor watch 循环问题
2026-02-02 15:15:39 +08:00
liuxiongfeng
4bfeeecb05 fix(billing): 修复 Gemini 接口缓存 token 统计
extractGeminiUsage 函数未提取 cachedContentTokenCount,
导致计费时缓存读取 token 始终为 0。

修复:
- 提取 usageMetadata.cachedContentTokenCount
- 设置 CacheReadInputTokens 字段
- InputTokens 减去缓存 token(与 response_transformer 逻辑一致)
2026-02-02 14:01:17 +08:00
liuxiongfeng
bbc7b4aeed feat(gateway): Gemini API Key 账户跳过模型映射检查,直接透传
Gemini API Key 账户通常代理上游服务,模型支持由上游判断,
本地不需要预先配置模型映射。
2026-02-02 13:40:29 +08:00
Wesley Liddick
d3062b2e46 Merge pull request #434 from DuckyProject/feat/announcement-system-pr-upstream
feat(announcements): admin/user announcement system
2026-02-02 10:50:26 +08:00
Wesley Liddick
b7777fb46c Merge pull request #436 from iBenzene/feat/redis-tls-support
feat: add support for using TLS to connect to Redis
2026-02-02 10:02:25 +08:00
iBenzene
35f39ca291 chore: 修复了 redis.go 中代码风格(golangci-lint)的问题 2026-01-31 19:06:19 +08:00
iBenzene
f2e206700c feat: add support for using TLS to connect to Redis 2026-01-31 03:58:01 +08:00
cyhhao
adb77af1d9 fix: satisfy golangci-lint (nil checks, remove unused helpers) 2026-01-31 02:07:57 +08:00
cyhhao
3a34746668 refactor: stop rewriting tool descriptions; keep only system sentence rewrite 2026-01-31 02:01:51 +08:00
cyhhao
fe17058700 refactor: limit OpenCode keyword replacement to tool descriptions 2026-01-31 01:40:38 +08:00
ducky
9bee0a2071 chore: gofmt for golangci-lint 2026-01-30 17:28:53 +08:00
ducky
b7f69844e1 feat(announcements): add admin/user announcement system
Implements announcements end-to-end (admin CRUD + read status, user list + mark read) with OR-of-AND targeting. Also breaks the ent<->service import cycle by moving schema-facing constants/targeting into a new domain package.
2026-01-30 16:45:04 +08:00
cyhhao
602bf9c017 Merge branch 'main' of github.com:Wei-Shaw/sub2api 2026-01-30 13:21:25 +08:00
Wesley Liddick
c3d1891ccd Merge pull request #427 from touwaeriol/pr/upgrade-antigravity-ua
chore: upgrade Antigravity User-Agent to 1.15.8
2026-01-30 09:17:17 +08:00
shaw
4d8f2db924 fix: 更新所有CI workflow的Go版本验证至1.25.6 2026-01-30 08:57:37 +08:00
shaw
6599b366dc fix: 升级Go版本至1.25.6修复标准库安全漏洞
修复GO-2026-4341和GO-2026-4340两个标准库漏洞
2026-01-30 08:53:53 +08:00
liuxiongfeng
ba16ace697 chore: upgrade Antigravity User-Agent to 1.15.8 2026-01-30 08:14:52 +08:00
song
7ade9baa15 fix(gateway): 过滤 Gemini 请求中 parts 为空的消息
Gemini API 不接受 contents 数组中 parts 为空的消息,会返回 400 INVALID_ARGUMENT 错误。
添加 filterEmptyPartsFromGeminiRequest 函数在转发前过滤这类消息。

影响范围:ForwardGemini (antigravity) 和 ForwardNative (gemini)
2026-01-29 21:09:33 +08:00
cyhhao
fa454b1b99 fix: align Claude Code system banner with opencode latest 2026-01-29 15:37:07 +08:00
cyhhao
8375094c69 fix(oauth): match Claude CLI accept header and beta set 2026-01-29 15:31:29 +08:00
cyhhao
91079d3f15 chore(debug): emit Claude mimic fingerprint on credential-scope error 2026-01-29 15:17:46 +08:00
cyhhao
63412a9fcc chore(debug): log Claude mimic fingerprint 2026-01-29 03:13:14 +08:00
cyhhao
d98648f03b fix: rewrite OpenCode identity sentence to Claude Code 2026-01-29 03:03:40 +08:00
cyhhao
c37fe91672 fix(oauth): update Claude CLI fingerprint headers 2026-01-29 02:52:26 +08:00
cyhhao
4d40fb6b60 fix(oauth): merge anthropic-beta and force Claude Code headers in mimic mode 2026-01-29 02:36:28 +08:00
cyhhao
be3b788b8f fix: also prefix next system block with Claude Code banner 2026-01-29 02:03:54 +08:00
cyhhao
723e54013a fix(oauth): mimic Claude Code metadata and beta headers 2026-01-29 01:49:51 +08:00
cyhhao
4d566f68b6 chore: gofmt 2026-01-29 01:34:58 +08:00
cyhhao
31f817d189 fix: add newline separation for Claude Code system prompt 2026-01-29 01:28:43 +08:00
cyhhao
59231668c5 Merge branch 'main' of github.com:Wei-Shaw/sub2api 2026-01-29 01:16:36 +08:00
shaw
cadca752c4 修复SSE流式响应中usage数据被覆盖的问题 2026-01-28 18:36:21 +08:00
Wesley Liddick
edf215e6fd Merge pull request #409 from DuckyProject/feat/purchase-subscription-iframe
feat(purchase): 增加购买订阅 iframe 页面与配置
2026-01-28 17:28:47 +08:00
shaw
e12dd079fd 修复调度器空缓存导致的竞态条件bug
当新分组创建后立即绑定账号时,调度器会错误地将空快照视为有效缓存命中,
导致返回没有可调度的账号。现在空快照会触发数据库回退查询。
2026-01-28 17:26:32 +08:00
ducky
04a509d45e feat(purchase): 增加购买订阅 iframe 页面与配置
- 新增 /purchase 页面(iframe + 新窗口兜底)

- 管理员系统设置可配置开关与URL

- 非 simple mode 才在侧边栏展示入口
2026-01-28 13:54:32 +08:00
Wesley Liddick
269a659200 Merge pull request #406 from geminiwen/main
fix(openai-oauth): 改进错误处理和代理支持
2026-01-28 13:53:44 +08:00
Wesley Liddick
2c31bf46b5 Merge pull request #401 from slovx2/heihuzi_main
feat(gemini): 为 Gemini 原生平台添加图片计费支持
2026-01-28 13:51:14 +08:00
song
5b787334c8 antigravity: 转发优先 daily 2026-01-28 11:17:39 +08:00
song
f761afb1ef antigravity: 区分切换后重试次数 2026-01-28 00:01:03 +08:00
Gemini Wen
8f6639f825 fix(response): add nil check for c.Request in error logging
Prevents panic when ErrorFrom is called in test contexts where
gin.CreateTestContext doesn't set up an HTTP request.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 19:26:44 +08:00
Gemini Wen
fc17d9d7df chore: bump version to 0.1.61 and fix tests
- Update VERSION from 0.1.46 to 0.1.61
- Remove ForceHTTP2 tests for OpenAI OAuth client (ForceHTTP2 was removed)
- Update createOpenAIReqClient test to use new single-arg signature

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 19:22:45 +08:00
Gemini Wen
ab092e88a8 fix(openai-oauth): 改进错误处理和代理支持
- 使用 ApplicationError 返回详细错误信息到前端
- 添加 User-Agent: codex-cli/0.91.0
- 移除 ForceHTTP2 以兼容 HTTP 代理
- 修复代理获取失败时静默忽略的问题
- 500 错误时记录完整错误日志

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 19:13:01 +08:00
song
877c17251d feat(group): 添加 MCP XML 注入开关
- Group 新增 mcp_xml_inject 字段,控制 Antigravity 平台的 MCP XML 协议注入
- 默认启用,可在分组设置中关闭
- 修复 GetByKeyForAuth 遗漏查询 mcp_xml_inject 字段导致认证缓存值始终为 false 的问题
2026-01-27 13:09:56 +08:00
cyhhao
ffe43f6098 Merge branch 'main' of github.com:Wei-Shaw/sub2api 2026-01-27 11:09:11 +08:00
song
66f49b67d6 feat(accounts): 账号列表显示 Antigravity scope 级别限流状态
- 后端 DTO 新增 scope_rate_limits 字段,从 extra 提取限流信息
- 前端状态列显示 scope 级限流徽章(Claude/Gemini/Image)
- 清除速率限制时同时清除账号级和 scope 级限流(已有实现)
2026-01-27 11:04:41 +08:00
song
08d6dc5227 feat(ops): 运维界面展示 Antigravity 账号 scope 级别限流统计
在运维监控的并发/排队卡片中,为 Antigravity 平台账号显示各 scope
(claude/gemini_text/gemini_image) 的限流数量统计,便于管理员了解
哪些 scope 正在被限流。
2026-01-27 09:34:10 +08:00
song
7cea6b6fc9 feat(gemini): 为 Gemini 原生平台添加图片计费支持
对齐 Antigravity 平台的图片计费逻辑:
- 添加 extractImageSize() 方法提取图片尺寸
- Forward() 和 ForwardNative() 返回 ImageCount/ImageSize
- 支持分组自定义图片价格和倍率
2026-01-27 00:33:48 +08:00
song
4b57e80e6a fix: jsonb_set 嵌套路径无法创建多层 key 的问题
PostgreSQL jsonb_set 在 create_if_missing=true 时无法一次性创建多层嵌套路径。
例如设置 {antigravity_quota_scopes,gemini_image} 时,如果 antigravity_quota_scopes 不存在,
jsonb_set 不会自动创建外层 key,导致更新静默失败(affected=1 但数据未变)。

修复方案:嵌套两次 jsonb_set,先确保外层 key 存在,再设置内层值。

影响范围:
- SetAntigravityQuotaScopeLimit: Antigravity 平台按模型 scope 限流
- SetModelRateLimit: Anthropic 平台 Sonnet 模型限流
2026-01-26 23:40:48 +08:00
song
0059a232a6 feat(gemini): 为 Gemini 原生平台添加图片计费支持
对齐 Antigravity 平台的图片计费逻辑:
- 添加 extractImageSize() 方法提取图片尺寸
- Forward() 和 ForwardNative() 返回 ImageCount/ImageSize
- 支持分组自定义图片价格和倍率
2026-01-26 20:51:40 +08:00
cyhhao
a161fcc89b Merge branch 'main' of github.com:Wei-Shaw/sub2api 2026-01-26 10:44:38 +08:00
song
e316a923d4 fix(ops): count failover kinds with suffix 2026-01-24 01:14:44 +08:00
song
fd0370c07a Add invalid-request fallback routing 2026-01-23 22:24:46 +08:00
song
316f2fee21 feat(ops): add account switch metrics and trend 2026-01-23 19:39:48 +08:00
song
3002c7a17f Clamp Claude maxOutputTokens to 64000 2026-01-23 10:44:21 +08:00
song
207e09500a feat(antigravity): 支持按模型类型配置重试次数
新增环境变量:
- GATEWAY_ANTIGRAVITY_MAX_RETRIES_CLAUDE
- GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_TEXT
- GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_IMAGE

未设置时回退到平台级 GATEWAY_ANTIGRAVITY_MAX_RETRIES
2026-01-21 20:48:36 +08:00
song
52c745bc62 feat: 为 Antigravity 平台启用拦截预热请求功能 2026-01-21 20:40:45 +08:00
0xff26b9a8
498c6cfae9 fix: 修复 schema 清理逻辑 2026-01-21 12:08:16 +08:00
0xff26b9a8
71f8b9e473 refactor(antigravity): 提取并同步 Schema 清理逻辑至 schema_cleaner.go
主要变更:
1. 重构代码结构:
   - 将 CleanJSONSchema 及其相关辅助函数从 request_transformer.go 提取到独立的 schema_cleaner.go 文件中,实现逻辑解耦。

2. 逻辑优化与修正:
   - 参考 Antigravity-Manager (json_schema.rs) 的实现逻辑,修正了 Schema 清洗策略。
2026-01-21 12:08:16 +08:00
song
3a31fa4768 fix: 429 限流时更新账号 last_used_at
在设置限流标记时同时更新 last_used_at,使得刚触发 429 的账号
在后续调度中优先级降低,让其他账号有更多被选中的机会。
2026-01-21 11:50:38 +08:00
cyhhao
65e69738cc Merge branch 'main' of github.com:Wei-Shaw/sub2api 2026-01-20 22:46:23 +08:00
song
549c134bb8 chore: gofmt antigravity gateway service 2026-01-20 19:16:43 +08:00
song
d206721fc1 feat: make antigravity max retries configurable 2026-01-20 19:12:19 +08:00
song
64795a03e3 新增账号凭证邮箱查询接口 2026-01-20 14:17:10 +08:00
cyhhao
c8e2f614fa Merge branch 'main' of github.com:Wei-Shaw/sub2api 2026-01-20 13:53:32 +08:00
song
86d63f919d feat(antigravity): 支持秒级 fallback 冷却时间 2026-01-20 11:38:40 +08:00
song
c43aa22cdb feat(antigravity): 支持按映射模型计费 2026-01-20 11:02:08 +08:00
song
d1a6303e49 fix(antigravity): 修复 Claude 非流式响应丢失 2026-01-20 00:52:27 +08:00
Call White
c0347cde85 Merge pull request #2 from cyhhao/fix/claude-oauth-compat
Fix/claude oauth compat
2026-01-19 16:56:49 +08:00
cyhhao
2f2e76f9c6 fix(gateway): gate streaming tool rewrites behind mimic 2026-01-19 16:20:24 +08:00
cyhhao
dd7f21244b merge: resolve conflicts with main 2026-01-19 15:35:31 +08:00
cyhhao
49be9d08f3 fix(网关): OAuth 请求统一 user_id 与指纹 2026-01-19 15:02:00 +08:00
cyhhao
bba5b3c037 fix(网关): OAuth 请求统一 user_id 与指纹 2026-01-19 15:01:32 +08:00
cyhhao
26298c4a5f fix(openai): emit OpenAI-compatible SSE error events 2026-01-19 13:53:39 +08:00
cyhhao
eb7d830296 fix(网关): 修复流式 tool 输入参数转换 2026-01-19 03:57:33 +08:00
cyhhao
02db4c7671 fix(网关): 修复流式 tool 输入参数转换 2026-01-19 03:53:08 +08:00
cyhhao
a05b8b56e3 fix(网关): SSE 缓冲 input_json_delta 反向转换 2026-01-19 03:46:31 +08:00
cyhhao
eca3898410 fix(网关): SSE 缓冲 input_json_delta 反向转换 2026-01-19 03:46:09 +08:00
cyhhao
6901b64fce merge: sync upstream changes 2026-01-17 18:30:16 +08:00
cyhhao
32c47b1509 fix(gateway): satisfy golangci-lint checks 2026-01-17 18:16:34 +08:00
Call White
39e430018b Merge pull request #1 from cyhhao/fix/responses-stream-cancel
fix(gateway): avoid invalid SSE error on canceled stream
2026-01-17 17:15:35 +08:00
nick8802754751
6549a40cf4 refactor: 使用 ConfirmDialog 组件替代原生 confirm() 对话框
- EditAccountModal 和 CreateAccountModal 使用 ConfirmDialog 组件
- 保存 pending payload 供确认后重试
- 添加 mixedChannelWarningTitle 翻译

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
2026-01-17 16:16:47 +08:00
nick8802754751
4e75d8fda9 fix: 添加混合渠道警告确认框和过滤 prompt_cache_retention 参数
- 前端: EditAccountModal 和 CreateAccountModal 添加 409 mixed_channel_warning 处理
- 前端: 弹出确认框让用户确认混合渠道风险
- 后端: 过滤 OpenAI 请求中的 prompt_cache_retention 参数(上游不支持)
- 添加中英文翻译

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
2026-01-17 16:06:44 +08:00
cyhhao
8917a3ea8f fix(网关): 修复 golangci-lint 2026-01-17 00:27:36 +08:00
cyhhao
0c011b889b fix(网关): Claude Code OAuth 补齐 oauth beta 2026-01-16 23:38:09 +08:00
cyhhao
0962ba43c0 fix(网关): 补齐非 Claude Code OAuth 兼容 2026-01-16 23:38:09 +08:00
cyhhao
b8c48fb477 fix(网关): 区分 Claude Code OAuth 适配 2026-01-16 23:38:09 +08:00
cyhhao
2a7d04fec4 fix(网关): 对齐 Claude OAuth 请求适配 2026-01-16 23:36:57 +08:00
cyhhao
bd854e1750 fix(网关): Claude Code OAuth 补齐 oauth beta 2026-01-16 23:15:52 +08:00
cyhhao
65fd0d15ae fix(网关): 补齐非 Claude Code OAuth 兼容 2026-01-16 00:42:31 +08:00
cyhhao
c11f14f3a0 fix(gateway): drain upstream after client disconnect 2026-01-15 21:51:14 +08:00
cyhhao
98b65e67f2 fix(gateway): avoid injecting invalid SSE on client cancel 2026-01-15 21:42:13 +08:00
cyhhao
c579439c1e fix(网关): 区分 Claude Code OAuth 适配 2026-01-15 19:17:07 +08:00
cyhhao
46e5ac9672 fix(网关): 对齐 Claude OAuth 请求适配 2026-01-15 18:54:42 +08:00
243 changed files with 24267 additions and 8642 deletions

View File

@@ -19,7 +19,7 @@ jobs:
cache: true
- name: Verify Go version
run: |
go version | grep -q 'go1.25.5'
go version | grep -q 'go1.25.6'
- name: Unit tests
working-directory: backend
run: make test-unit
@@ -38,7 +38,7 @@ jobs:
cache: true
- name: Verify Go version
run: |
go version | grep -q 'go1.25.5'
go version | grep -q 'go1.25.6'
- name: golangci-lint
uses: golangci/golangci-lint-action@v9
with:

View File

@@ -115,7 +115,7 @@ jobs:
- name: Verify Go version
run: |
go version | grep -q 'go1.25.5'
go version | grep -q 'go1.25.6'
# Docker setup for GoReleaser
- name: Set up QEMU

View File

@@ -22,7 +22,7 @@ jobs:
cache-dependency-path: backend/go.sum
- name: Verify Go version
run: |
go version | grep -q 'go1.25.5'
go version | grep -q 'go1.25.6'
- name: Run govulncheck
working-directory: backend
run: |

View File

@@ -7,7 +7,7 @@
# =============================================================================
ARG NODE_IMAGE=node:24-alpine
ARG GOLANG_IMAGE=golang:1.25.5-alpine
ARG GOLANG_IMAGE=golang:1.25.6-alpine
ARG ALPINE_IMAGE=alpine:3.20
ARG GOPROXY=https://goproxy.cn,direct
ARG GOSUMDB=sum.golang.google.cn

128
README.md
View File

@@ -128,7 +128,7 @@ curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install
---
### Method 2: Docker Compose
### Method 2: Docker Compose (Recommended)
Deploy with Docker Compose, including PostgreSQL and Redis containers.
@@ -137,87 +137,157 @@ Deploy with Docker Compose, including PostgreSQL and Redis containers.
- Docker 20.10+
- Docker Compose v2+
#### Installation Steps
#### Quick Start (One-Click Deployment)
Use the automated deployment script for easy setup:
```bash
# Create deployment directory
mkdir -p sub2api-deploy && cd sub2api-deploy
# Download and run deployment preparation script
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
# Start services
docker-compose -f docker-compose.local.yml up -d
# View logs
docker-compose -f docker-compose.local.yml logs -f sub2api
```
**What the script does:**
- Downloads `docker-compose.local.yml` and `.env.example`
- Generates secure credentials (JWT_SECRET, TOTP_ENCRYPTION_KEY, POSTGRES_PASSWORD)
- Creates `.env` file with auto-generated secrets
- Creates data directories (uses local directories for easy backup/migration)
- Displays generated credentials for your reference
#### Manual Deployment
If you prefer manual setup:
```bash
# 1. Clone the repository
git clone https://github.com/Wei-Shaw/sub2api.git
cd sub2api
cd sub2api/deploy
# 2. Enter the deploy directory
cd deploy
# 3. Copy environment configuration
# 2. Copy environment configuration
cp .env.example .env
# 4. Edit configuration (set your passwords)
# 3. Edit configuration (generate secure passwords)
nano .env
```
**Required configuration in `.env`:**
```bash
# PostgreSQL password (REQUIRED - change this!)
# PostgreSQL password (REQUIRED)
POSTGRES_PASSWORD=your_secure_password_here
# JWT Secret (RECOMMENDED - keeps users logged in after restart)
JWT_SECRET=your_jwt_secret_here
# TOTP Encryption Key (RECOMMENDED - preserves 2FA after restart)
TOTP_ENCRYPTION_KEY=your_totp_key_here
# Optional: Admin account
ADMIN_EMAIL=admin@example.com
ADMIN_PASSWORD=your_admin_password
# Optional: Custom port
SERVER_PORT=8080
```
# Optional: Security configuration
# Enable URL allowlist validation (false to skip allowlist checks, only basic format validation)
SECURITY_URL_ALLOWLIST_ENABLED=false
**Generate secure secrets:**
```bash
# Generate JWT_SECRET
openssl rand -hex 32
# Allow insecure HTTP URLs when allowlist is disabled (default: false, requires https)
# ⚠️ WARNING: Enabling this allows HTTP (plaintext) URLs which can expose API keys
# Only recommended for:
# - Development/testing environments
# - Internal networks with trusted endpoints
# - When using local test servers (http://localhost)
# PRODUCTION: Keep this false or use HTTPS URLs only
SECURITY_URL_ALLOWLIST_ALLOW_INSECURE_HTTP=false
# Generate TOTP_ENCRYPTION_KEY
openssl rand -hex 32
# Allow private IP addresses for upstream/pricing/CRS (for internal deployments)
SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS=false
# Generate POSTGRES_PASSWORD
openssl rand -hex 32
```
```bash
# 4. Create data directories (for local version)
mkdir -p data postgres_data redis_data
# 5. Start all services
# Option A: Local directory version (recommended - easy migration)
docker-compose -f docker-compose.local.yml up -d
# Option B: Named volumes version (simple setup)
docker-compose up -d
# 6. Check status
docker-compose ps
docker-compose -f docker-compose.local.yml ps
# 7. View logs
docker-compose logs -f sub2api
docker-compose -f docker-compose.local.yml logs -f sub2api
```
#### Deployment Versions
| Version | Data Storage | Migration | Best For |
|---------|-------------|-----------|----------|
| **docker-compose.local.yml** | Local directories | ✅ Easy (tar entire directory) | Production, frequent backups |
| **docker-compose.yml** | Named volumes | ⚠️ Requires docker commands | Simple setup |
**Recommendation:** Use `docker-compose.local.yml` (deployed by script) for easier data management.
#### Access
Open `http://YOUR_SERVER_IP:8080` in your browser.
If admin password was auto-generated, find it in logs:
```bash
docker-compose -f docker-compose.local.yml logs sub2api | grep "admin password"
```
#### Upgrade
```bash
# Pull latest image and recreate container
docker-compose pull
docker-compose up -d
docker-compose -f docker-compose.local.yml pull
docker-compose -f docker-compose.local.yml up -d
```
#### Easy Migration (Local Directory Version)
When using `docker-compose.local.yml`, migrate to a new server easily:
```bash
# On source server
docker-compose -f docker-compose.local.yml down
cd ..
tar czf sub2api-complete.tar.gz sub2api-deploy/
# Transfer to new server
scp sub2api-complete.tar.gz user@new-server:/path/
# On new server
tar xzf sub2api-complete.tar.gz
cd sub2api-deploy/
docker-compose -f docker-compose.local.yml up -d
```
#### Useful Commands
```bash
# Stop all services
docker-compose down
docker-compose -f docker-compose.local.yml down
# Restart
docker-compose restart
docker-compose -f docker-compose.local.yml restart
# View all logs
docker-compose logs -f
docker-compose -f docker-compose.local.yml logs -f
# Remove all data (caution!)
docker-compose -f docker-compose.local.yml down
rm -rf data/ postgres_data/ redis_data/
```
---

View File

@@ -135,7 +135,7 @@ curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install
---
### 方式二Docker Compose
### 方式二Docker Compose(推荐)
使用 Docker Compose 部署,包含 PostgreSQL 和 Redis 容器。
@@ -144,87 +144,157 @@ curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install
- Docker 20.10+
- Docker Compose v2+
#### 安装步骤
#### 快速开始(一键部署)
使用自动化部署脚本快速搭建:
```bash
# 创建部署目录
mkdir -p sub2api-deploy && cd sub2api-deploy
# 下载并运行部署准备脚本
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
# 启动服务
docker-compose -f docker-compose.local.yml up -d
# 查看日志
docker-compose -f docker-compose.local.yml logs -f sub2api
```
**脚本功能:**
- 下载 `docker-compose.local.yml``.env.example`
- 自动生成安全凭证JWT_SECRET、TOTP_ENCRYPTION_KEY、POSTGRES_PASSWORD
- 创建 `.env` 文件并填充自动生成的密钥
- 创建数据目录(使用本地目录,便于备份和迁移)
- 显示生成的凭证供你记录
#### 手动部署
如果你希望手动配置:
```bash
# 1. 克隆仓库
git clone https://github.com/Wei-Shaw/sub2api.git
cd sub2api
cd sub2api/deploy
# 2. 进入 deploy 目录
cd deploy
# 3. 复制环境配置文件
# 2. 复制环境配置文件
cp .env.example .env
# 4. 编辑配置(设置密码
# 3. 编辑配置(生成安全密码)
nano .env
```
**`.env` 必须配置项:**
```bash
# PostgreSQL 密码(必须修改!
# PostgreSQL 密码(必
POSTGRES_PASSWORD=your_secure_password_here
# JWT 密钥(推荐 - 重启后保持用户登录状态)
JWT_SECRET=your_jwt_secret_here
# TOTP 加密密钥(推荐 - 重启后保留双因素认证)
TOTP_ENCRYPTION_KEY=your_totp_key_here
# 可选:管理员账号
ADMIN_EMAIL=admin@example.com
ADMIN_PASSWORD=your_admin_password
# 可选:自定义端口
SERVER_PORT=8080
```
# 可选:安全配置
# 启用 URL 白名单验证false 则跳过白名单检查,仅做基本格式校验)
SECURITY_URL_ALLOWLIST_ENABLED=false
**生成安全密钥:**
```bash
# 生成 JWT_SECRET
openssl rand -hex 32
# 关闭白名单时,是否允许 http:// URL默认 false只允许 https://
# ⚠️ 警告:允许 HTTP 会暴露 API 密钥(明文传输)
# 仅建议在以下场景使用:
# - 开发/测试环境
# - 内部可信网络
# - 本地测试服务器http://localhost
# 生产环境:保持 false 或仅使用 HTTPS URL
SECURITY_URL_ALLOWLIST_ALLOW_INSECURE_HTTP=false
# 生成 TOTP_ENCRYPTION_KEY
openssl rand -hex 32
# 是否允许私有 IP 地址用于上游/定价/CRS内网部署时使用
SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS=false
# 生成 POSTGRES_PASSWORD
openssl rand -hex 32
```
```bash
# 4. 创建数据目录(本地版)
mkdir -p data postgres_data redis_data
# 5. 启动所有服务
# 选项 A本地目录版推荐 - 易于迁移)
docker-compose -f docker-compose.local.yml up -d
# 选项 B命名卷版简单设置
docker-compose up -d
# 6. 查看状态
docker-compose ps
docker-compose -f docker-compose.local.yml ps
# 7. 查看日志
docker-compose logs -f sub2api
docker-compose -f docker-compose.local.yml logs -f sub2api
```
#### 部署版本对比
| 版本 | 数据存储 | 迁移便利性 | 适用场景 |
|------|---------|-----------|---------|
| **docker-compose.local.yml** | 本地目录 | ✅ 简单(打包整个目录) | 生产环境、频繁备份 |
| **docker-compose.yml** | 命名卷 | ⚠️ 需要 docker 命令 | 简单设置 |
**推荐:** 使用 `docker-compose.local.yml`(脚本部署)以便更轻松地管理数据。
#### 访问
在浏览器中打开 `http://你的服务器IP:8080`
如果管理员密码是自动生成的,在日志中查找:
```bash
docker-compose -f docker-compose.local.yml logs sub2api | grep "admin password"
```
#### 升级
```bash
# 拉取最新镜像并重建容器
docker-compose pull
docker-compose up -d
docker-compose -f docker-compose.local.yml pull
docker-compose -f docker-compose.local.yml up -d
```
#### 轻松迁移(本地目录版)
使用 `docker-compose.local.yml` 时,可以轻松迁移到新服务器:
```bash
# 源服务器
docker-compose -f docker-compose.local.yml down
cd ..
tar czf sub2api-complete.tar.gz sub2api-deploy/
# 传输到新服务器
scp sub2api-complete.tar.gz user@new-server:/path/
# 新服务器
tar xzf sub2api-complete.tar.gz
cd sub2api-deploy/
docker-compose -f docker-compose.local.yml up -d
```
#### 常用命令
```bash
# 停止所有服务
docker-compose down
docker-compose -f docker-compose.local.yml down
# 重启
docker-compose restart
docker-compose -f docker-compose.local.yml restart
# 查看所有日志
docker-compose logs -f
docker-compose -f docker-compose.local.yml logs -f
# 删除所有数据(谨慎!)
docker-compose -f docker-compose.local.yml down
rm -rf data/ postgres_data/ redis_data/
```
---

View File

@@ -1,4 +1,4 @@
FROM golang:1.25.5-alpine
FROM golang:1.25.6-alpine
WORKDIR /app
@@ -15,7 +15,7 @@ RUN go mod download
COPY . .
# 构建应用
RUN go build -o main cmd/server/main.go
RUN go build -o main ./cmd/server/
# 暴露端口
EXPOSE 8080

View File

@@ -33,7 +33,7 @@ func main() {
}()
userRepo := repository.NewUserRepository(client, sqlDB)
authService := service.NewAuthService(userRepo, cfg, nil, nil, nil, nil, nil)
authService := service.NewAuthService(userRepo, nil, cfg, nil, nil, nil, nil, nil)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

View File

@@ -1 +1 @@
0.1.46
0.1.61

View File

@@ -43,6 +43,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
return nil, err
}
userRepository := repository.NewUserRepository(client, db)
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
settingRepository := repository.NewSettingRepository(client)
settingService := service.NewSettingService(settingRepository, configConfig)
redisClient := repository.ProvideRedis(configConfig)
@@ -61,26 +62,29 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
authService := service.NewAuthService(userRepository, redeemCodeRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
if err != nil {
return nil, err
}
totpCache := repository.NewTotpCache(redisClient)
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, totpService)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
userHandler := handler.NewUserHandler(userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
announcementRepository := repository.NewAnnouncementRepository(client)
announcementReadRepository := repository.NewAnnouncementReadRepository(client)
announcementService := service.NewAnnouncementService(announcementRepository, announcementReadRepository, userRepository, userSubscriptionRepository)
announcementHandler := handler.NewAnnouncementHandler(announcementService)
dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig)
dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig)
@@ -128,6 +132,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator)
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
@@ -167,12 +172,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
totpHandler := handler.NewTotpHandler(totpService)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
@@ -183,7 +188,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)

249
backend/ent/announcement.go Normal file
View File

@@ -0,0 +1,249 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"encoding/json"
"fmt"
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/internal/domain"
)
// Announcement is the model entity for the Announcement schema.
type Announcement struct {
config `json:"-"`
// ID of the ent.
ID int64 `json:"id,omitempty"`
// 公告标题
Title string `json:"title,omitempty"`
// 公告内容(支持 Markdown
Content string `json:"content,omitempty"`
// 状态: draft, active, archived
Status string `json:"status,omitempty"`
// 展示条件JSON 规则)
Targeting domain.AnnouncementTargeting `json:"targeting,omitempty"`
// 开始展示时间(为空表示立即生效)
StartsAt *time.Time `json:"starts_at,omitempty"`
// 结束展示时间(为空表示永久生效)
EndsAt *time.Time `json:"ends_at,omitempty"`
// 创建人用户ID管理员
CreatedBy *int64 `json:"created_by,omitempty"`
// 更新人用户ID管理员
UpdatedBy *int64 `json:"updated_by,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// UpdatedAt holds the value of the "updated_at" field.
UpdatedAt time.Time `json:"updated_at,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the AnnouncementQuery when eager-loading is set.
Edges AnnouncementEdges `json:"edges"`
selectValues sql.SelectValues
}
// AnnouncementEdges holds the relations/edges for other nodes in the graph.
type AnnouncementEdges struct {
// Reads holds the value of the reads edge.
Reads []*AnnouncementRead `json:"reads,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [1]bool
}
// ReadsOrErr returns the Reads value or an error if the edge
// was not loaded in eager-loading.
func (e AnnouncementEdges) ReadsOrErr() ([]*AnnouncementRead, error) {
if e.loadedTypes[0] {
return e.Reads, nil
}
return nil, &NotLoadedError{edge: "reads"}
}
// scanValues returns the types for scanning values from sql.Rows.
func (*Announcement) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case announcement.FieldTargeting:
values[i] = new([]byte)
case announcement.FieldID, announcement.FieldCreatedBy, announcement.FieldUpdatedBy:
values[i] = new(sql.NullInt64)
case announcement.FieldTitle, announcement.FieldContent, announcement.FieldStatus:
values[i] = new(sql.NullString)
case announcement.FieldStartsAt, announcement.FieldEndsAt, announcement.FieldCreatedAt, announcement.FieldUpdatedAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
}
}
return values, nil
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the Announcement fields.
func (_m *Announcement) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
for i := range columns {
switch columns[i] {
case announcement.FieldID:
value, ok := values[i].(*sql.NullInt64)
if !ok {
return fmt.Errorf("unexpected type %T for field id", value)
}
_m.ID = int64(value.Int64)
case announcement.FieldTitle:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field title", values[i])
} else if value.Valid {
_m.Title = value.String
}
case announcement.FieldContent:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field content", values[i])
} else if value.Valid {
_m.Content = value.String
}
case announcement.FieldStatus:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field status", values[i])
} else if value.Valid {
_m.Status = value.String
}
case announcement.FieldTargeting:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field targeting", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.Targeting); err != nil {
return fmt.Errorf("unmarshal field targeting: %w", err)
}
}
case announcement.FieldStartsAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field starts_at", values[i])
} else if value.Valid {
_m.StartsAt = new(time.Time)
*_m.StartsAt = value.Time
}
case announcement.FieldEndsAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field ends_at", values[i])
} else if value.Valid {
_m.EndsAt = new(time.Time)
*_m.EndsAt = value.Time
}
case announcement.FieldCreatedBy:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field created_by", values[i])
} else if value.Valid {
_m.CreatedBy = new(int64)
*_m.CreatedBy = value.Int64
}
case announcement.FieldUpdatedBy:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field updated_by", values[i])
} else if value.Valid {
_m.UpdatedBy = new(int64)
*_m.UpdatedBy = value.Int64
}
case announcement.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
} else if value.Valid {
_m.CreatedAt = value.Time
}
case announcement.FieldUpdatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
} else if value.Valid {
_m.UpdatedAt = value.Time
}
default:
_m.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// Value returns the ent.Value that was dynamically selected and assigned to the Announcement.
// This includes values selected through modifiers, order, etc.
func (_m *Announcement) Value(name string) (ent.Value, error) {
return _m.selectValues.Get(name)
}
// QueryReads queries the "reads" edge of the Announcement entity.
func (_m *Announcement) QueryReads() *AnnouncementReadQuery {
return NewAnnouncementClient(_m.config).QueryReads(_m)
}
// Update returns a builder for updating this Announcement.
// Note that you need to call Announcement.Unwrap() before calling this method if this Announcement
// was returned from a transaction, and the transaction was committed or rolled back.
func (_m *Announcement) Update() *AnnouncementUpdateOne {
return NewAnnouncementClient(_m.config).UpdateOne(_m)
}
// Unwrap unwraps the Announcement entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func (_m *Announcement) Unwrap() *Announcement {
_tx, ok := _m.config.driver.(*txDriver)
if !ok {
panic("ent: Announcement is not a transactional entity")
}
_m.config.driver = _tx.drv
return _m
}
// String implements the fmt.Stringer.
func (_m *Announcement) String() string {
var builder strings.Builder
builder.WriteString("Announcement(")
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
builder.WriteString("title=")
builder.WriteString(_m.Title)
builder.WriteString(", ")
builder.WriteString("content=")
builder.WriteString(_m.Content)
builder.WriteString(", ")
builder.WriteString("status=")
builder.WriteString(_m.Status)
builder.WriteString(", ")
builder.WriteString("targeting=")
builder.WriteString(fmt.Sprintf("%v", _m.Targeting))
builder.WriteString(", ")
if v := _m.StartsAt; v != nil {
builder.WriteString("starts_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
if v := _m.EndsAt; v != nil {
builder.WriteString("ends_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
if v := _m.CreatedBy; v != nil {
builder.WriteString("created_by=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
if v := _m.UpdatedBy; v != nil {
builder.WriteString("updated_by=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("updated_at=")
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
builder.WriteByte(')')
return builder.String()
}
// Announcements is a parsable slice of Announcement.
type Announcements []*Announcement

View File

@@ -0,0 +1,164 @@
// Code generated by ent, DO NOT EDIT.
package announcement
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
const (
// Label holds the string label denoting the announcement type in the database.
Label = "announcement"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldTitle holds the string denoting the title field in the database.
FieldTitle = "title"
// FieldContent holds the string denoting the content field in the database.
FieldContent = "content"
// FieldStatus holds the string denoting the status field in the database.
FieldStatus = "status"
// FieldTargeting holds the string denoting the targeting field in the database.
FieldTargeting = "targeting"
// FieldStartsAt holds the string denoting the starts_at field in the database.
FieldStartsAt = "starts_at"
// FieldEndsAt holds the string denoting the ends_at field in the database.
FieldEndsAt = "ends_at"
// FieldCreatedBy holds the string denoting the created_by field in the database.
FieldCreatedBy = "created_by"
// FieldUpdatedBy holds the string denoting the updated_by field in the database.
FieldUpdatedBy = "updated_by"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
FieldUpdatedAt = "updated_at"
// EdgeReads holds the string denoting the reads edge name in mutations.
EdgeReads = "reads"
// Table holds the table name of the announcement in the database.
Table = "announcements"
// ReadsTable is the table that holds the reads relation/edge.
ReadsTable = "announcement_reads"
// ReadsInverseTable is the table name for the AnnouncementRead entity.
// It exists in this package in order to avoid circular dependency with the "announcementread" package.
ReadsInverseTable = "announcement_reads"
// ReadsColumn is the table column denoting the reads relation/edge.
ReadsColumn = "announcement_id"
)
// Columns holds all SQL columns for announcement fields.
var Columns = []string{
FieldID,
FieldTitle,
FieldContent,
FieldStatus,
FieldTargeting,
FieldStartsAt,
FieldEndsAt,
FieldCreatedBy,
FieldUpdatedBy,
FieldCreatedAt,
FieldUpdatedAt,
}
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
if column == Columns[i] {
return true
}
}
return false
}
var (
// TitleValidator is a validator for the "title" field. It is called by the builders before save.
TitleValidator func(string) error
// ContentValidator is a validator for the "content" field. It is called by the builders before save.
ContentValidator func(string) error
// DefaultStatus holds the default value on creation for the "status" field.
DefaultStatus string
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
StatusValidator func(string) error
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
DefaultUpdatedAt func() time.Time
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
UpdateDefaultUpdatedAt func() time.Time
)
// OrderOption defines the ordering options for the Announcement queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByTitle orders the results by the title field.
func ByTitle(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTitle, opts...).ToFunc()
}
// ByContent orders the results by the content field.
func ByContent(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldContent, opts...).ToFunc()
}
// ByStatus orders the results by the status field.
func ByStatus(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStatus, opts...).ToFunc()
}
// ByStartsAt orders the results by the starts_at field.
func ByStartsAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStartsAt, opts...).ToFunc()
}
// ByEndsAt orders the results by the ends_at field.
func ByEndsAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldEndsAt, opts...).ToFunc()
}
// ByCreatedBy orders the results by the created_by field.
func ByCreatedBy(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedBy, opts...).ToFunc()
}
// ByUpdatedBy orders the results by the updated_by field.
func ByUpdatedBy(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedBy, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByUpdatedAt orders the results by the updated_at field.
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// ByReadsCount orders the results by reads count.
func ByReadsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newReadsStep(), opts...)
}
}
// ByReads orders the results by reads terms.
func ByReads(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newReadsStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
func newReadsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(ReadsInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, ReadsTable, ReadsColumn),
)
}

View File

@@ -0,0 +1,624 @@
// Code generated by ent, DO NOT EDIT.
package announcement
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// ID filters vertices based on their ID field.
func ID(id int64) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldID, id))
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id int64) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldID, id))
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id int64) predicate.Announcement {
return predicate.Announcement(sql.FieldNEQ(FieldID, id))
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...int64) predicate.Announcement {
return predicate.Announcement(sql.FieldIn(FieldID, ids...))
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...int64) predicate.Announcement {
return predicate.Announcement(sql.FieldNotIn(FieldID, ids...))
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id int64) predicate.Announcement {
return predicate.Announcement(sql.FieldGT(FieldID, id))
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id int64) predicate.Announcement {
return predicate.Announcement(sql.FieldGTE(FieldID, id))
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id int64) predicate.Announcement {
return predicate.Announcement(sql.FieldLT(FieldID, id))
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id int64) predicate.Announcement {
return predicate.Announcement(sql.FieldLTE(FieldID, id))
}
// Title applies equality check predicate on the "title" field. It's identical to TitleEQ.
func Title(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldTitle, v))
}
// Content applies equality check predicate on the "content" field. It's identical to ContentEQ.
func Content(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldContent, v))
}
// Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
func Status(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldStatus, v))
}
// StartsAt applies equality check predicate on the "starts_at" field. It's identical to StartsAtEQ.
func StartsAt(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldStartsAt, v))
}
// EndsAt applies equality check predicate on the "ends_at" field. It's identical to EndsAtEQ.
func EndsAt(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldEndsAt, v))
}
// CreatedBy applies equality check predicate on the "created_by" field. It's identical to CreatedByEQ.
func CreatedBy(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldCreatedBy, v))
}
// UpdatedBy applies equality check predicate on the "updated_by" field. It's identical to UpdatedByEQ.
func UpdatedBy(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldUpdatedBy, v))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldCreatedAt, v))
}
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
func UpdatedAt(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldUpdatedAt, v))
}
// TitleEQ applies the EQ predicate on the "title" field.
func TitleEQ(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldTitle, v))
}
// TitleNEQ applies the NEQ predicate on the "title" field.
func TitleNEQ(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldNEQ(FieldTitle, v))
}
// TitleIn applies the In predicate on the "title" field.
func TitleIn(vs ...string) predicate.Announcement {
return predicate.Announcement(sql.FieldIn(FieldTitle, vs...))
}
// TitleNotIn applies the NotIn predicate on the "title" field.
func TitleNotIn(vs ...string) predicate.Announcement {
return predicate.Announcement(sql.FieldNotIn(FieldTitle, vs...))
}
// TitleGT applies the GT predicate on the "title" field.
func TitleGT(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldGT(FieldTitle, v))
}
// TitleGTE applies the GTE predicate on the "title" field.
func TitleGTE(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldGTE(FieldTitle, v))
}
// TitleLT applies the LT predicate on the "title" field.
func TitleLT(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldLT(FieldTitle, v))
}
// TitleLTE applies the LTE predicate on the "title" field.
func TitleLTE(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldLTE(FieldTitle, v))
}
// TitleContains applies the Contains predicate on the "title" field.
func TitleContains(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldContains(FieldTitle, v))
}
// TitleHasPrefix applies the HasPrefix predicate on the "title" field.
func TitleHasPrefix(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldHasPrefix(FieldTitle, v))
}
// TitleHasSuffix applies the HasSuffix predicate on the "title" field.
func TitleHasSuffix(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldHasSuffix(FieldTitle, v))
}
// TitleEqualFold applies the EqualFold predicate on the "title" field.
func TitleEqualFold(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldEqualFold(FieldTitle, v))
}
// TitleContainsFold applies the ContainsFold predicate on the "title" field.
func TitleContainsFold(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldContainsFold(FieldTitle, v))
}
// ContentEQ applies the EQ predicate on the "content" field.
func ContentEQ(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldContent, v))
}
// ContentNEQ applies the NEQ predicate on the "content" field.
func ContentNEQ(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldNEQ(FieldContent, v))
}
// ContentIn applies the In predicate on the "content" field.
func ContentIn(vs ...string) predicate.Announcement {
return predicate.Announcement(sql.FieldIn(FieldContent, vs...))
}
// ContentNotIn applies the NotIn predicate on the "content" field.
func ContentNotIn(vs ...string) predicate.Announcement {
return predicate.Announcement(sql.FieldNotIn(FieldContent, vs...))
}
// ContentGT applies the GT predicate on the "content" field.
func ContentGT(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldGT(FieldContent, v))
}
// ContentGTE applies the GTE predicate on the "content" field.
func ContentGTE(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldGTE(FieldContent, v))
}
// ContentLT applies the LT predicate on the "content" field.
func ContentLT(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldLT(FieldContent, v))
}
// ContentLTE applies the LTE predicate on the "content" field.
func ContentLTE(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldLTE(FieldContent, v))
}
// ContentContains applies the Contains predicate on the "content" field.
func ContentContains(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldContains(FieldContent, v))
}
// ContentHasPrefix applies the HasPrefix predicate on the "content" field.
func ContentHasPrefix(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldHasPrefix(FieldContent, v))
}
// ContentHasSuffix applies the HasSuffix predicate on the "content" field.
func ContentHasSuffix(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldHasSuffix(FieldContent, v))
}
// ContentEqualFold applies the EqualFold predicate on the "content" field.
func ContentEqualFold(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldEqualFold(FieldContent, v))
}
// ContentContainsFold applies the ContainsFold predicate on the "content" field.
func ContentContainsFold(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldContainsFold(FieldContent, v))
}
// StatusEQ applies the EQ predicate on the "status" field.
func StatusEQ(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldStatus, v))
}
// StatusNEQ applies the NEQ predicate on the "status" field.
func StatusNEQ(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldNEQ(FieldStatus, v))
}
// StatusIn applies the In predicate on the "status" field.
func StatusIn(vs ...string) predicate.Announcement {
return predicate.Announcement(sql.FieldIn(FieldStatus, vs...))
}
// StatusNotIn applies the NotIn predicate on the "status" field.
func StatusNotIn(vs ...string) predicate.Announcement {
return predicate.Announcement(sql.FieldNotIn(FieldStatus, vs...))
}
// StatusGT applies the GT predicate on the "status" field.
func StatusGT(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldGT(FieldStatus, v))
}
// StatusGTE applies the GTE predicate on the "status" field.
func StatusGTE(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldGTE(FieldStatus, v))
}
// StatusLT applies the LT predicate on the "status" field.
func StatusLT(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldLT(FieldStatus, v))
}
// StatusLTE applies the LTE predicate on the "status" field.
func StatusLTE(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldLTE(FieldStatus, v))
}
// StatusContains applies the Contains predicate on the "status" field.
func StatusContains(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldContains(FieldStatus, v))
}
// StatusHasPrefix applies the HasPrefix predicate on the "status" field.
func StatusHasPrefix(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldHasPrefix(FieldStatus, v))
}
// StatusHasSuffix applies the HasSuffix predicate on the "status" field.
func StatusHasSuffix(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldHasSuffix(FieldStatus, v))
}
// StatusEqualFold applies the EqualFold predicate on the "status" field.
func StatusEqualFold(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldEqualFold(FieldStatus, v))
}
// StatusContainsFold applies the ContainsFold predicate on the "status" field.
func StatusContainsFold(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldContainsFold(FieldStatus, v))
}
// TargetingIsNil applies the IsNil predicate on the "targeting" field.
func TargetingIsNil() predicate.Announcement {
return predicate.Announcement(sql.FieldIsNull(FieldTargeting))
}
// TargetingNotNil applies the NotNil predicate on the "targeting" field.
func TargetingNotNil() predicate.Announcement {
return predicate.Announcement(sql.FieldNotNull(FieldTargeting))
}
// StartsAtEQ applies the EQ predicate on the "starts_at" field.
func StartsAtEQ(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldStartsAt, v))
}
// StartsAtNEQ applies the NEQ predicate on the "starts_at" field.
func StartsAtNEQ(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldNEQ(FieldStartsAt, v))
}
// StartsAtIn applies the In predicate on the "starts_at" field.
func StartsAtIn(vs ...time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldIn(FieldStartsAt, vs...))
}
// StartsAtNotIn applies the NotIn predicate on the "starts_at" field.
func StartsAtNotIn(vs ...time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldNotIn(FieldStartsAt, vs...))
}
// StartsAtGT applies the GT predicate on the "starts_at" field.
func StartsAtGT(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldGT(FieldStartsAt, v))
}
// StartsAtGTE applies the GTE predicate on the "starts_at" field.
func StartsAtGTE(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldGTE(FieldStartsAt, v))
}
// StartsAtLT applies the LT predicate on the "starts_at" field.
func StartsAtLT(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldLT(FieldStartsAt, v))
}
// StartsAtLTE applies the LTE predicate on the "starts_at" field.
func StartsAtLTE(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldLTE(FieldStartsAt, v))
}
// StartsAtIsNil applies the IsNil predicate on the "starts_at" field.
func StartsAtIsNil() predicate.Announcement {
return predicate.Announcement(sql.FieldIsNull(FieldStartsAt))
}
// StartsAtNotNil applies the NotNil predicate on the "starts_at" field.
func StartsAtNotNil() predicate.Announcement {
return predicate.Announcement(sql.FieldNotNull(FieldStartsAt))
}
// EndsAtEQ applies the EQ predicate on the "ends_at" field.
func EndsAtEQ(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldEndsAt, v))
}
// EndsAtNEQ applies the NEQ predicate on the "ends_at" field.
func EndsAtNEQ(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldNEQ(FieldEndsAt, v))
}
// EndsAtIn applies the In predicate on the "ends_at" field.
func EndsAtIn(vs ...time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldIn(FieldEndsAt, vs...))
}
// EndsAtNotIn applies the NotIn predicate on the "ends_at" field.
func EndsAtNotIn(vs ...time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldNotIn(FieldEndsAt, vs...))
}
// EndsAtGT applies the GT predicate on the "ends_at" field.
func EndsAtGT(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldGT(FieldEndsAt, v))
}
// EndsAtGTE applies the GTE predicate on the "ends_at" field.
func EndsAtGTE(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldGTE(FieldEndsAt, v))
}
// EndsAtLT applies the LT predicate on the "ends_at" field.
func EndsAtLT(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldLT(FieldEndsAt, v))
}
// EndsAtLTE applies the LTE predicate on the "ends_at" field.
func EndsAtLTE(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldLTE(FieldEndsAt, v))
}
// EndsAtIsNil applies the IsNil predicate on the "ends_at" field.
func EndsAtIsNil() predicate.Announcement {
return predicate.Announcement(sql.FieldIsNull(FieldEndsAt))
}
// EndsAtNotNil applies the NotNil predicate on the "ends_at" field.
func EndsAtNotNil() predicate.Announcement {
return predicate.Announcement(sql.FieldNotNull(FieldEndsAt))
}
// CreatedByEQ applies the EQ predicate on the "created_by" field.
func CreatedByEQ(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldCreatedBy, v))
}
// CreatedByNEQ applies the NEQ predicate on the "created_by" field.
func CreatedByNEQ(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldNEQ(FieldCreatedBy, v))
}
// CreatedByIn applies the In predicate on the "created_by" field.
func CreatedByIn(vs ...int64) predicate.Announcement {
return predicate.Announcement(sql.FieldIn(FieldCreatedBy, vs...))
}
// CreatedByNotIn applies the NotIn predicate on the "created_by" field.
func CreatedByNotIn(vs ...int64) predicate.Announcement {
return predicate.Announcement(sql.FieldNotIn(FieldCreatedBy, vs...))
}
// CreatedByGT applies the GT predicate on the "created_by" field.
func CreatedByGT(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldGT(FieldCreatedBy, v))
}
// CreatedByGTE applies the GTE predicate on the "created_by" field.
func CreatedByGTE(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldGTE(FieldCreatedBy, v))
}
// CreatedByLT applies the LT predicate on the "created_by" field.
func CreatedByLT(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldLT(FieldCreatedBy, v))
}
// CreatedByLTE applies the LTE predicate on the "created_by" field.
func CreatedByLTE(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldLTE(FieldCreatedBy, v))
}
// CreatedByIsNil applies the IsNil predicate on the "created_by" field.
func CreatedByIsNil() predicate.Announcement {
return predicate.Announcement(sql.FieldIsNull(FieldCreatedBy))
}
// CreatedByNotNil applies the NotNil predicate on the "created_by" field.
func CreatedByNotNil() predicate.Announcement {
return predicate.Announcement(sql.FieldNotNull(FieldCreatedBy))
}
// UpdatedByEQ applies the EQ predicate on the "updated_by" field.
func UpdatedByEQ(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldUpdatedBy, v))
}
// UpdatedByNEQ applies the NEQ predicate on the "updated_by" field.
func UpdatedByNEQ(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldNEQ(FieldUpdatedBy, v))
}
// UpdatedByIn applies the In predicate on the "updated_by" field.
func UpdatedByIn(vs ...int64) predicate.Announcement {
return predicate.Announcement(sql.FieldIn(FieldUpdatedBy, vs...))
}
// UpdatedByNotIn applies the NotIn predicate on the "updated_by" field.
func UpdatedByNotIn(vs ...int64) predicate.Announcement {
return predicate.Announcement(sql.FieldNotIn(FieldUpdatedBy, vs...))
}
// UpdatedByGT applies the GT predicate on the "updated_by" field.
func UpdatedByGT(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldGT(FieldUpdatedBy, v))
}
// UpdatedByGTE applies the GTE predicate on the "updated_by" field.
func UpdatedByGTE(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldGTE(FieldUpdatedBy, v))
}
// UpdatedByLT applies the LT predicate on the "updated_by" field.
func UpdatedByLT(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldLT(FieldUpdatedBy, v))
}
// UpdatedByLTE applies the LTE predicate on the "updated_by" field.
func UpdatedByLTE(v int64) predicate.Announcement {
return predicate.Announcement(sql.FieldLTE(FieldUpdatedBy, v))
}
// UpdatedByIsNil applies the IsNil predicate on the "updated_by" field.
func UpdatedByIsNil() predicate.Announcement {
return predicate.Announcement(sql.FieldIsNull(FieldUpdatedBy))
}
// UpdatedByNotNil applies the NotNil predicate on the "updated_by" field.
func UpdatedByNotNil() predicate.Announcement {
return predicate.Announcement(sql.FieldNotNull(FieldUpdatedBy))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldCreatedAt, v))
}
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
func CreatedAtNEQ(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldNEQ(FieldCreatedAt, v))
}
// CreatedAtIn applies the In predicate on the "created_at" field.
func CreatedAtIn(vs ...time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldIn(FieldCreatedAt, vs...))
}
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
func CreatedAtNotIn(vs ...time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldNotIn(FieldCreatedAt, vs...))
}
// CreatedAtGT applies the GT predicate on the "created_at" field.
func CreatedAtGT(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldGT(FieldCreatedAt, v))
}
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
func CreatedAtGTE(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldGTE(FieldCreatedAt, v))
}
// CreatedAtLT applies the LT predicate on the "created_at" field.
func CreatedAtLT(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldLT(FieldCreatedAt, v))
}
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
func CreatedAtLTE(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldLTE(FieldCreatedAt, v))
}
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
func UpdatedAtEQ(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldUpdatedAt, v))
}
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
func UpdatedAtNEQ(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldNEQ(FieldUpdatedAt, v))
}
// UpdatedAtIn applies the In predicate on the "updated_at" field.
func UpdatedAtIn(vs ...time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldIn(FieldUpdatedAt, vs...))
}
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
func UpdatedAtNotIn(vs ...time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldNotIn(FieldUpdatedAt, vs...))
}
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
func UpdatedAtGT(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldGT(FieldUpdatedAt, v))
}
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
func UpdatedAtGTE(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldGTE(FieldUpdatedAt, v))
}
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
func UpdatedAtLT(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldLT(FieldUpdatedAt, v))
}
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
func UpdatedAtLTE(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldLTE(FieldUpdatedAt, v))
}
// HasReads applies the HasEdge predicate on the "reads" edge.
func HasReads() predicate.Announcement {
return predicate.Announcement(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, ReadsTable, ReadsColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasReadsWith applies the HasEdge predicate on the "reads" edge with a given conditions (other predicates).
func HasReadsWith(preds ...predicate.AnnouncementRead) predicate.Announcement {
return predicate.Announcement(func(s *sql.Selector) {
step := newReadsStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.Announcement) predicate.Announcement {
return predicate.Announcement(sql.AndPredicates(predicates...))
}
// Or groups predicates with the OR operator between them.
func Or(predicates ...predicate.Announcement) predicate.Announcement {
return predicate.Announcement(sql.OrPredicates(predicates...))
}
// Not applies the not operator on the given predicate.
func Not(p predicate.Announcement) predicate.Announcement {
return predicate.Announcement(sql.NotPredicates(p))
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,88 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// AnnouncementDelete is the builder for deleting a Announcement entity.
type AnnouncementDelete struct {
config
hooks []Hook
mutation *AnnouncementMutation
}
// Where appends a list predicates to the AnnouncementDelete builder.
func (_d *AnnouncementDelete) Where(ps ...predicate.Announcement) *AnnouncementDelete {
_d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (_d *AnnouncementDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *AnnouncementDelete) ExecX(ctx context.Context) int {
n, err := _d.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (_d *AnnouncementDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(announcement.Table, sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64))
if ps := _d.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
_d.mutation.done = true
return affected, err
}
// AnnouncementDeleteOne is the builder for deleting a single Announcement entity.
type AnnouncementDeleteOne struct {
_d *AnnouncementDelete
}
// Where appends a list predicates to the AnnouncementDelete builder.
func (_d *AnnouncementDeleteOne) Where(ps ...predicate.Announcement) *AnnouncementDeleteOne {
_d._d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query.
func (_d *AnnouncementDeleteOne) Exec(ctx context.Context) error {
n, err := _d._d.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{announcement.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *AnnouncementDeleteOne) ExecX(ctx context.Context) {
if err := _d.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,643 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"database/sql/driver"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// AnnouncementQuery is the builder for querying Announcement entities.
type AnnouncementQuery struct {
config
ctx *QueryContext
order []announcement.OrderOption
inters []Interceptor
predicates []predicate.Announcement
withReads *AnnouncementReadQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the AnnouncementQuery builder.
func (_q *AnnouncementQuery) Where(ps ...predicate.Announcement) *AnnouncementQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *AnnouncementQuery) Limit(limit int) *AnnouncementQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *AnnouncementQuery) Offset(offset int) *AnnouncementQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *AnnouncementQuery) Unique(unique bool) *AnnouncementQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *AnnouncementQuery) Order(o ...announcement.OrderOption) *AnnouncementQuery {
_q.order = append(_q.order, o...)
return _q
}
// QueryReads chains the current query on the "reads" edge.
func (_q *AnnouncementQuery) QueryReads() *AnnouncementReadQuery {
query := (&AnnouncementReadClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(announcement.Table, announcement.FieldID, selector),
sqlgraph.To(announcementread.Table, announcementread.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, announcement.ReadsTable, announcement.ReadsColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// First returns the first Announcement entity from the query.
// Returns a *NotFoundError when no Announcement was found.
func (_q *AnnouncementQuery) First(ctx context.Context) (*Announcement, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{announcement.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *AnnouncementQuery) FirstX(ctx context.Context) *Announcement {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first Announcement ID from the query.
// Returns a *NotFoundError when no Announcement ID was found.
func (_q *AnnouncementQuery) FirstID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{announcement.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *AnnouncementQuery) FirstIDX(ctx context.Context) int64 {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single Announcement entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one Announcement entity is found.
// Returns a *NotFoundError when no Announcement entities are found.
func (_q *AnnouncementQuery) Only(ctx context.Context) (*Announcement, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{announcement.Label}
default:
return nil, &NotSingularError{announcement.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *AnnouncementQuery) OnlyX(ctx context.Context) *Announcement {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only Announcement ID in the query.
// Returns a *NotSingularError when more than one Announcement ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *AnnouncementQuery) OnlyID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{announcement.Label}
default:
err = &NotSingularError{announcement.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *AnnouncementQuery) OnlyIDX(ctx context.Context) int64 {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of Announcements.
func (_q *AnnouncementQuery) All(ctx context.Context) ([]*Announcement, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*Announcement, *AnnouncementQuery]()
return withInterceptors[[]*Announcement](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *AnnouncementQuery) AllX(ctx context.Context) []*Announcement {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of Announcement IDs.
func (_q *AnnouncementQuery) IDs(ctx context.Context) (ids []int64, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(announcement.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *AnnouncementQuery) IDsX(ctx context.Context) []int64 {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *AnnouncementQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*AnnouncementQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *AnnouncementQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *AnnouncementQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *AnnouncementQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the AnnouncementQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *AnnouncementQuery) Clone() *AnnouncementQuery {
if _q == nil {
return nil
}
return &AnnouncementQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]announcement.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.Announcement{}, _q.predicates...),
withReads: _q.withReads.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// WithReads tells the query-builder to eager-load the nodes that are connected to
// the "reads" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *AnnouncementQuery) WithReads(opts ...func(*AnnouncementReadQuery)) *AnnouncementQuery {
query := (&AnnouncementReadClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withReads = query
return _q
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// Title string `json:"title,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.Announcement.Query().
// GroupBy(announcement.FieldTitle).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *AnnouncementQuery) GroupBy(field string, fields ...string) *AnnouncementGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &AnnouncementGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = announcement.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// Title string `json:"title,omitempty"`
// }
//
// client.Announcement.Query().
// Select(announcement.FieldTitle).
// Scan(ctx, &v)
func (_q *AnnouncementQuery) Select(fields ...string) *AnnouncementSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &AnnouncementSelect{AnnouncementQuery: _q}
sbuild.label = announcement.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a AnnouncementSelect configured with the given aggregations.
func (_q *AnnouncementQuery) Aggregate(fns ...AggregateFunc) *AnnouncementSelect {
return _q.Select().Aggregate(fns...)
}
func (_q *AnnouncementQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !announcement.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *AnnouncementQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Announcement, error) {
var (
nodes = []*Announcement{}
_spec = _q.querySpec()
loadedTypes = [1]bool{
_q.withReads != nil,
}
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*Announcement).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &Announcement{config: _q.config}
nodes = append(nodes, node)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
if query := _q.withReads; query != nil {
if err := _q.loadReads(ctx, query, nodes,
func(n *Announcement) { n.Edges.Reads = []*AnnouncementRead{} },
func(n *Announcement, e *AnnouncementRead) { n.Edges.Reads = append(n.Edges.Reads, e) }); err != nil {
return nil, err
}
}
return nodes, nil
}
func (_q *AnnouncementQuery) loadReads(ctx context.Context, query *AnnouncementReadQuery, nodes []*Announcement, init func(*Announcement), assign func(*Announcement, *AnnouncementRead)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*Announcement)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(announcementread.FieldAnnouncementID)
}
query.Where(predicate.AnnouncementRead(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(announcement.ReadsColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.AnnouncementID
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "announcement_id" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *AnnouncementQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *AnnouncementQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(announcement.Table, announcement.Columns, sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, announcement.FieldID)
for i := range fields {
if fields[i] != announcement.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *AnnouncementQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(announcement.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = announcement.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *AnnouncementQuery) ForUpdate(opts ...sql.LockOption) *AnnouncementQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *AnnouncementQuery) ForShare(opts ...sql.LockOption) *AnnouncementQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// AnnouncementGroupBy is the group-by builder for Announcement entities.
type AnnouncementGroupBy struct {
selector
build *AnnouncementQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *AnnouncementGroupBy) Aggregate(fns ...AggregateFunc) *AnnouncementGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *AnnouncementGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*AnnouncementQuery, *AnnouncementGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *AnnouncementGroupBy) sqlScan(ctx context.Context, root *AnnouncementQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// AnnouncementSelect is the builder for selecting fields of Announcement entities.
type AnnouncementSelect struct {
*AnnouncementQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *AnnouncementSelect) Aggregate(fns ...AggregateFunc) *AnnouncementSelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *AnnouncementSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*AnnouncementQuery, *AnnouncementSelect](ctx, _s.AnnouncementQuery, _s, _s.inters, v)
}
func (_s *AnnouncementSelect) sqlScan(ctx context.Context, root *AnnouncementQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}

View File

@@ -0,0 +1,824 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/internal/domain"
)
// AnnouncementUpdate is the builder for updating Announcement entities.
type AnnouncementUpdate struct {
config
hooks []Hook
mutation *AnnouncementMutation
}
// Where appends a list predicates to the AnnouncementUpdate builder.
func (_u *AnnouncementUpdate) Where(ps ...predicate.Announcement) *AnnouncementUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetTitle sets the "title" field.
func (_u *AnnouncementUpdate) SetTitle(v string) *AnnouncementUpdate {
_u.mutation.SetTitle(v)
return _u
}
// SetNillableTitle sets the "title" field if the given value is not nil.
func (_u *AnnouncementUpdate) SetNillableTitle(v *string) *AnnouncementUpdate {
if v != nil {
_u.SetTitle(*v)
}
return _u
}
// SetContent sets the "content" field.
func (_u *AnnouncementUpdate) SetContent(v string) *AnnouncementUpdate {
_u.mutation.SetContent(v)
return _u
}
// SetNillableContent sets the "content" field if the given value is not nil.
func (_u *AnnouncementUpdate) SetNillableContent(v *string) *AnnouncementUpdate {
if v != nil {
_u.SetContent(*v)
}
return _u
}
// SetStatus sets the "status" field.
func (_u *AnnouncementUpdate) SetStatus(v string) *AnnouncementUpdate {
_u.mutation.SetStatus(v)
return _u
}
// SetNillableStatus sets the "status" field if the given value is not nil.
func (_u *AnnouncementUpdate) SetNillableStatus(v *string) *AnnouncementUpdate {
if v != nil {
_u.SetStatus(*v)
}
return _u
}
// SetTargeting sets the "targeting" field.
func (_u *AnnouncementUpdate) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpdate {
_u.mutation.SetTargeting(v)
return _u
}
// SetNillableTargeting sets the "targeting" field if the given value is not nil.
func (_u *AnnouncementUpdate) SetNillableTargeting(v *domain.AnnouncementTargeting) *AnnouncementUpdate {
if v != nil {
_u.SetTargeting(*v)
}
return _u
}
// ClearTargeting clears the value of the "targeting" field.
func (_u *AnnouncementUpdate) ClearTargeting() *AnnouncementUpdate {
_u.mutation.ClearTargeting()
return _u
}
// SetStartsAt sets the "starts_at" field.
func (_u *AnnouncementUpdate) SetStartsAt(v time.Time) *AnnouncementUpdate {
_u.mutation.SetStartsAt(v)
return _u
}
// SetNillableStartsAt sets the "starts_at" field if the given value is not nil.
func (_u *AnnouncementUpdate) SetNillableStartsAt(v *time.Time) *AnnouncementUpdate {
if v != nil {
_u.SetStartsAt(*v)
}
return _u
}
// ClearStartsAt clears the value of the "starts_at" field.
func (_u *AnnouncementUpdate) ClearStartsAt() *AnnouncementUpdate {
_u.mutation.ClearStartsAt()
return _u
}
// SetEndsAt sets the "ends_at" field.
func (_u *AnnouncementUpdate) SetEndsAt(v time.Time) *AnnouncementUpdate {
_u.mutation.SetEndsAt(v)
return _u
}
// SetNillableEndsAt sets the "ends_at" field if the given value is not nil.
func (_u *AnnouncementUpdate) SetNillableEndsAt(v *time.Time) *AnnouncementUpdate {
if v != nil {
_u.SetEndsAt(*v)
}
return _u
}
// ClearEndsAt clears the value of the "ends_at" field.
func (_u *AnnouncementUpdate) ClearEndsAt() *AnnouncementUpdate {
_u.mutation.ClearEndsAt()
return _u
}
// SetCreatedBy sets the "created_by" field.
func (_u *AnnouncementUpdate) SetCreatedBy(v int64) *AnnouncementUpdate {
_u.mutation.ResetCreatedBy()
_u.mutation.SetCreatedBy(v)
return _u
}
// SetNillableCreatedBy sets the "created_by" field if the given value is not nil.
func (_u *AnnouncementUpdate) SetNillableCreatedBy(v *int64) *AnnouncementUpdate {
if v != nil {
_u.SetCreatedBy(*v)
}
return _u
}
// AddCreatedBy adds value to the "created_by" field.
func (_u *AnnouncementUpdate) AddCreatedBy(v int64) *AnnouncementUpdate {
_u.mutation.AddCreatedBy(v)
return _u
}
// ClearCreatedBy clears the value of the "created_by" field.
func (_u *AnnouncementUpdate) ClearCreatedBy() *AnnouncementUpdate {
_u.mutation.ClearCreatedBy()
return _u
}
// SetUpdatedBy sets the "updated_by" field.
func (_u *AnnouncementUpdate) SetUpdatedBy(v int64) *AnnouncementUpdate {
_u.mutation.ResetUpdatedBy()
_u.mutation.SetUpdatedBy(v)
return _u
}
// SetNillableUpdatedBy sets the "updated_by" field if the given value is not nil.
func (_u *AnnouncementUpdate) SetNillableUpdatedBy(v *int64) *AnnouncementUpdate {
if v != nil {
_u.SetUpdatedBy(*v)
}
return _u
}
// AddUpdatedBy adds value to the "updated_by" field.
func (_u *AnnouncementUpdate) AddUpdatedBy(v int64) *AnnouncementUpdate {
_u.mutation.AddUpdatedBy(v)
return _u
}
// ClearUpdatedBy clears the value of the "updated_by" field.
func (_u *AnnouncementUpdate) ClearUpdatedBy() *AnnouncementUpdate {
_u.mutation.ClearUpdatedBy()
return _u
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *AnnouncementUpdate) SetUpdatedAt(v time.Time) *AnnouncementUpdate {
_u.mutation.SetUpdatedAt(v)
return _u
}
// AddReadIDs adds the "reads" edge to the AnnouncementRead entity by IDs.
func (_u *AnnouncementUpdate) AddReadIDs(ids ...int64) *AnnouncementUpdate {
_u.mutation.AddReadIDs(ids...)
return _u
}
// AddReads adds the "reads" edges to the AnnouncementRead entity.
func (_u *AnnouncementUpdate) AddReads(v ...*AnnouncementRead) *AnnouncementUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddReadIDs(ids...)
}
// Mutation returns the AnnouncementMutation object of the builder.
func (_u *AnnouncementUpdate) Mutation() *AnnouncementMutation {
return _u.mutation
}
// ClearReads clears all "reads" edges to the AnnouncementRead entity.
func (_u *AnnouncementUpdate) ClearReads() *AnnouncementUpdate {
_u.mutation.ClearReads()
return _u
}
// RemoveReadIDs removes the "reads" edge to AnnouncementRead entities by IDs.
func (_u *AnnouncementUpdate) RemoveReadIDs(ids ...int64) *AnnouncementUpdate {
_u.mutation.RemoveReadIDs(ids...)
return _u
}
// RemoveReads removes "reads" edges to AnnouncementRead entities.
func (_u *AnnouncementUpdate) RemoveReads(v ...*AnnouncementRead) *AnnouncementUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveReadIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *AnnouncementUpdate) Save(ctx context.Context) (int, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *AnnouncementUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *AnnouncementUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *AnnouncementUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *AnnouncementUpdate) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := announcement.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *AnnouncementUpdate) check() error {
if v, ok := _u.mutation.Title(); ok {
if err := announcement.TitleValidator(v); err != nil {
return &ValidationError{Name: "title", err: fmt.Errorf(`ent: validator failed for field "Announcement.title": %w`, err)}
}
}
if v, ok := _u.mutation.Content(); ok {
if err := announcement.ContentValidator(v); err != nil {
return &ValidationError{Name: "content", err: fmt.Errorf(`ent: validator failed for field "Announcement.content": %w`, err)}
}
}
if v, ok := _u.mutation.Status(); ok {
if err := announcement.StatusValidator(v); err != nil {
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)}
}
}
return nil
}
func (_u *AnnouncementUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(announcement.Table, announcement.Columns, sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.Title(); ok {
_spec.SetField(announcement.FieldTitle, field.TypeString, value)
}
if value, ok := _u.mutation.Content(); ok {
_spec.SetField(announcement.FieldContent, field.TypeString, value)
}
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(announcement.FieldStatus, field.TypeString, value)
}
if value, ok := _u.mutation.Targeting(); ok {
_spec.SetField(announcement.FieldTargeting, field.TypeJSON, value)
}
if _u.mutation.TargetingCleared() {
_spec.ClearField(announcement.FieldTargeting, field.TypeJSON)
}
if value, ok := _u.mutation.StartsAt(); ok {
_spec.SetField(announcement.FieldStartsAt, field.TypeTime, value)
}
if _u.mutation.StartsAtCleared() {
_spec.ClearField(announcement.FieldStartsAt, field.TypeTime)
}
if value, ok := _u.mutation.EndsAt(); ok {
_spec.SetField(announcement.FieldEndsAt, field.TypeTime, value)
}
if _u.mutation.EndsAtCleared() {
_spec.ClearField(announcement.FieldEndsAt, field.TypeTime)
}
if value, ok := _u.mutation.CreatedBy(); ok {
_spec.SetField(announcement.FieldCreatedBy, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedCreatedBy(); ok {
_spec.AddField(announcement.FieldCreatedBy, field.TypeInt64, value)
}
if _u.mutation.CreatedByCleared() {
_spec.ClearField(announcement.FieldCreatedBy, field.TypeInt64)
}
if value, ok := _u.mutation.UpdatedBy(); ok {
_spec.SetField(announcement.FieldUpdatedBy, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedUpdatedBy(); ok {
_spec.AddField(announcement.FieldUpdatedBy, field.TypeInt64, value)
}
if _u.mutation.UpdatedByCleared() {
_spec.ClearField(announcement.FieldUpdatedBy, field.TypeInt64)
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(announcement.FieldUpdatedAt, field.TypeTime, value)
}
if _u.mutation.ReadsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: announcement.ReadsTable,
Columns: []string{announcement.ReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedReadsIDs(); len(nodes) > 0 && !_u.mutation.ReadsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: announcement.ReadsTable,
Columns: []string{announcement.ReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.ReadsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: announcement.ReadsTable,
Columns: []string{announcement.ReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{announcement.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// AnnouncementUpdateOne is the builder for updating a single Announcement entity.
type AnnouncementUpdateOne struct {
config
fields []string
hooks []Hook
mutation *AnnouncementMutation
}
// SetTitle sets the "title" field.
func (_u *AnnouncementUpdateOne) SetTitle(v string) *AnnouncementUpdateOne {
_u.mutation.SetTitle(v)
return _u
}
// SetNillableTitle sets the "title" field if the given value is not nil.
func (_u *AnnouncementUpdateOne) SetNillableTitle(v *string) *AnnouncementUpdateOne {
if v != nil {
_u.SetTitle(*v)
}
return _u
}
// SetContent sets the "content" field.
func (_u *AnnouncementUpdateOne) SetContent(v string) *AnnouncementUpdateOne {
_u.mutation.SetContent(v)
return _u
}
// SetNillableContent sets the "content" field if the given value is not nil.
func (_u *AnnouncementUpdateOne) SetNillableContent(v *string) *AnnouncementUpdateOne {
if v != nil {
_u.SetContent(*v)
}
return _u
}
// SetStatus sets the "status" field.
func (_u *AnnouncementUpdateOne) SetStatus(v string) *AnnouncementUpdateOne {
_u.mutation.SetStatus(v)
return _u
}
// SetNillableStatus sets the "status" field if the given value is not nil.
func (_u *AnnouncementUpdateOne) SetNillableStatus(v *string) *AnnouncementUpdateOne {
if v != nil {
_u.SetStatus(*v)
}
return _u
}
// SetTargeting sets the "targeting" field.
func (_u *AnnouncementUpdateOne) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpdateOne {
_u.mutation.SetTargeting(v)
return _u
}
// SetNillableTargeting sets the "targeting" field if the given value is not nil.
func (_u *AnnouncementUpdateOne) SetNillableTargeting(v *domain.AnnouncementTargeting) *AnnouncementUpdateOne {
if v != nil {
_u.SetTargeting(*v)
}
return _u
}
// ClearTargeting clears the value of the "targeting" field.
func (_u *AnnouncementUpdateOne) ClearTargeting() *AnnouncementUpdateOne {
_u.mutation.ClearTargeting()
return _u
}
// SetStartsAt sets the "starts_at" field.
func (_u *AnnouncementUpdateOne) SetStartsAt(v time.Time) *AnnouncementUpdateOne {
_u.mutation.SetStartsAt(v)
return _u
}
// SetNillableStartsAt sets the "starts_at" field if the given value is not nil.
func (_u *AnnouncementUpdateOne) SetNillableStartsAt(v *time.Time) *AnnouncementUpdateOne {
if v != nil {
_u.SetStartsAt(*v)
}
return _u
}
// ClearStartsAt clears the value of the "starts_at" field.
func (_u *AnnouncementUpdateOne) ClearStartsAt() *AnnouncementUpdateOne {
_u.mutation.ClearStartsAt()
return _u
}
// SetEndsAt sets the "ends_at" field.
func (_u *AnnouncementUpdateOne) SetEndsAt(v time.Time) *AnnouncementUpdateOne {
_u.mutation.SetEndsAt(v)
return _u
}
// SetNillableEndsAt sets the "ends_at" field if the given value is not nil.
func (_u *AnnouncementUpdateOne) SetNillableEndsAt(v *time.Time) *AnnouncementUpdateOne {
if v != nil {
_u.SetEndsAt(*v)
}
return _u
}
// ClearEndsAt clears the value of the "ends_at" field.
func (_u *AnnouncementUpdateOne) ClearEndsAt() *AnnouncementUpdateOne {
_u.mutation.ClearEndsAt()
return _u
}
// SetCreatedBy sets the "created_by" field.
func (_u *AnnouncementUpdateOne) SetCreatedBy(v int64) *AnnouncementUpdateOne {
_u.mutation.ResetCreatedBy()
_u.mutation.SetCreatedBy(v)
return _u
}
// SetNillableCreatedBy sets the "created_by" field if the given value is not nil.
func (_u *AnnouncementUpdateOne) SetNillableCreatedBy(v *int64) *AnnouncementUpdateOne {
if v != nil {
_u.SetCreatedBy(*v)
}
return _u
}
// AddCreatedBy adds value to the "created_by" field.
func (_u *AnnouncementUpdateOne) AddCreatedBy(v int64) *AnnouncementUpdateOne {
_u.mutation.AddCreatedBy(v)
return _u
}
// ClearCreatedBy clears the value of the "created_by" field.
func (_u *AnnouncementUpdateOne) ClearCreatedBy() *AnnouncementUpdateOne {
_u.mutation.ClearCreatedBy()
return _u
}
// SetUpdatedBy sets the "updated_by" field.
func (_u *AnnouncementUpdateOne) SetUpdatedBy(v int64) *AnnouncementUpdateOne {
_u.mutation.ResetUpdatedBy()
_u.mutation.SetUpdatedBy(v)
return _u
}
// SetNillableUpdatedBy sets the "updated_by" field if the given value is not nil.
func (_u *AnnouncementUpdateOne) SetNillableUpdatedBy(v *int64) *AnnouncementUpdateOne {
if v != nil {
_u.SetUpdatedBy(*v)
}
return _u
}
// AddUpdatedBy adds value to the "updated_by" field.
func (_u *AnnouncementUpdateOne) AddUpdatedBy(v int64) *AnnouncementUpdateOne {
_u.mutation.AddUpdatedBy(v)
return _u
}
// ClearUpdatedBy clears the value of the "updated_by" field.
func (_u *AnnouncementUpdateOne) ClearUpdatedBy() *AnnouncementUpdateOne {
_u.mutation.ClearUpdatedBy()
return _u
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *AnnouncementUpdateOne) SetUpdatedAt(v time.Time) *AnnouncementUpdateOne {
_u.mutation.SetUpdatedAt(v)
return _u
}
// AddReadIDs adds the "reads" edge to the AnnouncementRead entity by IDs.
func (_u *AnnouncementUpdateOne) AddReadIDs(ids ...int64) *AnnouncementUpdateOne {
_u.mutation.AddReadIDs(ids...)
return _u
}
// AddReads adds the "reads" edges to the AnnouncementRead entity.
func (_u *AnnouncementUpdateOne) AddReads(v ...*AnnouncementRead) *AnnouncementUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddReadIDs(ids...)
}
// Mutation returns the AnnouncementMutation object of the builder.
func (_u *AnnouncementUpdateOne) Mutation() *AnnouncementMutation {
return _u.mutation
}
// ClearReads clears all "reads" edges to the AnnouncementRead entity.
func (_u *AnnouncementUpdateOne) ClearReads() *AnnouncementUpdateOne {
_u.mutation.ClearReads()
return _u
}
// RemoveReadIDs removes the "reads" edge to AnnouncementRead entities by IDs.
func (_u *AnnouncementUpdateOne) RemoveReadIDs(ids ...int64) *AnnouncementUpdateOne {
_u.mutation.RemoveReadIDs(ids...)
return _u
}
// RemoveReads removes "reads" edges to AnnouncementRead entities.
func (_u *AnnouncementUpdateOne) RemoveReads(v ...*AnnouncementRead) *AnnouncementUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveReadIDs(ids...)
}
// Where appends a list predicates to the AnnouncementUpdate builder.
func (_u *AnnouncementUpdateOne) Where(ps ...predicate.Announcement) *AnnouncementUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *AnnouncementUpdateOne) Select(field string, fields ...string) *AnnouncementUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated Announcement entity.
func (_u *AnnouncementUpdateOne) Save(ctx context.Context) (*Announcement, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *AnnouncementUpdateOne) SaveX(ctx context.Context) *Announcement {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *AnnouncementUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *AnnouncementUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *AnnouncementUpdateOne) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := announcement.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *AnnouncementUpdateOne) check() error {
if v, ok := _u.mutation.Title(); ok {
if err := announcement.TitleValidator(v); err != nil {
return &ValidationError{Name: "title", err: fmt.Errorf(`ent: validator failed for field "Announcement.title": %w`, err)}
}
}
if v, ok := _u.mutation.Content(); ok {
if err := announcement.ContentValidator(v); err != nil {
return &ValidationError{Name: "content", err: fmt.Errorf(`ent: validator failed for field "Announcement.content": %w`, err)}
}
}
if v, ok := _u.mutation.Status(); ok {
if err := announcement.StatusValidator(v); err != nil {
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)}
}
}
return nil
}
func (_u *AnnouncementUpdateOne) sqlSave(ctx context.Context) (_node *Announcement, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(announcement.Table, announcement.Columns, sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Announcement.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, announcement.FieldID)
for _, f := range fields {
if !announcement.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != announcement.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.Title(); ok {
_spec.SetField(announcement.FieldTitle, field.TypeString, value)
}
if value, ok := _u.mutation.Content(); ok {
_spec.SetField(announcement.FieldContent, field.TypeString, value)
}
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(announcement.FieldStatus, field.TypeString, value)
}
if value, ok := _u.mutation.Targeting(); ok {
_spec.SetField(announcement.FieldTargeting, field.TypeJSON, value)
}
if _u.mutation.TargetingCleared() {
_spec.ClearField(announcement.FieldTargeting, field.TypeJSON)
}
if value, ok := _u.mutation.StartsAt(); ok {
_spec.SetField(announcement.FieldStartsAt, field.TypeTime, value)
}
if _u.mutation.StartsAtCleared() {
_spec.ClearField(announcement.FieldStartsAt, field.TypeTime)
}
if value, ok := _u.mutation.EndsAt(); ok {
_spec.SetField(announcement.FieldEndsAt, field.TypeTime, value)
}
if _u.mutation.EndsAtCleared() {
_spec.ClearField(announcement.FieldEndsAt, field.TypeTime)
}
if value, ok := _u.mutation.CreatedBy(); ok {
_spec.SetField(announcement.FieldCreatedBy, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedCreatedBy(); ok {
_spec.AddField(announcement.FieldCreatedBy, field.TypeInt64, value)
}
if _u.mutation.CreatedByCleared() {
_spec.ClearField(announcement.FieldCreatedBy, field.TypeInt64)
}
if value, ok := _u.mutation.UpdatedBy(); ok {
_spec.SetField(announcement.FieldUpdatedBy, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedUpdatedBy(); ok {
_spec.AddField(announcement.FieldUpdatedBy, field.TypeInt64, value)
}
if _u.mutation.UpdatedByCleared() {
_spec.ClearField(announcement.FieldUpdatedBy, field.TypeInt64)
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(announcement.FieldUpdatedAt, field.TypeTime, value)
}
if _u.mutation.ReadsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: announcement.ReadsTable,
Columns: []string{announcement.ReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedReadsIDs(); len(nodes) > 0 && !_u.mutation.ReadsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: announcement.ReadsTable,
Columns: []string{announcement.ReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.ReadsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: announcement.ReadsTable,
Columns: []string{announcement.ReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &Announcement{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{announcement.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}

View File

@@ -0,0 +1,185 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"fmt"
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/user"
)
// AnnouncementRead is the model entity for the AnnouncementRead schema.
type AnnouncementRead struct {
config `json:"-"`
// ID of the ent.
ID int64 `json:"id,omitempty"`
// AnnouncementID holds the value of the "announcement_id" field.
AnnouncementID int64 `json:"announcement_id,omitempty"`
// UserID holds the value of the "user_id" field.
UserID int64 `json:"user_id,omitempty"`
// 用户首次已读时间
ReadAt time.Time `json:"read_at,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the AnnouncementReadQuery when eager-loading is set.
Edges AnnouncementReadEdges `json:"edges"`
selectValues sql.SelectValues
}
// AnnouncementReadEdges holds the relations/edges for other nodes in the graph.
type AnnouncementReadEdges struct {
// Announcement holds the value of the announcement edge.
Announcement *Announcement `json:"announcement,omitempty"`
// User holds the value of the user edge.
User *User `json:"user,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [2]bool
}
// AnnouncementOrErr returns the Announcement value or an error if the edge
// was not loaded in eager-loading, or loaded but was not found.
func (e AnnouncementReadEdges) AnnouncementOrErr() (*Announcement, error) {
if e.Announcement != nil {
return e.Announcement, nil
} else if e.loadedTypes[0] {
return nil, &NotFoundError{label: announcement.Label}
}
return nil, &NotLoadedError{edge: "announcement"}
}
// UserOrErr returns the User value or an error if the edge
// was not loaded in eager-loading, or loaded but was not found.
func (e AnnouncementReadEdges) UserOrErr() (*User, error) {
if e.User != nil {
return e.User, nil
} else if e.loadedTypes[1] {
return nil, &NotFoundError{label: user.Label}
}
return nil, &NotLoadedError{edge: "user"}
}
// scanValues returns the types for scanning values from sql.Rows.
func (*AnnouncementRead) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case announcementread.FieldID, announcementread.FieldAnnouncementID, announcementread.FieldUserID:
values[i] = new(sql.NullInt64)
case announcementread.FieldReadAt, announcementread.FieldCreatedAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
}
}
return values, nil
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the AnnouncementRead fields.
func (_m *AnnouncementRead) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
for i := range columns {
switch columns[i] {
case announcementread.FieldID:
value, ok := values[i].(*sql.NullInt64)
if !ok {
return fmt.Errorf("unexpected type %T for field id", value)
}
_m.ID = int64(value.Int64)
case announcementread.FieldAnnouncementID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field announcement_id", values[i])
} else if value.Valid {
_m.AnnouncementID = value.Int64
}
case announcementread.FieldUserID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field user_id", values[i])
} else if value.Valid {
_m.UserID = value.Int64
}
case announcementread.FieldReadAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field read_at", values[i])
} else if value.Valid {
_m.ReadAt = value.Time
}
case announcementread.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
} else if value.Valid {
_m.CreatedAt = value.Time
}
default:
_m.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// Value returns the ent.Value that was dynamically selected and assigned to the AnnouncementRead.
// This includes values selected through modifiers, order, etc.
func (_m *AnnouncementRead) Value(name string) (ent.Value, error) {
return _m.selectValues.Get(name)
}
// QueryAnnouncement queries the "announcement" edge of the AnnouncementRead entity.
func (_m *AnnouncementRead) QueryAnnouncement() *AnnouncementQuery {
return NewAnnouncementReadClient(_m.config).QueryAnnouncement(_m)
}
// QueryUser queries the "user" edge of the AnnouncementRead entity.
func (_m *AnnouncementRead) QueryUser() *UserQuery {
return NewAnnouncementReadClient(_m.config).QueryUser(_m)
}
// Update returns a builder for updating this AnnouncementRead.
// Note that you need to call AnnouncementRead.Unwrap() before calling this method if this AnnouncementRead
// was returned from a transaction, and the transaction was committed or rolled back.
func (_m *AnnouncementRead) Update() *AnnouncementReadUpdateOne {
return NewAnnouncementReadClient(_m.config).UpdateOne(_m)
}
// Unwrap unwraps the AnnouncementRead entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func (_m *AnnouncementRead) Unwrap() *AnnouncementRead {
_tx, ok := _m.config.driver.(*txDriver)
if !ok {
panic("ent: AnnouncementRead is not a transactional entity")
}
_m.config.driver = _tx.drv
return _m
}
// String implements the fmt.Stringer.
func (_m *AnnouncementRead) String() string {
var builder strings.Builder
builder.WriteString("AnnouncementRead(")
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
builder.WriteString("announcement_id=")
builder.WriteString(fmt.Sprintf("%v", _m.AnnouncementID))
builder.WriteString(", ")
builder.WriteString("user_id=")
builder.WriteString(fmt.Sprintf("%v", _m.UserID))
builder.WriteString(", ")
builder.WriteString("read_at=")
builder.WriteString(_m.ReadAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteByte(')')
return builder.String()
}
// AnnouncementReads is a parsable slice of AnnouncementRead.
type AnnouncementReads []*AnnouncementRead

View File

@@ -0,0 +1,127 @@
// Code generated by ent, DO NOT EDIT.
package announcementread
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
const (
// Label holds the string label denoting the announcementread type in the database.
Label = "announcement_read"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldAnnouncementID holds the string denoting the announcement_id field in the database.
FieldAnnouncementID = "announcement_id"
// FieldUserID holds the string denoting the user_id field in the database.
FieldUserID = "user_id"
// FieldReadAt holds the string denoting the read_at field in the database.
FieldReadAt = "read_at"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// EdgeAnnouncement holds the string denoting the announcement edge name in mutations.
EdgeAnnouncement = "announcement"
// EdgeUser holds the string denoting the user edge name in mutations.
EdgeUser = "user"
// Table holds the table name of the announcementread in the database.
Table = "announcement_reads"
// AnnouncementTable is the table that holds the announcement relation/edge.
AnnouncementTable = "announcement_reads"
// AnnouncementInverseTable is the table name for the Announcement entity.
// It exists in this package in order to avoid circular dependency with the "announcement" package.
AnnouncementInverseTable = "announcements"
// AnnouncementColumn is the table column denoting the announcement relation/edge.
AnnouncementColumn = "announcement_id"
// UserTable is the table that holds the user relation/edge.
UserTable = "announcement_reads"
// UserInverseTable is the table name for the User entity.
// It exists in this package in order to avoid circular dependency with the "user" package.
UserInverseTable = "users"
// UserColumn is the table column denoting the user relation/edge.
UserColumn = "user_id"
)
// Columns holds all SQL columns for announcementread fields.
var Columns = []string{
FieldID,
FieldAnnouncementID,
FieldUserID,
FieldReadAt,
FieldCreatedAt,
}
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
if column == Columns[i] {
return true
}
}
return false
}
var (
// DefaultReadAt holds the default value on creation for the "read_at" field.
DefaultReadAt func() time.Time
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
)
// OrderOption defines the ordering options for the AnnouncementRead queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByAnnouncementID orders the results by the announcement_id field.
func ByAnnouncementID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAnnouncementID, opts...).ToFunc()
}
// ByUserID orders the results by the user_id field.
func ByUserID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUserID, opts...).ToFunc()
}
// ByReadAt orders the results by the read_at field.
func ByReadAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldReadAt, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByAnnouncementField orders the results by announcement field.
func ByAnnouncementField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newAnnouncementStep(), sql.OrderByField(field, opts...))
}
}
// ByUserField orders the results by user field.
func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...))
}
}
func newAnnouncementStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(AnnouncementInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, AnnouncementTable, AnnouncementColumn),
)
}
func newUserStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(UserInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
)
}

View File

@@ -0,0 +1,257 @@
// Code generated by ent, DO NOT EDIT.
package announcementread
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// ID filters vertices based on their ID field.
func ID(id int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldEQ(FieldID, id))
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldEQ(FieldID, id))
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldNEQ(FieldID, id))
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldIn(FieldID, ids...))
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldNotIn(FieldID, ids...))
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldGT(FieldID, id))
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldGTE(FieldID, id))
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldLT(FieldID, id))
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldLTE(FieldID, id))
}
// AnnouncementID applies equality check predicate on the "announcement_id" field. It's identical to AnnouncementIDEQ.
func AnnouncementID(v int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldEQ(FieldAnnouncementID, v))
}
// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ.
func UserID(v int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldEQ(FieldUserID, v))
}
// ReadAt applies equality check predicate on the "read_at" field. It's identical to ReadAtEQ.
func ReadAt(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldEQ(FieldReadAt, v))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldEQ(FieldCreatedAt, v))
}
// AnnouncementIDEQ applies the EQ predicate on the "announcement_id" field.
func AnnouncementIDEQ(v int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldEQ(FieldAnnouncementID, v))
}
// AnnouncementIDNEQ applies the NEQ predicate on the "announcement_id" field.
func AnnouncementIDNEQ(v int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldNEQ(FieldAnnouncementID, v))
}
// AnnouncementIDIn applies the In predicate on the "announcement_id" field.
func AnnouncementIDIn(vs ...int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldIn(FieldAnnouncementID, vs...))
}
// AnnouncementIDNotIn applies the NotIn predicate on the "announcement_id" field.
func AnnouncementIDNotIn(vs ...int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldNotIn(FieldAnnouncementID, vs...))
}
// UserIDEQ applies the EQ predicate on the "user_id" field.
func UserIDEQ(v int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldEQ(FieldUserID, v))
}
// UserIDNEQ applies the NEQ predicate on the "user_id" field.
func UserIDNEQ(v int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldNEQ(FieldUserID, v))
}
// UserIDIn applies the In predicate on the "user_id" field.
func UserIDIn(vs ...int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldIn(FieldUserID, vs...))
}
// UserIDNotIn applies the NotIn predicate on the "user_id" field.
func UserIDNotIn(vs ...int64) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldNotIn(FieldUserID, vs...))
}
// ReadAtEQ applies the EQ predicate on the "read_at" field.
func ReadAtEQ(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldEQ(FieldReadAt, v))
}
// ReadAtNEQ applies the NEQ predicate on the "read_at" field.
func ReadAtNEQ(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldNEQ(FieldReadAt, v))
}
// ReadAtIn applies the In predicate on the "read_at" field.
func ReadAtIn(vs ...time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldIn(FieldReadAt, vs...))
}
// ReadAtNotIn applies the NotIn predicate on the "read_at" field.
func ReadAtNotIn(vs ...time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldNotIn(FieldReadAt, vs...))
}
// ReadAtGT applies the GT predicate on the "read_at" field.
func ReadAtGT(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldGT(FieldReadAt, v))
}
// ReadAtGTE applies the GTE predicate on the "read_at" field.
func ReadAtGTE(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldGTE(FieldReadAt, v))
}
// ReadAtLT applies the LT predicate on the "read_at" field.
func ReadAtLT(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldLT(FieldReadAt, v))
}
// ReadAtLTE applies the LTE predicate on the "read_at" field.
func ReadAtLTE(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldLTE(FieldReadAt, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldEQ(FieldCreatedAt, v))
}
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
func CreatedAtNEQ(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldNEQ(FieldCreatedAt, v))
}
// CreatedAtIn applies the In predicate on the "created_at" field.
func CreatedAtIn(vs ...time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldIn(FieldCreatedAt, vs...))
}
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
func CreatedAtNotIn(vs ...time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldNotIn(FieldCreatedAt, vs...))
}
// CreatedAtGT applies the GT predicate on the "created_at" field.
func CreatedAtGT(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldGT(FieldCreatedAt, v))
}
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
func CreatedAtGTE(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldGTE(FieldCreatedAt, v))
}
// CreatedAtLT applies the LT predicate on the "created_at" field.
func CreatedAtLT(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldLT(FieldCreatedAt, v))
}
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
func CreatedAtLTE(v time.Time) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.FieldLTE(FieldCreatedAt, v))
}
// HasAnnouncement applies the HasEdge predicate on the "announcement" edge.
func HasAnnouncement() predicate.AnnouncementRead {
return predicate.AnnouncementRead(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, AnnouncementTable, AnnouncementColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasAnnouncementWith applies the HasEdge predicate on the "announcement" edge with a given conditions (other predicates).
func HasAnnouncementWith(preds ...predicate.Announcement) predicate.AnnouncementRead {
return predicate.AnnouncementRead(func(s *sql.Selector) {
step := newAnnouncementStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasUser applies the HasEdge predicate on the "user" edge.
func HasUser() predicate.AnnouncementRead {
return predicate.AnnouncementRead(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates).
func HasUserWith(preds ...predicate.User) predicate.AnnouncementRead {
return predicate.AnnouncementRead(func(s *sql.Selector) {
step := newUserStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.AnnouncementRead) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.AndPredicates(predicates...))
}
// Or groups predicates with the OR operator between them.
func Or(predicates ...predicate.AnnouncementRead) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.OrPredicates(predicates...))
}
// Not applies the not operator on the given predicate.
func Not(p predicate.AnnouncementRead) predicate.AnnouncementRead {
return predicate.AnnouncementRead(sql.NotPredicates(p))
}

View File

@@ -0,0 +1,660 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/user"
)
// AnnouncementReadCreate is the builder for creating a AnnouncementRead entity.
type AnnouncementReadCreate struct {
config
mutation *AnnouncementReadMutation
hooks []Hook
conflict []sql.ConflictOption
}
// SetAnnouncementID sets the "announcement_id" field.
func (_c *AnnouncementReadCreate) SetAnnouncementID(v int64) *AnnouncementReadCreate {
_c.mutation.SetAnnouncementID(v)
return _c
}
// SetUserID sets the "user_id" field.
func (_c *AnnouncementReadCreate) SetUserID(v int64) *AnnouncementReadCreate {
_c.mutation.SetUserID(v)
return _c
}
// SetReadAt sets the "read_at" field.
func (_c *AnnouncementReadCreate) SetReadAt(v time.Time) *AnnouncementReadCreate {
_c.mutation.SetReadAt(v)
return _c
}
// SetNillableReadAt sets the "read_at" field if the given value is not nil.
func (_c *AnnouncementReadCreate) SetNillableReadAt(v *time.Time) *AnnouncementReadCreate {
if v != nil {
_c.SetReadAt(*v)
}
return _c
}
// SetCreatedAt sets the "created_at" field.
func (_c *AnnouncementReadCreate) SetCreatedAt(v time.Time) *AnnouncementReadCreate {
_c.mutation.SetCreatedAt(v)
return _c
}
// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
func (_c *AnnouncementReadCreate) SetNillableCreatedAt(v *time.Time) *AnnouncementReadCreate {
if v != nil {
_c.SetCreatedAt(*v)
}
return _c
}
// SetAnnouncement sets the "announcement" edge to the Announcement entity.
func (_c *AnnouncementReadCreate) SetAnnouncement(v *Announcement) *AnnouncementReadCreate {
return _c.SetAnnouncementID(v.ID)
}
// SetUser sets the "user" edge to the User entity.
func (_c *AnnouncementReadCreate) SetUser(v *User) *AnnouncementReadCreate {
return _c.SetUserID(v.ID)
}
// Mutation returns the AnnouncementReadMutation object of the builder.
func (_c *AnnouncementReadCreate) Mutation() *AnnouncementReadMutation {
return _c.mutation
}
// Save creates the AnnouncementRead in the database.
func (_c *AnnouncementReadCreate) Save(ctx context.Context) (*AnnouncementRead, error) {
_c.defaults()
return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
}
// SaveX calls Save and panics if Save returns an error.
func (_c *AnnouncementReadCreate) SaveX(ctx context.Context) *AnnouncementRead {
v, err := _c.Save(ctx)
if err != nil {
panic(err)
}
return v
}
// Exec executes the query.
func (_c *AnnouncementReadCreate) Exec(ctx context.Context) error {
_, err := _c.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_c *AnnouncementReadCreate) ExecX(ctx context.Context) {
if err := _c.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_c *AnnouncementReadCreate) defaults() {
if _, ok := _c.mutation.ReadAt(); !ok {
v := announcementread.DefaultReadAt()
_c.mutation.SetReadAt(v)
}
if _, ok := _c.mutation.CreatedAt(); !ok {
v := announcementread.DefaultCreatedAt()
_c.mutation.SetCreatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_c *AnnouncementReadCreate) check() error {
if _, ok := _c.mutation.AnnouncementID(); !ok {
return &ValidationError{Name: "announcement_id", err: errors.New(`ent: missing required field "AnnouncementRead.announcement_id"`)}
}
if _, ok := _c.mutation.UserID(); !ok {
return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "AnnouncementRead.user_id"`)}
}
if _, ok := _c.mutation.ReadAt(); !ok {
return &ValidationError{Name: "read_at", err: errors.New(`ent: missing required field "AnnouncementRead.read_at"`)}
}
if _, ok := _c.mutation.CreatedAt(); !ok {
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AnnouncementRead.created_at"`)}
}
if len(_c.mutation.AnnouncementIDs()) == 0 {
return &ValidationError{Name: "announcement", err: errors.New(`ent: missing required edge "AnnouncementRead.announcement"`)}
}
if len(_c.mutation.UserIDs()) == 0 {
return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "AnnouncementRead.user"`)}
}
return nil
}
func (_c *AnnouncementReadCreate) sqlSave(ctx context.Context) (*AnnouncementRead, error) {
if err := _c.check(); err != nil {
return nil, err
}
_node, _spec := _c.createSpec()
if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
id := _spec.ID.Value.(int64)
_node.ID = int64(id)
_c.mutation.id = &_node.ID
_c.mutation.done = true
return _node, nil
}
func (_c *AnnouncementReadCreate) createSpec() (*AnnouncementRead, *sqlgraph.CreateSpec) {
var (
_node = &AnnouncementRead{config: _c.config}
_spec = sqlgraph.NewCreateSpec(announcementread.Table, sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64))
)
_spec.OnConflict = _c.conflict
if value, ok := _c.mutation.ReadAt(); ok {
_spec.SetField(announcementread.FieldReadAt, field.TypeTime, value)
_node.ReadAt = value
}
if value, ok := _c.mutation.CreatedAt(); ok {
_spec.SetField(announcementread.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = value
}
if nodes := _c.mutation.AnnouncementIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: announcementread.AnnouncementTable,
Columns: []string{announcementread.AnnouncementColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_node.AnnouncementID = nodes[0]
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: announcementread.UserTable,
Columns: []string{announcementread.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_node.UserID = nodes[0]
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec
}
// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
// of the `INSERT` statement. For example:
//
// client.AnnouncementRead.Create().
// SetAnnouncementID(v).
// OnConflict(
// // Update the row with the new values
// // the was proposed for insertion.
// sql.ResolveWithNewValues(),
// ).
// // Override some of the fields with custom
// // update values.
// Update(func(u *ent.AnnouncementReadUpsert) {
// SetAnnouncementID(v+v).
// }).
// Exec(ctx)
func (_c *AnnouncementReadCreate) OnConflict(opts ...sql.ConflictOption) *AnnouncementReadUpsertOne {
_c.conflict = opts
return &AnnouncementReadUpsertOne{
create: _c,
}
}
// OnConflictColumns calls `OnConflict` and configures the columns
// as conflict target. Using this option is equivalent to using:
//
// client.AnnouncementRead.Create().
// OnConflict(sql.ConflictColumns(columns...)).
// Exec(ctx)
func (_c *AnnouncementReadCreate) OnConflictColumns(columns ...string) *AnnouncementReadUpsertOne {
_c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
return &AnnouncementReadUpsertOne{
create: _c,
}
}
type (
// AnnouncementReadUpsertOne is the builder for "upsert"-ing
// one AnnouncementRead node.
AnnouncementReadUpsertOne struct {
create *AnnouncementReadCreate
}
// AnnouncementReadUpsert is the "OnConflict" setter.
AnnouncementReadUpsert struct {
*sql.UpdateSet
}
)
// SetAnnouncementID sets the "announcement_id" field.
func (u *AnnouncementReadUpsert) SetAnnouncementID(v int64) *AnnouncementReadUpsert {
u.Set(announcementread.FieldAnnouncementID, v)
return u
}
// UpdateAnnouncementID sets the "announcement_id" field to the value that was provided on create.
func (u *AnnouncementReadUpsert) UpdateAnnouncementID() *AnnouncementReadUpsert {
u.SetExcluded(announcementread.FieldAnnouncementID)
return u
}
// SetUserID sets the "user_id" field.
func (u *AnnouncementReadUpsert) SetUserID(v int64) *AnnouncementReadUpsert {
u.Set(announcementread.FieldUserID, v)
return u
}
// UpdateUserID sets the "user_id" field to the value that was provided on create.
func (u *AnnouncementReadUpsert) UpdateUserID() *AnnouncementReadUpsert {
u.SetExcluded(announcementread.FieldUserID)
return u
}
// SetReadAt sets the "read_at" field.
func (u *AnnouncementReadUpsert) SetReadAt(v time.Time) *AnnouncementReadUpsert {
u.Set(announcementread.FieldReadAt, v)
return u
}
// UpdateReadAt sets the "read_at" field to the value that was provided on create.
func (u *AnnouncementReadUpsert) UpdateReadAt() *AnnouncementReadUpsert {
u.SetExcluded(announcementread.FieldReadAt)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
// client.AnnouncementRead.Create().
// OnConflict(
// sql.ResolveWithNewValues(),
// ).
// Exec(ctx)
func (u *AnnouncementReadUpsertOne) UpdateNewValues() *AnnouncementReadUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
if _, exists := u.create.mutation.CreatedAt(); exists {
s.SetIgnore(announcementread.FieldCreatedAt)
}
}))
return u
}
// Ignore sets each column to itself in case of conflict.
// Using this option is equivalent to using:
//
// client.AnnouncementRead.Create().
// OnConflict(sql.ResolveWithIgnore()).
// Exec(ctx)
func (u *AnnouncementReadUpsertOne) Ignore() *AnnouncementReadUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
return u
}
// DoNothing configures the conflict_action to `DO NOTHING`.
// Supported only by SQLite and PostgreSQL.
func (u *AnnouncementReadUpsertOne) DoNothing() *AnnouncementReadUpsertOne {
u.create.conflict = append(u.create.conflict, sql.DoNothing())
return u
}
// Update allows overriding fields `UPDATE` values. See the AnnouncementReadCreate.OnConflict
// documentation for more info.
func (u *AnnouncementReadUpsertOne) Update(set func(*AnnouncementReadUpsert)) *AnnouncementReadUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
set(&AnnouncementReadUpsert{UpdateSet: update})
}))
return u
}
// SetAnnouncementID sets the "announcement_id" field.
func (u *AnnouncementReadUpsertOne) SetAnnouncementID(v int64) *AnnouncementReadUpsertOne {
return u.Update(func(s *AnnouncementReadUpsert) {
s.SetAnnouncementID(v)
})
}
// UpdateAnnouncementID sets the "announcement_id" field to the value that was provided on create.
func (u *AnnouncementReadUpsertOne) UpdateAnnouncementID() *AnnouncementReadUpsertOne {
return u.Update(func(s *AnnouncementReadUpsert) {
s.UpdateAnnouncementID()
})
}
// SetUserID sets the "user_id" field.
func (u *AnnouncementReadUpsertOne) SetUserID(v int64) *AnnouncementReadUpsertOne {
return u.Update(func(s *AnnouncementReadUpsert) {
s.SetUserID(v)
})
}
// UpdateUserID sets the "user_id" field to the value that was provided on create.
func (u *AnnouncementReadUpsertOne) UpdateUserID() *AnnouncementReadUpsertOne {
return u.Update(func(s *AnnouncementReadUpsert) {
s.UpdateUserID()
})
}
// SetReadAt sets the "read_at" field.
func (u *AnnouncementReadUpsertOne) SetReadAt(v time.Time) *AnnouncementReadUpsertOne {
return u.Update(func(s *AnnouncementReadUpsert) {
s.SetReadAt(v)
})
}
// UpdateReadAt sets the "read_at" field to the value that was provided on create.
func (u *AnnouncementReadUpsertOne) UpdateReadAt() *AnnouncementReadUpsertOne {
return u.Update(func(s *AnnouncementReadUpsert) {
s.UpdateReadAt()
})
}
// Exec executes the query.
func (u *AnnouncementReadUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
return errors.New("ent: missing options for AnnouncementReadCreate.OnConflict")
}
return u.create.Exec(ctx)
}
// ExecX is like Exec, but panics if an error occurs.
func (u *AnnouncementReadUpsertOne) ExecX(ctx context.Context) {
if err := u.create.Exec(ctx); err != nil {
panic(err)
}
}
// Exec executes the UPSERT query and returns the inserted/updated ID.
func (u *AnnouncementReadUpsertOne) ID(ctx context.Context) (id int64, err error) {
node, err := u.create.Save(ctx)
if err != nil {
return id, err
}
return node.ID, nil
}
// IDX is like ID, but panics if an error occurs.
func (u *AnnouncementReadUpsertOne) IDX(ctx context.Context) int64 {
id, err := u.ID(ctx)
if err != nil {
panic(err)
}
return id
}
// AnnouncementReadCreateBulk is the builder for creating many AnnouncementRead entities in bulk.
type AnnouncementReadCreateBulk struct {
config
err error
builders []*AnnouncementReadCreate
conflict []sql.ConflictOption
}
// Save creates the AnnouncementRead entities in the database.
func (_c *AnnouncementReadCreateBulk) Save(ctx context.Context) ([]*AnnouncementRead, error) {
if _c.err != nil {
return nil, _c.err
}
specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
nodes := make([]*AnnouncementRead, len(_c.builders))
mutators := make([]Mutator, len(_c.builders))
for i := range _c.builders {
func(i int, root context.Context) {
builder := _c.builders[i]
builder.defaults()
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*AnnouncementReadMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
if err := builder.check(); err != nil {
return nil, err
}
builder.mutation = mutation
var err error
nodes[i], specs[i] = builder.createSpec()
if i < len(mutators)-1 {
_, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
} else {
spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
spec.OnConflict = _c.conflict
// Invoke the actual operation on the latest mutation in the chain.
if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
}
}
if err != nil {
return nil, err
}
mutation.id = &nodes[i].ID
if specs[i].ID.Value != nil {
id := specs[i].ID.Value.(int64)
nodes[i].ID = int64(id)
}
mutation.done = true
return nodes[i], nil
})
for i := len(builder.hooks) - 1; i >= 0; i-- {
mut = builder.hooks[i](mut)
}
mutators[i] = mut
}(i, ctx)
}
if len(mutators) > 0 {
if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
return nil, err
}
}
return nodes, nil
}
// SaveX is like Save, but panics if an error occurs.
func (_c *AnnouncementReadCreateBulk) SaveX(ctx context.Context) []*AnnouncementRead {
v, err := _c.Save(ctx)
if err != nil {
panic(err)
}
return v
}
// Exec executes the query.
func (_c *AnnouncementReadCreateBulk) Exec(ctx context.Context) error {
_, err := _c.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_c *AnnouncementReadCreateBulk) ExecX(ctx context.Context) {
if err := _c.Exec(ctx); err != nil {
panic(err)
}
}
// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
// of the `INSERT` statement. For example:
//
// client.AnnouncementRead.CreateBulk(builders...).
// OnConflict(
// // Update the row with the new values
// // the was proposed for insertion.
// sql.ResolveWithNewValues(),
// ).
// // Override some of the fields with custom
// // update values.
// Update(func(u *ent.AnnouncementReadUpsert) {
// SetAnnouncementID(v+v).
// }).
// Exec(ctx)
func (_c *AnnouncementReadCreateBulk) OnConflict(opts ...sql.ConflictOption) *AnnouncementReadUpsertBulk {
_c.conflict = opts
return &AnnouncementReadUpsertBulk{
create: _c,
}
}
// OnConflictColumns calls `OnConflict` and configures the columns
// as conflict target. Using this option is equivalent to using:
//
// client.AnnouncementRead.Create().
// OnConflict(sql.ConflictColumns(columns...)).
// Exec(ctx)
func (_c *AnnouncementReadCreateBulk) OnConflictColumns(columns ...string) *AnnouncementReadUpsertBulk {
_c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
return &AnnouncementReadUpsertBulk{
create: _c,
}
}
// AnnouncementReadUpsertBulk is the builder for "upsert"-ing
// a bulk of AnnouncementRead nodes.
type AnnouncementReadUpsertBulk struct {
create *AnnouncementReadCreateBulk
}
// UpdateNewValues updates the mutable fields using the new values that
// were set on create. Using this option is equivalent to using:
//
// client.AnnouncementRead.Create().
// OnConflict(
// sql.ResolveWithNewValues(),
// ).
// Exec(ctx)
func (u *AnnouncementReadUpsertBulk) UpdateNewValues() *AnnouncementReadUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
for _, b := range u.create.builders {
if _, exists := b.mutation.CreatedAt(); exists {
s.SetIgnore(announcementread.FieldCreatedAt)
}
}
}))
return u
}
// Ignore sets each column to itself in case of conflict.
// Using this option is equivalent to using:
//
// client.AnnouncementRead.Create().
// OnConflict(sql.ResolveWithIgnore()).
// Exec(ctx)
func (u *AnnouncementReadUpsertBulk) Ignore() *AnnouncementReadUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
return u
}
// DoNothing configures the conflict_action to `DO NOTHING`.
// Supported only by SQLite and PostgreSQL.
func (u *AnnouncementReadUpsertBulk) DoNothing() *AnnouncementReadUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.DoNothing())
return u
}
// Update allows overriding fields `UPDATE` values. See the AnnouncementReadCreateBulk.OnConflict
// documentation for more info.
func (u *AnnouncementReadUpsertBulk) Update(set func(*AnnouncementReadUpsert)) *AnnouncementReadUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
set(&AnnouncementReadUpsert{UpdateSet: update})
}))
return u
}
// SetAnnouncementID sets the "announcement_id" field.
func (u *AnnouncementReadUpsertBulk) SetAnnouncementID(v int64) *AnnouncementReadUpsertBulk {
return u.Update(func(s *AnnouncementReadUpsert) {
s.SetAnnouncementID(v)
})
}
// UpdateAnnouncementID sets the "announcement_id" field to the value that was provided on create.
func (u *AnnouncementReadUpsertBulk) UpdateAnnouncementID() *AnnouncementReadUpsertBulk {
return u.Update(func(s *AnnouncementReadUpsert) {
s.UpdateAnnouncementID()
})
}
// SetUserID sets the "user_id" field.
func (u *AnnouncementReadUpsertBulk) SetUserID(v int64) *AnnouncementReadUpsertBulk {
return u.Update(func(s *AnnouncementReadUpsert) {
s.SetUserID(v)
})
}
// UpdateUserID sets the "user_id" field to the value that was provided on create.
func (u *AnnouncementReadUpsertBulk) UpdateUserID() *AnnouncementReadUpsertBulk {
return u.Update(func(s *AnnouncementReadUpsert) {
s.UpdateUserID()
})
}
// SetReadAt sets the "read_at" field.
func (u *AnnouncementReadUpsertBulk) SetReadAt(v time.Time) *AnnouncementReadUpsertBulk {
return u.Update(func(s *AnnouncementReadUpsert) {
s.SetReadAt(v)
})
}
// UpdateReadAt sets the "read_at" field to the value that was provided on create.
func (u *AnnouncementReadUpsertBulk) UpdateReadAt() *AnnouncementReadUpsertBulk {
return u.Update(func(s *AnnouncementReadUpsert) {
s.UpdateReadAt()
})
}
// Exec executes the query.
func (u *AnnouncementReadUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
return u.create.err
}
for i, b := range u.create.builders {
if len(b.conflict) != 0 {
return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AnnouncementReadCreateBulk instead", i)
}
}
if len(u.create.conflict) == 0 {
return errors.New("ent: missing options for AnnouncementReadCreateBulk.OnConflict")
}
return u.create.Exec(ctx)
}
// ExecX is like Exec, but panics if an error occurs.
func (u *AnnouncementReadUpsertBulk) ExecX(ctx context.Context) {
if err := u.create.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,88 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// AnnouncementReadDelete is the builder for deleting a AnnouncementRead entity.
type AnnouncementReadDelete struct {
config
hooks []Hook
mutation *AnnouncementReadMutation
}
// Where appends a list predicates to the AnnouncementReadDelete builder.
func (_d *AnnouncementReadDelete) Where(ps ...predicate.AnnouncementRead) *AnnouncementReadDelete {
_d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (_d *AnnouncementReadDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *AnnouncementReadDelete) ExecX(ctx context.Context) int {
n, err := _d.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (_d *AnnouncementReadDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(announcementread.Table, sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64))
if ps := _d.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
_d.mutation.done = true
return affected, err
}
// AnnouncementReadDeleteOne is the builder for deleting a single AnnouncementRead entity.
type AnnouncementReadDeleteOne struct {
_d *AnnouncementReadDelete
}
// Where appends a list predicates to the AnnouncementReadDelete builder.
func (_d *AnnouncementReadDeleteOne) Where(ps ...predicate.AnnouncementRead) *AnnouncementReadDeleteOne {
_d._d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query.
func (_d *AnnouncementReadDeleteOne) Exec(ctx context.Context) error {
n, err := _d._d.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{announcementread.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *AnnouncementReadDeleteOne) ExecX(ctx context.Context) {
if err := _d.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,718 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/user"
)
// AnnouncementReadQuery is the builder for querying AnnouncementRead entities.
type AnnouncementReadQuery struct {
config
ctx *QueryContext
order []announcementread.OrderOption
inters []Interceptor
predicates []predicate.AnnouncementRead
withAnnouncement *AnnouncementQuery
withUser *UserQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the AnnouncementReadQuery builder.
func (_q *AnnouncementReadQuery) Where(ps ...predicate.AnnouncementRead) *AnnouncementReadQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *AnnouncementReadQuery) Limit(limit int) *AnnouncementReadQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *AnnouncementReadQuery) Offset(offset int) *AnnouncementReadQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *AnnouncementReadQuery) Unique(unique bool) *AnnouncementReadQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *AnnouncementReadQuery) Order(o ...announcementread.OrderOption) *AnnouncementReadQuery {
_q.order = append(_q.order, o...)
return _q
}
// QueryAnnouncement chains the current query on the "announcement" edge.
func (_q *AnnouncementReadQuery) QueryAnnouncement() *AnnouncementQuery {
query := (&AnnouncementClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(announcementread.Table, announcementread.FieldID, selector),
sqlgraph.To(announcement.Table, announcement.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, announcementread.AnnouncementTable, announcementread.AnnouncementColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryUser chains the current query on the "user" edge.
func (_q *AnnouncementReadQuery) QueryUser() *UserQuery {
query := (&UserClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(announcementread.Table, announcementread.FieldID, selector),
sqlgraph.To(user.Table, user.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, announcementread.UserTable, announcementread.UserColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// First returns the first AnnouncementRead entity from the query.
// Returns a *NotFoundError when no AnnouncementRead was found.
func (_q *AnnouncementReadQuery) First(ctx context.Context) (*AnnouncementRead, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{announcementread.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *AnnouncementReadQuery) FirstX(ctx context.Context) *AnnouncementRead {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first AnnouncementRead ID from the query.
// Returns a *NotFoundError when no AnnouncementRead ID was found.
func (_q *AnnouncementReadQuery) FirstID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{announcementread.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *AnnouncementReadQuery) FirstIDX(ctx context.Context) int64 {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single AnnouncementRead entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one AnnouncementRead entity is found.
// Returns a *NotFoundError when no AnnouncementRead entities are found.
func (_q *AnnouncementReadQuery) Only(ctx context.Context) (*AnnouncementRead, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{announcementread.Label}
default:
return nil, &NotSingularError{announcementread.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *AnnouncementReadQuery) OnlyX(ctx context.Context) *AnnouncementRead {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only AnnouncementRead ID in the query.
// Returns a *NotSingularError when more than one AnnouncementRead ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *AnnouncementReadQuery) OnlyID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{announcementread.Label}
default:
err = &NotSingularError{announcementread.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *AnnouncementReadQuery) OnlyIDX(ctx context.Context) int64 {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of AnnouncementReads.
func (_q *AnnouncementReadQuery) All(ctx context.Context) ([]*AnnouncementRead, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*AnnouncementRead, *AnnouncementReadQuery]()
return withInterceptors[[]*AnnouncementRead](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *AnnouncementReadQuery) AllX(ctx context.Context) []*AnnouncementRead {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of AnnouncementRead IDs.
func (_q *AnnouncementReadQuery) IDs(ctx context.Context) (ids []int64, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(announcementread.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *AnnouncementReadQuery) IDsX(ctx context.Context) []int64 {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *AnnouncementReadQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*AnnouncementReadQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *AnnouncementReadQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *AnnouncementReadQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *AnnouncementReadQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the AnnouncementReadQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *AnnouncementReadQuery) Clone() *AnnouncementReadQuery {
if _q == nil {
return nil
}
return &AnnouncementReadQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]announcementread.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.AnnouncementRead{}, _q.predicates...),
withAnnouncement: _q.withAnnouncement.Clone(),
withUser: _q.withUser.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// WithAnnouncement tells the query-builder to eager-load the nodes that are connected to
// the "announcement" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *AnnouncementReadQuery) WithAnnouncement(opts ...func(*AnnouncementQuery)) *AnnouncementReadQuery {
query := (&AnnouncementClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withAnnouncement = query
return _q
}
// WithUser tells the query-builder to eager-load the nodes that are connected to
// the "user" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *AnnouncementReadQuery) WithUser(opts ...func(*UserQuery)) *AnnouncementReadQuery {
query := (&UserClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withUser = query
return _q
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// AnnouncementID int64 `json:"announcement_id,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.AnnouncementRead.Query().
// GroupBy(announcementread.FieldAnnouncementID).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *AnnouncementReadQuery) GroupBy(field string, fields ...string) *AnnouncementReadGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &AnnouncementReadGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = announcementread.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// AnnouncementID int64 `json:"announcement_id,omitempty"`
// }
//
// client.AnnouncementRead.Query().
// Select(announcementread.FieldAnnouncementID).
// Scan(ctx, &v)
func (_q *AnnouncementReadQuery) Select(fields ...string) *AnnouncementReadSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &AnnouncementReadSelect{AnnouncementReadQuery: _q}
sbuild.label = announcementread.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a AnnouncementReadSelect configured with the given aggregations.
func (_q *AnnouncementReadQuery) Aggregate(fns ...AggregateFunc) *AnnouncementReadSelect {
return _q.Select().Aggregate(fns...)
}
func (_q *AnnouncementReadQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !announcementread.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *AnnouncementReadQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AnnouncementRead, error) {
var (
nodes = []*AnnouncementRead{}
_spec = _q.querySpec()
loadedTypes = [2]bool{
_q.withAnnouncement != nil,
_q.withUser != nil,
}
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*AnnouncementRead).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &AnnouncementRead{config: _q.config}
nodes = append(nodes, node)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
if query := _q.withAnnouncement; query != nil {
if err := _q.loadAnnouncement(ctx, query, nodes, nil,
func(n *AnnouncementRead, e *Announcement) { n.Edges.Announcement = e }); err != nil {
return nil, err
}
}
if query := _q.withUser; query != nil {
if err := _q.loadUser(ctx, query, nodes, nil,
func(n *AnnouncementRead, e *User) { n.Edges.User = e }); err != nil {
return nil, err
}
}
return nodes, nil
}
func (_q *AnnouncementReadQuery) loadAnnouncement(ctx context.Context, query *AnnouncementQuery, nodes []*AnnouncementRead, init func(*AnnouncementRead), assign func(*AnnouncementRead, *Announcement)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*AnnouncementRead)
for i := range nodes {
fk := nodes[i].AnnouncementID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(announcement.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "announcement_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *AnnouncementReadQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*AnnouncementRead, init func(*AnnouncementRead), assign func(*AnnouncementRead, *User)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*AnnouncementRead)
for i := range nodes {
fk := nodes[i].UserID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(user.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *AnnouncementReadQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *AnnouncementReadQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(announcementread.Table, announcementread.Columns, sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, announcementread.FieldID)
for i := range fields {
if fields[i] != announcementread.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
if _q.withAnnouncement != nil {
_spec.Node.AddColumnOnce(announcementread.FieldAnnouncementID)
}
if _q.withUser != nil {
_spec.Node.AddColumnOnce(announcementread.FieldUserID)
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *AnnouncementReadQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(announcementread.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = announcementread.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *AnnouncementReadQuery) ForUpdate(opts ...sql.LockOption) *AnnouncementReadQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *AnnouncementReadQuery) ForShare(opts ...sql.LockOption) *AnnouncementReadQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// AnnouncementReadGroupBy is the group-by builder for AnnouncementRead entities.
type AnnouncementReadGroupBy struct {
selector
build *AnnouncementReadQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *AnnouncementReadGroupBy) Aggregate(fns ...AggregateFunc) *AnnouncementReadGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *AnnouncementReadGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*AnnouncementReadQuery, *AnnouncementReadGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *AnnouncementReadGroupBy) sqlScan(ctx context.Context, root *AnnouncementReadQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// AnnouncementReadSelect is the builder for selecting fields of AnnouncementRead entities.
type AnnouncementReadSelect struct {
*AnnouncementReadQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *AnnouncementReadSelect) Aggregate(fns ...AggregateFunc) *AnnouncementReadSelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *AnnouncementReadSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*AnnouncementReadQuery, *AnnouncementReadSelect](ctx, _s.AnnouncementReadQuery, _s, _s.inters, v)
}
func (_s *AnnouncementReadSelect) sqlScan(ctx context.Context, root *AnnouncementReadQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}

View File

@@ -0,0 +1,456 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/user"
)
// AnnouncementReadUpdate is the builder for updating AnnouncementRead entities.
type AnnouncementReadUpdate struct {
config
hooks []Hook
mutation *AnnouncementReadMutation
}
// Where appends a list predicates to the AnnouncementReadUpdate builder.
func (_u *AnnouncementReadUpdate) Where(ps ...predicate.AnnouncementRead) *AnnouncementReadUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetAnnouncementID sets the "announcement_id" field.
func (_u *AnnouncementReadUpdate) SetAnnouncementID(v int64) *AnnouncementReadUpdate {
_u.mutation.SetAnnouncementID(v)
return _u
}
// SetNillableAnnouncementID sets the "announcement_id" field if the given value is not nil.
func (_u *AnnouncementReadUpdate) SetNillableAnnouncementID(v *int64) *AnnouncementReadUpdate {
if v != nil {
_u.SetAnnouncementID(*v)
}
return _u
}
// SetUserID sets the "user_id" field.
func (_u *AnnouncementReadUpdate) SetUserID(v int64) *AnnouncementReadUpdate {
_u.mutation.SetUserID(v)
return _u
}
// SetNillableUserID sets the "user_id" field if the given value is not nil.
func (_u *AnnouncementReadUpdate) SetNillableUserID(v *int64) *AnnouncementReadUpdate {
if v != nil {
_u.SetUserID(*v)
}
return _u
}
// SetReadAt sets the "read_at" field.
func (_u *AnnouncementReadUpdate) SetReadAt(v time.Time) *AnnouncementReadUpdate {
_u.mutation.SetReadAt(v)
return _u
}
// SetNillableReadAt sets the "read_at" field if the given value is not nil.
func (_u *AnnouncementReadUpdate) SetNillableReadAt(v *time.Time) *AnnouncementReadUpdate {
if v != nil {
_u.SetReadAt(*v)
}
return _u
}
// SetAnnouncement sets the "announcement" edge to the Announcement entity.
func (_u *AnnouncementReadUpdate) SetAnnouncement(v *Announcement) *AnnouncementReadUpdate {
return _u.SetAnnouncementID(v.ID)
}
// SetUser sets the "user" edge to the User entity.
func (_u *AnnouncementReadUpdate) SetUser(v *User) *AnnouncementReadUpdate {
return _u.SetUserID(v.ID)
}
// Mutation returns the AnnouncementReadMutation object of the builder.
func (_u *AnnouncementReadUpdate) Mutation() *AnnouncementReadMutation {
return _u.mutation
}
// ClearAnnouncement clears the "announcement" edge to the Announcement entity.
func (_u *AnnouncementReadUpdate) ClearAnnouncement() *AnnouncementReadUpdate {
_u.mutation.ClearAnnouncement()
return _u
}
// ClearUser clears the "user" edge to the User entity.
func (_u *AnnouncementReadUpdate) ClearUser() *AnnouncementReadUpdate {
_u.mutation.ClearUser()
return _u
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *AnnouncementReadUpdate) Save(ctx context.Context) (int, error) {
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *AnnouncementReadUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *AnnouncementReadUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *AnnouncementReadUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *AnnouncementReadUpdate) check() error {
if _u.mutation.AnnouncementCleared() && len(_u.mutation.AnnouncementIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "AnnouncementRead.announcement"`)
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "AnnouncementRead.user"`)
}
return nil
}
func (_u *AnnouncementReadUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(announcementread.Table, announcementread.Columns, sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.ReadAt(); ok {
_spec.SetField(announcementread.FieldReadAt, field.TypeTime, value)
}
if _u.mutation.AnnouncementCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: announcementread.AnnouncementTable,
Columns: []string{announcementread.AnnouncementColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AnnouncementIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: announcementread.AnnouncementTable,
Columns: []string{announcementread.AnnouncementColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: announcementread.UserTable,
Columns: []string{announcementread.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: announcementread.UserTable,
Columns: []string{announcementread.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{announcementread.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// AnnouncementReadUpdateOne is the builder for updating a single AnnouncementRead entity.
type AnnouncementReadUpdateOne struct {
config
fields []string
hooks []Hook
mutation *AnnouncementReadMutation
}
// SetAnnouncementID sets the "announcement_id" field.
func (_u *AnnouncementReadUpdateOne) SetAnnouncementID(v int64) *AnnouncementReadUpdateOne {
_u.mutation.SetAnnouncementID(v)
return _u
}
// SetNillableAnnouncementID sets the "announcement_id" field if the given value is not nil.
func (_u *AnnouncementReadUpdateOne) SetNillableAnnouncementID(v *int64) *AnnouncementReadUpdateOne {
if v != nil {
_u.SetAnnouncementID(*v)
}
return _u
}
// SetUserID sets the "user_id" field.
func (_u *AnnouncementReadUpdateOne) SetUserID(v int64) *AnnouncementReadUpdateOne {
_u.mutation.SetUserID(v)
return _u
}
// SetNillableUserID sets the "user_id" field if the given value is not nil.
func (_u *AnnouncementReadUpdateOne) SetNillableUserID(v *int64) *AnnouncementReadUpdateOne {
if v != nil {
_u.SetUserID(*v)
}
return _u
}
// SetReadAt sets the "read_at" field.
func (_u *AnnouncementReadUpdateOne) SetReadAt(v time.Time) *AnnouncementReadUpdateOne {
_u.mutation.SetReadAt(v)
return _u
}
// SetNillableReadAt sets the "read_at" field if the given value is not nil.
func (_u *AnnouncementReadUpdateOne) SetNillableReadAt(v *time.Time) *AnnouncementReadUpdateOne {
if v != nil {
_u.SetReadAt(*v)
}
return _u
}
// SetAnnouncement sets the "announcement" edge to the Announcement entity.
func (_u *AnnouncementReadUpdateOne) SetAnnouncement(v *Announcement) *AnnouncementReadUpdateOne {
return _u.SetAnnouncementID(v.ID)
}
// SetUser sets the "user" edge to the User entity.
func (_u *AnnouncementReadUpdateOne) SetUser(v *User) *AnnouncementReadUpdateOne {
return _u.SetUserID(v.ID)
}
// Mutation returns the AnnouncementReadMutation object of the builder.
func (_u *AnnouncementReadUpdateOne) Mutation() *AnnouncementReadMutation {
return _u.mutation
}
// ClearAnnouncement clears the "announcement" edge to the Announcement entity.
func (_u *AnnouncementReadUpdateOne) ClearAnnouncement() *AnnouncementReadUpdateOne {
_u.mutation.ClearAnnouncement()
return _u
}
// ClearUser clears the "user" edge to the User entity.
func (_u *AnnouncementReadUpdateOne) ClearUser() *AnnouncementReadUpdateOne {
_u.mutation.ClearUser()
return _u
}
// Where appends a list predicates to the AnnouncementReadUpdate builder.
func (_u *AnnouncementReadUpdateOne) Where(ps ...predicate.AnnouncementRead) *AnnouncementReadUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *AnnouncementReadUpdateOne) Select(field string, fields ...string) *AnnouncementReadUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated AnnouncementRead entity.
func (_u *AnnouncementReadUpdateOne) Save(ctx context.Context) (*AnnouncementRead, error) {
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *AnnouncementReadUpdateOne) SaveX(ctx context.Context) *AnnouncementRead {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *AnnouncementReadUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *AnnouncementReadUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *AnnouncementReadUpdateOne) check() error {
if _u.mutation.AnnouncementCleared() && len(_u.mutation.AnnouncementIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "AnnouncementRead.announcement"`)
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "AnnouncementRead.user"`)
}
return nil
}
func (_u *AnnouncementReadUpdateOne) sqlSave(ctx context.Context) (_node *AnnouncementRead, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(announcementread.Table, announcementread.Columns, sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AnnouncementRead.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, announcementread.FieldID)
for _, f := range fields {
if !announcementread.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != announcementread.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.ReadAt(); ok {
_spec.SetField(announcementread.FieldReadAt, field.TypeTime, value)
}
if _u.mutation.AnnouncementCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: announcementread.AnnouncementTable,
Columns: []string{announcementread.AnnouncementColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AnnouncementIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: announcementread.AnnouncementTable,
Columns: []string{announcementread.AnnouncementColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcement.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: announcementread.UserTable,
Columns: []string{announcementread.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: announcementread.UserTable,
Columns: []string{announcementread.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &AnnouncementRead{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{announcementread.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}

View File

@@ -40,6 +40,12 @@ type APIKey struct {
IPWhitelist []string `json:"ip_whitelist,omitempty"`
// Blocked IPs/CIDRs
IPBlacklist []string `json:"ip_blacklist,omitempty"`
// Quota limit in USD for this API key (0 = unlimited)
Quota float64 `json:"quota,omitempty"`
// Used quota amount in USD
QuotaUsed float64 `json:"quota_used,omitempty"`
// Expiration time for this API key (null = never expires)
ExpiresAt *time.Time `json:"expires_at,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the APIKeyQuery when eager-loading is set.
Edges APIKeyEdges `json:"edges"`
@@ -97,11 +103,13 @@ func (*APIKey) scanValues(columns []string) ([]any, error) {
switch columns[i] {
case apikey.FieldIPWhitelist, apikey.FieldIPBlacklist:
values[i] = new([]byte)
case apikey.FieldQuota, apikey.FieldQuotaUsed:
values[i] = new(sql.NullFloat64)
case apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID:
values[i] = new(sql.NullInt64)
case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus:
values[i] = new(sql.NullString)
case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt:
case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt, apikey.FieldExpiresAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
@@ -190,6 +198,25 @@ func (_m *APIKey) assignValues(columns []string, values []any) error {
return fmt.Errorf("unmarshal field ip_blacklist: %w", err)
}
}
case apikey.FieldQuota:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field quota", values[i])
} else if value.Valid {
_m.Quota = value.Float64
}
case apikey.FieldQuotaUsed:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field quota_used", values[i])
} else if value.Valid {
_m.QuotaUsed = value.Float64
}
case apikey.FieldExpiresAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field expires_at", values[i])
} else if value.Valid {
_m.ExpiresAt = new(time.Time)
*_m.ExpiresAt = value.Time
}
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -274,6 +301,17 @@ func (_m *APIKey) String() string {
builder.WriteString(", ")
builder.WriteString("ip_blacklist=")
builder.WriteString(fmt.Sprintf("%v", _m.IPBlacklist))
builder.WriteString(", ")
builder.WriteString("quota=")
builder.WriteString(fmt.Sprintf("%v", _m.Quota))
builder.WriteString(", ")
builder.WriteString("quota_used=")
builder.WriteString(fmt.Sprintf("%v", _m.QuotaUsed))
builder.WriteString(", ")
if v := _m.ExpiresAt; v != nil {
builder.WriteString("expires_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteByte(')')
return builder.String()
}

View File

@@ -35,6 +35,12 @@ const (
FieldIPWhitelist = "ip_whitelist"
// FieldIPBlacklist holds the string denoting the ip_blacklist field in the database.
FieldIPBlacklist = "ip_blacklist"
// FieldQuota holds the string denoting the quota field in the database.
FieldQuota = "quota"
// FieldQuotaUsed holds the string denoting the quota_used field in the database.
FieldQuotaUsed = "quota_used"
// FieldExpiresAt holds the string denoting the expires_at field in the database.
FieldExpiresAt = "expires_at"
// EdgeUser holds the string denoting the user edge name in mutations.
EdgeUser = "user"
// EdgeGroup holds the string denoting the group edge name in mutations.
@@ -79,6 +85,9 @@ var Columns = []string{
FieldStatus,
FieldIPWhitelist,
FieldIPBlacklist,
FieldQuota,
FieldQuotaUsed,
FieldExpiresAt,
}
// ValidColumn reports if the column name is valid (part of the table columns).
@@ -113,6 +122,10 @@ var (
DefaultStatus string
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
StatusValidator func(string) error
// DefaultQuota holds the default value on creation for the "quota" field.
DefaultQuota float64
// DefaultQuotaUsed holds the default value on creation for the "quota_used" field.
DefaultQuotaUsed float64
)
// OrderOption defines the ordering options for the APIKey queries.
@@ -163,6 +176,21 @@ func ByStatus(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStatus, opts...).ToFunc()
}
// ByQuota orders the results by the quota field.
func ByQuota(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldQuota, opts...).ToFunc()
}
// ByQuotaUsed orders the results by the quota_used field.
func ByQuotaUsed(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldQuotaUsed, opts...).ToFunc()
}
// ByExpiresAt orders the results by the expires_at field.
func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldExpiresAt, opts...).ToFunc()
}
// ByUserField orders the results by user field.
func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {

View File

@@ -95,6 +95,21 @@ func Status(v string) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldStatus, v))
}
// Quota applies equality check predicate on the "quota" field. It's identical to QuotaEQ.
func Quota(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldQuota, v))
}
// QuotaUsed applies equality check predicate on the "quota_used" field. It's identical to QuotaUsedEQ.
func QuotaUsed(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldQuotaUsed, v))
}
// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ.
func ExpiresAt(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldExpiresAt, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldCreatedAt, v))
@@ -490,6 +505,136 @@ func IPBlacklistNotNil() predicate.APIKey {
return predicate.APIKey(sql.FieldNotNull(FieldIPBlacklist))
}
// QuotaEQ applies the EQ predicate on the "quota" field.
func QuotaEQ(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldQuota, v))
}
// QuotaNEQ applies the NEQ predicate on the "quota" field.
func QuotaNEQ(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldNEQ(FieldQuota, v))
}
// QuotaIn applies the In predicate on the "quota" field.
func QuotaIn(vs ...float64) predicate.APIKey {
return predicate.APIKey(sql.FieldIn(FieldQuota, vs...))
}
// QuotaNotIn applies the NotIn predicate on the "quota" field.
func QuotaNotIn(vs ...float64) predicate.APIKey {
return predicate.APIKey(sql.FieldNotIn(FieldQuota, vs...))
}
// QuotaGT applies the GT predicate on the "quota" field.
func QuotaGT(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldGT(FieldQuota, v))
}
// QuotaGTE applies the GTE predicate on the "quota" field.
func QuotaGTE(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldGTE(FieldQuota, v))
}
// QuotaLT applies the LT predicate on the "quota" field.
func QuotaLT(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldLT(FieldQuota, v))
}
// QuotaLTE applies the LTE predicate on the "quota" field.
func QuotaLTE(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldLTE(FieldQuota, v))
}
// QuotaUsedEQ applies the EQ predicate on the "quota_used" field.
func QuotaUsedEQ(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldQuotaUsed, v))
}
// QuotaUsedNEQ applies the NEQ predicate on the "quota_used" field.
func QuotaUsedNEQ(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldNEQ(FieldQuotaUsed, v))
}
// QuotaUsedIn applies the In predicate on the "quota_used" field.
func QuotaUsedIn(vs ...float64) predicate.APIKey {
return predicate.APIKey(sql.FieldIn(FieldQuotaUsed, vs...))
}
// QuotaUsedNotIn applies the NotIn predicate on the "quota_used" field.
func QuotaUsedNotIn(vs ...float64) predicate.APIKey {
return predicate.APIKey(sql.FieldNotIn(FieldQuotaUsed, vs...))
}
// QuotaUsedGT applies the GT predicate on the "quota_used" field.
func QuotaUsedGT(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldGT(FieldQuotaUsed, v))
}
// QuotaUsedGTE applies the GTE predicate on the "quota_used" field.
func QuotaUsedGTE(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldGTE(FieldQuotaUsed, v))
}
// QuotaUsedLT applies the LT predicate on the "quota_used" field.
func QuotaUsedLT(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldLT(FieldQuotaUsed, v))
}
// QuotaUsedLTE applies the LTE predicate on the "quota_used" field.
func QuotaUsedLTE(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldLTE(FieldQuotaUsed, v))
}
// ExpiresAtEQ applies the EQ predicate on the "expires_at" field.
func ExpiresAtEQ(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldExpiresAt, v))
}
// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field.
func ExpiresAtNEQ(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldNEQ(FieldExpiresAt, v))
}
// ExpiresAtIn applies the In predicate on the "expires_at" field.
func ExpiresAtIn(vs ...time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldIn(FieldExpiresAt, vs...))
}
// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field.
func ExpiresAtNotIn(vs ...time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldNotIn(FieldExpiresAt, vs...))
}
// ExpiresAtGT applies the GT predicate on the "expires_at" field.
func ExpiresAtGT(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldGT(FieldExpiresAt, v))
}
// ExpiresAtGTE applies the GTE predicate on the "expires_at" field.
func ExpiresAtGTE(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldGTE(FieldExpiresAt, v))
}
// ExpiresAtLT applies the LT predicate on the "expires_at" field.
func ExpiresAtLT(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldLT(FieldExpiresAt, v))
}
// ExpiresAtLTE applies the LTE predicate on the "expires_at" field.
func ExpiresAtLTE(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldLTE(FieldExpiresAt, v))
}
// ExpiresAtIsNil applies the IsNil predicate on the "expires_at" field.
func ExpiresAtIsNil() predicate.APIKey {
return predicate.APIKey(sql.FieldIsNull(FieldExpiresAt))
}
// ExpiresAtNotNil applies the NotNil predicate on the "expires_at" field.
func ExpiresAtNotNil() predicate.APIKey {
return predicate.APIKey(sql.FieldNotNull(FieldExpiresAt))
}
// HasUser applies the HasEdge predicate on the "user" edge.
func HasUser() predicate.APIKey {
return predicate.APIKey(func(s *sql.Selector) {

View File

@@ -125,6 +125,48 @@ func (_c *APIKeyCreate) SetIPBlacklist(v []string) *APIKeyCreate {
return _c
}
// SetQuota sets the "quota" field.
func (_c *APIKeyCreate) SetQuota(v float64) *APIKeyCreate {
_c.mutation.SetQuota(v)
return _c
}
// SetNillableQuota sets the "quota" field if the given value is not nil.
func (_c *APIKeyCreate) SetNillableQuota(v *float64) *APIKeyCreate {
if v != nil {
_c.SetQuota(*v)
}
return _c
}
// SetQuotaUsed sets the "quota_used" field.
func (_c *APIKeyCreate) SetQuotaUsed(v float64) *APIKeyCreate {
_c.mutation.SetQuotaUsed(v)
return _c
}
// SetNillableQuotaUsed sets the "quota_used" field if the given value is not nil.
func (_c *APIKeyCreate) SetNillableQuotaUsed(v *float64) *APIKeyCreate {
if v != nil {
_c.SetQuotaUsed(*v)
}
return _c
}
// SetExpiresAt sets the "expires_at" field.
func (_c *APIKeyCreate) SetExpiresAt(v time.Time) *APIKeyCreate {
_c.mutation.SetExpiresAt(v)
return _c
}
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
func (_c *APIKeyCreate) SetNillableExpiresAt(v *time.Time) *APIKeyCreate {
if v != nil {
_c.SetExpiresAt(*v)
}
return _c
}
// SetUser sets the "user" edge to the User entity.
func (_c *APIKeyCreate) SetUser(v *User) *APIKeyCreate {
return _c.SetUserID(v.ID)
@@ -205,6 +247,14 @@ func (_c *APIKeyCreate) defaults() error {
v := apikey.DefaultStatus
_c.mutation.SetStatus(v)
}
if _, ok := _c.mutation.Quota(); !ok {
v := apikey.DefaultQuota
_c.mutation.SetQuota(v)
}
if _, ok := _c.mutation.QuotaUsed(); !ok {
v := apikey.DefaultQuotaUsed
_c.mutation.SetQuotaUsed(v)
}
return nil
}
@@ -243,6 +293,12 @@ func (_c *APIKeyCreate) check() error {
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "APIKey.status": %w`, err)}
}
}
if _, ok := _c.mutation.Quota(); !ok {
return &ValidationError{Name: "quota", err: errors.New(`ent: missing required field "APIKey.quota"`)}
}
if _, ok := _c.mutation.QuotaUsed(); !ok {
return &ValidationError{Name: "quota_used", err: errors.New(`ent: missing required field "APIKey.quota_used"`)}
}
if len(_c.mutation.UserIDs()) == 0 {
return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "APIKey.user"`)}
}
@@ -305,6 +361,18 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) {
_spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value)
_node.IPBlacklist = value
}
if value, ok := _c.mutation.Quota(); ok {
_spec.SetField(apikey.FieldQuota, field.TypeFloat64, value)
_node.Quota = value
}
if value, ok := _c.mutation.QuotaUsed(); ok {
_spec.SetField(apikey.FieldQuotaUsed, field.TypeFloat64, value)
_node.QuotaUsed = value
}
if value, ok := _c.mutation.ExpiresAt(); ok {
_spec.SetField(apikey.FieldExpiresAt, field.TypeTime, value)
_node.ExpiresAt = &value
}
if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
@@ -539,6 +607,60 @@ func (u *APIKeyUpsert) ClearIPBlacklist() *APIKeyUpsert {
return u
}
// SetQuota sets the "quota" field.
func (u *APIKeyUpsert) SetQuota(v float64) *APIKeyUpsert {
u.Set(apikey.FieldQuota, v)
return u
}
// UpdateQuota sets the "quota" field to the value that was provided on create.
func (u *APIKeyUpsert) UpdateQuota() *APIKeyUpsert {
u.SetExcluded(apikey.FieldQuota)
return u
}
// AddQuota adds v to the "quota" field.
func (u *APIKeyUpsert) AddQuota(v float64) *APIKeyUpsert {
u.Add(apikey.FieldQuota, v)
return u
}
// SetQuotaUsed sets the "quota_used" field.
func (u *APIKeyUpsert) SetQuotaUsed(v float64) *APIKeyUpsert {
u.Set(apikey.FieldQuotaUsed, v)
return u
}
// UpdateQuotaUsed sets the "quota_used" field to the value that was provided on create.
func (u *APIKeyUpsert) UpdateQuotaUsed() *APIKeyUpsert {
u.SetExcluded(apikey.FieldQuotaUsed)
return u
}
// AddQuotaUsed adds v to the "quota_used" field.
func (u *APIKeyUpsert) AddQuotaUsed(v float64) *APIKeyUpsert {
u.Add(apikey.FieldQuotaUsed, v)
return u
}
// SetExpiresAt sets the "expires_at" field.
func (u *APIKeyUpsert) SetExpiresAt(v time.Time) *APIKeyUpsert {
u.Set(apikey.FieldExpiresAt, v)
return u
}
// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
func (u *APIKeyUpsert) UpdateExpiresAt() *APIKeyUpsert {
u.SetExcluded(apikey.FieldExpiresAt)
return u
}
// ClearExpiresAt clears the value of the "expires_at" field.
func (u *APIKeyUpsert) ClearExpiresAt() *APIKeyUpsert {
u.SetNull(apikey.FieldExpiresAt)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -738,6 +860,69 @@ func (u *APIKeyUpsertOne) ClearIPBlacklist() *APIKeyUpsertOne {
})
}
// SetQuota sets the "quota" field.
func (u *APIKeyUpsertOne) SetQuota(v float64) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.SetQuota(v)
})
}
// AddQuota adds v to the "quota" field.
func (u *APIKeyUpsertOne) AddQuota(v float64) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.AddQuota(v)
})
}
// UpdateQuota sets the "quota" field to the value that was provided on create.
func (u *APIKeyUpsertOne) UpdateQuota() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateQuota()
})
}
// SetQuotaUsed sets the "quota_used" field.
func (u *APIKeyUpsertOne) SetQuotaUsed(v float64) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.SetQuotaUsed(v)
})
}
// AddQuotaUsed adds v to the "quota_used" field.
func (u *APIKeyUpsertOne) AddQuotaUsed(v float64) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.AddQuotaUsed(v)
})
}
// UpdateQuotaUsed sets the "quota_used" field to the value that was provided on create.
func (u *APIKeyUpsertOne) UpdateQuotaUsed() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateQuotaUsed()
})
}
// SetExpiresAt sets the "expires_at" field.
func (u *APIKeyUpsertOne) SetExpiresAt(v time.Time) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.SetExpiresAt(v)
})
}
// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
func (u *APIKeyUpsertOne) UpdateExpiresAt() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateExpiresAt()
})
}
// ClearExpiresAt clears the value of the "expires_at" field.
func (u *APIKeyUpsertOne) ClearExpiresAt() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.ClearExpiresAt()
})
}
// Exec executes the query.
func (u *APIKeyUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -1103,6 +1288,69 @@ func (u *APIKeyUpsertBulk) ClearIPBlacklist() *APIKeyUpsertBulk {
})
}
// SetQuota sets the "quota" field.
func (u *APIKeyUpsertBulk) SetQuota(v float64) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.SetQuota(v)
})
}
// AddQuota adds v to the "quota" field.
func (u *APIKeyUpsertBulk) AddQuota(v float64) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.AddQuota(v)
})
}
// UpdateQuota sets the "quota" field to the value that was provided on create.
func (u *APIKeyUpsertBulk) UpdateQuota() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateQuota()
})
}
// SetQuotaUsed sets the "quota_used" field.
func (u *APIKeyUpsertBulk) SetQuotaUsed(v float64) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.SetQuotaUsed(v)
})
}
// AddQuotaUsed adds v to the "quota_used" field.
func (u *APIKeyUpsertBulk) AddQuotaUsed(v float64) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.AddQuotaUsed(v)
})
}
// UpdateQuotaUsed sets the "quota_used" field to the value that was provided on create.
func (u *APIKeyUpsertBulk) UpdateQuotaUsed() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateQuotaUsed()
})
}
// SetExpiresAt sets the "expires_at" field.
func (u *APIKeyUpsertBulk) SetExpiresAt(v time.Time) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.SetExpiresAt(v)
})
}
// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
func (u *APIKeyUpsertBulk) UpdateExpiresAt() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateExpiresAt()
})
}
// ClearExpiresAt clears the value of the "expires_at" field.
func (u *APIKeyUpsertBulk) ClearExpiresAt() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.ClearExpiresAt()
})
}
// Exec executes the query.
func (u *APIKeyUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {

View File

@@ -170,6 +170,68 @@ func (_u *APIKeyUpdate) ClearIPBlacklist() *APIKeyUpdate {
return _u
}
// SetQuota sets the "quota" field.
func (_u *APIKeyUpdate) SetQuota(v float64) *APIKeyUpdate {
_u.mutation.ResetQuota()
_u.mutation.SetQuota(v)
return _u
}
// SetNillableQuota sets the "quota" field if the given value is not nil.
func (_u *APIKeyUpdate) SetNillableQuota(v *float64) *APIKeyUpdate {
if v != nil {
_u.SetQuota(*v)
}
return _u
}
// AddQuota adds value to the "quota" field.
func (_u *APIKeyUpdate) AddQuota(v float64) *APIKeyUpdate {
_u.mutation.AddQuota(v)
return _u
}
// SetQuotaUsed sets the "quota_used" field.
func (_u *APIKeyUpdate) SetQuotaUsed(v float64) *APIKeyUpdate {
_u.mutation.ResetQuotaUsed()
_u.mutation.SetQuotaUsed(v)
return _u
}
// SetNillableQuotaUsed sets the "quota_used" field if the given value is not nil.
func (_u *APIKeyUpdate) SetNillableQuotaUsed(v *float64) *APIKeyUpdate {
if v != nil {
_u.SetQuotaUsed(*v)
}
return _u
}
// AddQuotaUsed adds value to the "quota_used" field.
func (_u *APIKeyUpdate) AddQuotaUsed(v float64) *APIKeyUpdate {
_u.mutation.AddQuotaUsed(v)
return _u
}
// SetExpiresAt sets the "expires_at" field.
func (_u *APIKeyUpdate) SetExpiresAt(v time.Time) *APIKeyUpdate {
_u.mutation.SetExpiresAt(v)
return _u
}
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
func (_u *APIKeyUpdate) SetNillableExpiresAt(v *time.Time) *APIKeyUpdate {
if v != nil {
_u.SetExpiresAt(*v)
}
return _u
}
// ClearExpiresAt clears the value of the "expires_at" field.
func (_u *APIKeyUpdate) ClearExpiresAt() *APIKeyUpdate {
_u.mutation.ClearExpiresAt()
return _u
}
// SetUser sets the "user" edge to the User entity.
func (_u *APIKeyUpdate) SetUser(v *User) *APIKeyUpdate {
return _u.SetUserID(v.ID)
@@ -350,6 +412,24 @@ func (_u *APIKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.IPBlacklistCleared() {
_spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON)
}
if value, ok := _u.mutation.Quota(); ok {
_spec.SetField(apikey.FieldQuota, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedQuota(); ok {
_spec.AddField(apikey.FieldQuota, field.TypeFloat64, value)
}
if value, ok := _u.mutation.QuotaUsed(); ok {
_spec.SetField(apikey.FieldQuotaUsed, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedQuotaUsed(); ok {
_spec.AddField(apikey.FieldQuotaUsed, field.TypeFloat64, value)
}
if value, ok := _u.mutation.ExpiresAt(); ok {
_spec.SetField(apikey.FieldExpiresAt, field.TypeTime, value)
}
if _u.mutation.ExpiresAtCleared() {
_spec.ClearField(apikey.FieldExpiresAt, field.TypeTime)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
@@ -611,6 +691,68 @@ func (_u *APIKeyUpdateOne) ClearIPBlacklist() *APIKeyUpdateOne {
return _u
}
// SetQuota sets the "quota" field.
func (_u *APIKeyUpdateOne) SetQuota(v float64) *APIKeyUpdateOne {
_u.mutation.ResetQuota()
_u.mutation.SetQuota(v)
return _u
}
// SetNillableQuota sets the "quota" field if the given value is not nil.
func (_u *APIKeyUpdateOne) SetNillableQuota(v *float64) *APIKeyUpdateOne {
if v != nil {
_u.SetQuota(*v)
}
return _u
}
// AddQuota adds value to the "quota" field.
func (_u *APIKeyUpdateOne) AddQuota(v float64) *APIKeyUpdateOne {
_u.mutation.AddQuota(v)
return _u
}
// SetQuotaUsed sets the "quota_used" field.
func (_u *APIKeyUpdateOne) SetQuotaUsed(v float64) *APIKeyUpdateOne {
_u.mutation.ResetQuotaUsed()
_u.mutation.SetQuotaUsed(v)
return _u
}
// SetNillableQuotaUsed sets the "quota_used" field if the given value is not nil.
func (_u *APIKeyUpdateOne) SetNillableQuotaUsed(v *float64) *APIKeyUpdateOne {
if v != nil {
_u.SetQuotaUsed(*v)
}
return _u
}
// AddQuotaUsed adds value to the "quota_used" field.
func (_u *APIKeyUpdateOne) AddQuotaUsed(v float64) *APIKeyUpdateOne {
_u.mutation.AddQuotaUsed(v)
return _u
}
// SetExpiresAt sets the "expires_at" field.
func (_u *APIKeyUpdateOne) SetExpiresAt(v time.Time) *APIKeyUpdateOne {
_u.mutation.SetExpiresAt(v)
return _u
}
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
func (_u *APIKeyUpdateOne) SetNillableExpiresAt(v *time.Time) *APIKeyUpdateOne {
if v != nil {
_u.SetExpiresAt(*v)
}
return _u
}
// ClearExpiresAt clears the value of the "expires_at" field.
func (_u *APIKeyUpdateOne) ClearExpiresAt() *APIKeyUpdateOne {
_u.mutation.ClearExpiresAt()
return _u
}
// SetUser sets the "user" edge to the User entity.
func (_u *APIKeyUpdateOne) SetUser(v *User) *APIKeyUpdateOne {
return _u.SetUserID(v.ID)
@@ -821,6 +963,24 @@ func (_u *APIKeyUpdateOne) sqlSave(ctx context.Context) (_node *APIKey, err erro
if _u.mutation.IPBlacklistCleared() {
_spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON)
}
if value, ok := _u.mutation.Quota(); ok {
_spec.SetField(apikey.FieldQuota, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedQuota(); ok {
_spec.AddField(apikey.FieldQuota, field.TypeFloat64, value)
}
if value, ok := _u.mutation.QuotaUsed(); ok {
_spec.SetField(apikey.FieldQuotaUsed, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedQuotaUsed(); ok {
_spec.AddField(apikey.FieldQuotaUsed, field.TypeFloat64, value)
}
if value, ok := _u.mutation.ExpiresAt(); ok {
_spec.SetField(apikey.FieldExpiresAt, field.TypeTime, value)
}
if _u.mutation.ExpiresAtCleared() {
_spec.ClearField(apikey.FieldExpiresAt, field.TypeTime)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,

View File

@@ -17,6 +17,8 @@ import (
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/account"
"github.com/Wei-Shaw/sub2api/ent/accountgroup"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/promocode"
@@ -46,6 +48,10 @@ type Client struct {
Account *AccountClient
// AccountGroup is the client for interacting with the AccountGroup builders.
AccountGroup *AccountGroupClient
// Announcement is the client for interacting with the Announcement builders.
Announcement *AnnouncementClient
// AnnouncementRead is the client for interacting with the AnnouncementRead builders.
AnnouncementRead *AnnouncementReadClient
// Group is the client for interacting with the Group builders.
Group *GroupClient
// PromoCode is the client for interacting with the PromoCode builders.
@@ -86,6 +92,8 @@ func (c *Client) init() {
c.APIKey = NewAPIKeyClient(c.config)
c.Account = NewAccountClient(c.config)
c.AccountGroup = NewAccountGroupClient(c.config)
c.Announcement = NewAnnouncementClient(c.config)
c.AnnouncementRead = NewAnnouncementReadClient(c.config)
c.Group = NewGroupClient(c.config)
c.PromoCode = NewPromoCodeClient(c.config)
c.PromoCodeUsage = NewPromoCodeUsageClient(c.config)
@@ -194,6 +202,8 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
APIKey: NewAPIKeyClient(cfg),
Account: NewAccountClient(cfg),
AccountGroup: NewAccountGroupClient(cfg),
Announcement: NewAnnouncementClient(cfg),
AnnouncementRead: NewAnnouncementReadClient(cfg),
Group: NewGroupClient(cfg),
PromoCode: NewPromoCodeClient(cfg),
PromoCodeUsage: NewPromoCodeUsageClient(cfg),
@@ -229,6 +239,8 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
APIKey: NewAPIKeyClient(cfg),
Account: NewAccountClient(cfg),
AccountGroup: NewAccountGroupClient(cfg),
Announcement: NewAnnouncementClient(cfg),
AnnouncementRead: NewAnnouncementReadClient(cfg),
Group: NewGroupClient(cfg),
PromoCode: NewPromoCodeClient(cfg),
PromoCodeUsage: NewPromoCodeUsageClient(cfg),
@@ -271,10 +283,10 @@ func (c *Client) Close() error {
// In order to add hooks to a specific client, call: `client.Node.Use(...)`.
func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{
c.APIKey, c.Account, c.AccountGroup, c.Group, c.PromoCode, c.PromoCodeUsage,
c.Proxy, c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User,
c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.Setting,
c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup,
c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription,
} {
n.Use(hooks...)
}
@@ -284,10 +296,10 @@ func (c *Client) Use(hooks ...Hook) {
// In order to add interceptors to a specific client, call: `client.Node.Intercept(...)`.
func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{
c.APIKey, c.Account, c.AccountGroup, c.Group, c.PromoCode, c.PromoCodeUsage,
c.Proxy, c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User,
c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.Setting,
c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup,
c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription,
} {
n.Intercept(interceptors...)
}
@@ -302,6 +314,10 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.Account.mutate(ctx, m)
case *AccountGroupMutation:
return c.AccountGroup.mutate(ctx, m)
case *AnnouncementMutation:
return c.Announcement.mutate(ctx, m)
case *AnnouncementReadMutation:
return c.AnnouncementRead.mutate(ctx, m)
case *GroupMutation:
return c.Group.mutate(ctx, m)
case *PromoCodeMutation:
@@ -831,6 +847,320 @@ func (c *AccountGroupClient) mutate(ctx context.Context, m *AccountGroupMutation
}
}
// AnnouncementClient is a client for the Announcement schema.
type AnnouncementClient struct {
config
}
// NewAnnouncementClient returns a client for the Announcement from the given config.
func NewAnnouncementClient(c config) *AnnouncementClient {
return &AnnouncementClient{config: c}
}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `announcement.Hooks(f(g(h())))`.
func (c *AnnouncementClient) Use(hooks ...Hook) {
c.hooks.Announcement = append(c.hooks.Announcement, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `announcement.Intercept(f(g(h())))`.
func (c *AnnouncementClient) Intercept(interceptors ...Interceptor) {
c.inters.Announcement = append(c.inters.Announcement, interceptors...)
}
// Create returns a builder for creating a Announcement entity.
func (c *AnnouncementClient) Create() *AnnouncementCreate {
mutation := newAnnouncementMutation(c.config, OpCreate)
return &AnnouncementCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of Announcement entities.
func (c *AnnouncementClient) CreateBulk(builders ...*AnnouncementCreate) *AnnouncementCreateBulk {
return &AnnouncementCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *AnnouncementClient) MapCreateBulk(slice any, setFunc func(*AnnouncementCreate, int)) *AnnouncementCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &AnnouncementCreateBulk{err: fmt.Errorf("calling to AnnouncementClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*AnnouncementCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &AnnouncementCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for Announcement.
func (c *AnnouncementClient) Update() *AnnouncementUpdate {
mutation := newAnnouncementMutation(c.config, OpUpdate)
return &AnnouncementUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *AnnouncementClient) UpdateOne(_m *Announcement) *AnnouncementUpdateOne {
mutation := newAnnouncementMutation(c.config, OpUpdateOne, withAnnouncement(_m))
return &AnnouncementUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOneID returns an update builder for the given id.
func (c *AnnouncementClient) UpdateOneID(id int64) *AnnouncementUpdateOne {
mutation := newAnnouncementMutation(c.config, OpUpdateOne, withAnnouncementID(id))
return &AnnouncementUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// Delete returns a delete builder for Announcement.
func (c *AnnouncementClient) Delete() *AnnouncementDelete {
mutation := newAnnouncementMutation(c.config, OpDelete)
return &AnnouncementDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// DeleteOne returns a builder for deleting the given entity.
func (c *AnnouncementClient) DeleteOne(_m *Announcement) *AnnouncementDeleteOne {
return c.DeleteOneID(_m.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *AnnouncementClient) DeleteOneID(id int64) *AnnouncementDeleteOne {
builder := c.Delete().Where(announcement.ID(id))
builder.mutation.id = &id
builder.mutation.op = OpDeleteOne
return &AnnouncementDeleteOne{builder}
}
// Query returns a query builder for Announcement.
func (c *AnnouncementClient) Query() *AnnouncementQuery {
return &AnnouncementQuery{
config: c.config,
ctx: &QueryContext{Type: TypeAnnouncement},
inters: c.Interceptors(),
}
}
// Get returns a Announcement entity by its id.
func (c *AnnouncementClient) Get(ctx context.Context, id int64) (*Announcement, error) {
return c.Query().Where(announcement.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *AnnouncementClient) GetX(ctx context.Context, id int64) *Announcement {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
// QueryReads queries the reads edge of a Announcement.
func (c *AnnouncementClient) QueryReads(_m *Announcement) *AnnouncementReadQuery {
query := (&AnnouncementReadClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(announcement.Table, announcement.FieldID, id),
sqlgraph.To(announcementread.Table, announcementread.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, announcement.ReadsTable, announcement.ReadsColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// Hooks returns the client hooks.
func (c *AnnouncementClient) Hooks() []Hook {
return c.hooks.Announcement
}
// Interceptors returns the client interceptors.
func (c *AnnouncementClient) Interceptors() []Interceptor {
return c.inters.Announcement
}
func (c *AnnouncementClient) mutate(ctx context.Context, m *AnnouncementMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&AnnouncementCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&AnnouncementUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&AnnouncementUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&AnnouncementDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown Announcement mutation op: %q", m.Op())
}
}
// AnnouncementReadClient is a client for the AnnouncementRead schema.
type AnnouncementReadClient struct {
config
}
// NewAnnouncementReadClient returns a client for the AnnouncementRead from the given config.
func NewAnnouncementReadClient(c config) *AnnouncementReadClient {
return &AnnouncementReadClient{config: c}
}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `announcementread.Hooks(f(g(h())))`.
func (c *AnnouncementReadClient) Use(hooks ...Hook) {
c.hooks.AnnouncementRead = append(c.hooks.AnnouncementRead, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `announcementread.Intercept(f(g(h())))`.
func (c *AnnouncementReadClient) Intercept(interceptors ...Interceptor) {
c.inters.AnnouncementRead = append(c.inters.AnnouncementRead, interceptors...)
}
// Create returns a builder for creating a AnnouncementRead entity.
func (c *AnnouncementReadClient) Create() *AnnouncementReadCreate {
mutation := newAnnouncementReadMutation(c.config, OpCreate)
return &AnnouncementReadCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of AnnouncementRead entities.
func (c *AnnouncementReadClient) CreateBulk(builders ...*AnnouncementReadCreate) *AnnouncementReadCreateBulk {
return &AnnouncementReadCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *AnnouncementReadClient) MapCreateBulk(slice any, setFunc func(*AnnouncementReadCreate, int)) *AnnouncementReadCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &AnnouncementReadCreateBulk{err: fmt.Errorf("calling to AnnouncementReadClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*AnnouncementReadCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &AnnouncementReadCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for AnnouncementRead.
func (c *AnnouncementReadClient) Update() *AnnouncementReadUpdate {
mutation := newAnnouncementReadMutation(c.config, OpUpdate)
return &AnnouncementReadUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *AnnouncementReadClient) UpdateOne(_m *AnnouncementRead) *AnnouncementReadUpdateOne {
mutation := newAnnouncementReadMutation(c.config, OpUpdateOne, withAnnouncementRead(_m))
return &AnnouncementReadUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOneID returns an update builder for the given id.
func (c *AnnouncementReadClient) UpdateOneID(id int64) *AnnouncementReadUpdateOne {
mutation := newAnnouncementReadMutation(c.config, OpUpdateOne, withAnnouncementReadID(id))
return &AnnouncementReadUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// Delete returns a delete builder for AnnouncementRead.
func (c *AnnouncementReadClient) Delete() *AnnouncementReadDelete {
mutation := newAnnouncementReadMutation(c.config, OpDelete)
return &AnnouncementReadDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// DeleteOne returns a builder for deleting the given entity.
func (c *AnnouncementReadClient) DeleteOne(_m *AnnouncementRead) *AnnouncementReadDeleteOne {
return c.DeleteOneID(_m.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *AnnouncementReadClient) DeleteOneID(id int64) *AnnouncementReadDeleteOne {
builder := c.Delete().Where(announcementread.ID(id))
builder.mutation.id = &id
builder.mutation.op = OpDeleteOne
return &AnnouncementReadDeleteOne{builder}
}
// Query returns a query builder for AnnouncementRead.
func (c *AnnouncementReadClient) Query() *AnnouncementReadQuery {
return &AnnouncementReadQuery{
config: c.config,
ctx: &QueryContext{Type: TypeAnnouncementRead},
inters: c.Interceptors(),
}
}
// Get returns a AnnouncementRead entity by its id.
func (c *AnnouncementReadClient) Get(ctx context.Context, id int64) (*AnnouncementRead, error) {
return c.Query().Where(announcementread.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *AnnouncementReadClient) GetX(ctx context.Context, id int64) *AnnouncementRead {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
// QueryAnnouncement queries the announcement edge of a AnnouncementRead.
func (c *AnnouncementReadClient) QueryAnnouncement(_m *AnnouncementRead) *AnnouncementQuery {
query := (&AnnouncementClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(announcementread.Table, announcementread.FieldID, id),
sqlgraph.To(announcement.Table, announcement.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, announcementread.AnnouncementTable, announcementread.AnnouncementColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// QueryUser queries the user edge of a AnnouncementRead.
func (c *AnnouncementReadClient) QueryUser(_m *AnnouncementRead) *UserQuery {
query := (&UserClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(announcementread.Table, announcementread.FieldID, id),
sqlgraph.To(user.Table, user.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, announcementread.UserTable, announcementread.UserColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// Hooks returns the client hooks.
func (c *AnnouncementReadClient) Hooks() []Hook {
return c.hooks.AnnouncementRead
}
// Interceptors returns the client interceptors.
func (c *AnnouncementReadClient) Interceptors() []Interceptor {
return c.inters.AnnouncementRead
}
func (c *AnnouncementReadClient) mutate(ctx context.Context, m *AnnouncementReadMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&AnnouncementReadCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&AnnouncementReadUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&AnnouncementReadUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&AnnouncementReadDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown AnnouncementRead mutation op: %q", m.Op())
}
}
// GroupClient is a client for the Group schema.
type GroupClient struct {
config
@@ -2375,6 +2705,22 @@ func (c *UserClient) QueryAssignedSubscriptions(_m *User) *UserSubscriptionQuery
return query
}
// QueryAnnouncementReads queries the announcement_reads edge of a User.
func (c *UserClient) QueryAnnouncementReads(_m *User) *AnnouncementReadQuery {
query := (&AnnouncementReadClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, id),
sqlgraph.To(announcementread.Table, announcementread.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.AnnouncementReadsTable, user.AnnouncementReadsColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// QueryAllowedGroups queries the allowed_groups edge of a User.
func (c *UserClient) QueryAllowedGroups(_m *User) *GroupQuery {
query := (&GroupClient{config: c.config}).Query()
@@ -3116,14 +3462,16 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
// hooks and interceptors per client, for fast access.
type (
hooks struct {
APIKey, Account, AccountGroup, Group, PromoCode, PromoCodeUsage, Proxy,
RedeemCode, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook
APIKey, Account, AccountGroup, Announcement, AnnouncementRead, Group, PromoCode,
PromoCodeUsage, Proxy, RedeemCode, Setting, UsageCleanupTask, UsageLog, User,
UserAllowedGroup, UserAttributeDefinition, UserAttributeValue,
UserSubscription []ent.Hook
}
inters struct {
APIKey, Account, AccountGroup, Group, PromoCode, PromoCodeUsage, Proxy,
RedeemCode, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor
APIKey, Account, AccountGroup, Announcement, AnnouncementRead, Group, PromoCode,
PromoCodeUsage, Proxy, RedeemCode, Setting, UsageCleanupTask, UsageLog, User,
UserAllowedGroup, UserAttributeDefinition, UserAttributeValue,
UserSubscription []ent.Interceptor
}
)

View File

@@ -14,6 +14,8 @@ import (
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/account"
"github.com/Wei-Shaw/sub2api/ent/accountgroup"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/promocode"
@@ -91,6 +93,8 @@ func checkColumn(t, c string) error {
apikey.Table: apikey.ValidColumn,
account.Table: account.ValidColumn,
accountgroup.Table: accountgroup.ValidColumn,
announcement.Table: announcement.ValidColumn,
announcementread.Table: announcementread.ValidColumn,
group.Table: group.ValidColumn,
promocode.Table: promocode.ValidColumn,
promocodeusage.Table: promocodeusage.ValidColumn,

View File

@@ -56,10 +56,16 @@ type Group struct {
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
// 非 Claude Code 请求降级使用的分组 ID
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
// 无效请求兜底使用的分组 ID
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
// 模型路由配置:模型模式 -> 优先账号ID列表
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
// 是否启用模型路由配置
ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"`
// 是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)
McpXMLInject bool `json:"mcp_xml_inject,omitempty"`
// 支持的模型系列claude, gemini_text, gemini_image
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the GroupQuery when eager-loading is set.
Edges GroupEdges `json:"edges"`
@@ -166,13 +172,13 @@ func (*Group) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case group.FieldModelRouting:
case group.FieldModelRouting, group.FieldSupportedModelScopes:
values[i] = new([]byte)
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled:
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject:
values[i] = new(sql.NullBool)
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
values[i] = new(sql.NullFloat64)
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID:
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest:
values[i] = new(sql.NullInt64)
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType:
values[i] = new(sql.NullString)
@@ -322,6 +328,13 @@ func (_m *Group) assignValues(columns []string, values []any) error {
_m.FallbackGroupID = new(int64)
*_m.FallbackGroupID = value.Int64
}
case group.FieldFallbackGroupIDOnInvalidRequest:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field fallback_group_id_on_invalid_request", values[i])
} else if value.Valid {
_m.FallbackGroupIDOnInvalidRequest = new(int64)
*_m.FallbackGroupIDOnInvalidRequest = value.Int64
}
case group.FieldModelRouting:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field model_routing", values[i])
@@ -336,6 +349,20 @@ func (_m *Group) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.ModelRoutingEnabled = value.Bool
}
case group.FieldMcpXMLInject:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field mcp_xml_inject", values[i])
} else if value.Valid {
_m.McpXMLInject = value.Bool
}
case group.FieldSupportedModelScopes:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field supported_model_scopes", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.SupportedModelScopes); err != nil {
return fmt.Errorf("unmarshal field supported_model_scopes: %w", err)
}
}
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -487,11 +514,22 @@ func (_m *Group) String() string {
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
if v := _m.FallbackGroupIDOnInvalidRequest; v != nil {
builder.WriteString("fallback_group_id_on_invalid_request=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("model_routing=")
builder.WriteString(fmt.Sprintf("%v", _m.ModelRouting))
builder.WriteString(", ")
builder.WriteString("model_routing_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.ModelRoutingEnabled))
builder.WriteString(", ")
builder.WriteString("mcp_xml_inject=")
builder.WriteString(fmt.Sprintf("%v", _m.McpXMLInject))
builder.WriteString(", ")
builder.WriteString("supported_model_scopes=")
builder.WriteString(fmt.Sprintf("%v", _m.SupportedModelScopes))
builder.WriteByte(')')
return builder.String()
}

View File

@@ -53,10 +53,16 @@ const (
FieldClaudeCodeOnly = "claude_code_only"
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
FieldFallbackGroupID = "fallback_group_id"
// FieldFallbackGroupIDOnInvalidRequest holds the string denoting the fallback_group_id_on_invalid_request field in the database.
FieldFallbackGroupIDOnInvalidRequest = "fallback_group_id_on_invalid_request"
// FieldModelRouting holds the string denoting the model_routing field in the database.
FieldModelRouting = "model_routing"
// FieldModelRoutingEnabled holds the string denoting the model_routing_enabled field in the database.
FieldModelRoutingEnabled = "model_routing_enabled"
// FieldMcpXMLInject holds the string denoting the mcp_xml_inject field in the database.
FieldMcpXMLInject = "mcp_xml_inject"
// FieldSupportedModelScopes holds the string denoting the supported_model_scopes field in the database.
FieldSupportedModelScopes = "supported_model_scopes"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -151,8 +157,11 @@ var Columns = []string{
FieldImagePrice4k,
FieldClaudeCodeOnly,
FieldFallbackGroupID,
FieldFallbackGroupIDOnInvalidRequest,
FieldModelRouting,
FieldModelRoutingEnabled,
FieldMcpXMLInject,
FieldSupportedModelScopes,
}
var (
@@ -212,6 +221,10 @@ var (
DefaultClaudeCodeOnly bool
// DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field.
DefaultModelRoutingEnabled bool
// DefaultMcpXMLInject holds the default value on creation for the "mcp_xml_inject" field.
DefaultMcpXMLInject bool
// DefaultSupportedModelScopes holds the default value on creation for the "supported_model_scopes" field.
DefaultSupportedModelScopes []string
)
// OrderOption defines the ordering options for the Group queries.
@@ -317,11 +330,21 @@ func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc()
}
// ByFallbackGroupIDOnInvalidRequest orders the results by the fallback_group_id_on_invalid_request field.
func ByFallbackGroupIDOnInvalidRequest(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldFallbackGroupIDOnInvalidRequest, opts...).ToFunc()
}
// ByModelRoutingEnabled orders the results by the model_routing_enabled field.
func ByModelRoutingEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldModelRoutingEnabled, opts...).ToFunc()
}
// ByMcpXMLInject orders the results by the mcp_xml_inject field.
func ByMcpXMLInject(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldMcpXMLInject, opts...).ToFunc()
}
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {

View File

@@ -150,11 +150,21 @@ func FallbackGroupID(v int64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
}
// FallbackGroupIDOnInvalidRequest applies equality check predicate on the "fallback_group_id_on_invalid_request" field. It's identical to FallbackGroupIDOnInvalidRequestEQ.
func FallbackGroupIDOnInvalidRequest(v int64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldFallbackGroupIDOnInvalidRequest, v))
}
// ModelRoutingEnabled applies equality check predicate on the "model_routing_enabled" field. It's identical to ModelRoutingEnabledEQ.
func ModelRoutingEnabled(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v))
}
// McpXMLInject applies equality check predicate on the "mcp_xml_inject" field. It's identical to McpXMLInjectEQ.
func McpXMLInject(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldMcpXMLInject, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
@@ -1070,6 +1080,56 @@ func FallbackGroupIDNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID))
}
// FallbackGroupIDOnInvalidRequestEQ applies the EQ predicate on the "fallback_group_id_on_invalid_request" field.
func FallbackGroupIDOnInvalidRequestEQ(v int64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldFallbackGroupIDOnInvalidRequest, v))
}
// FallbackGroupIDOnInvalidRequestNEQ applies the NEQ predicate on the "fallback_group_id_on_invalid_request" field.
func FallbackGroupIDOnInvalidRequestNEQ(v int64) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldFallbackGroupIDOnInvalidRequest, v))
}
// FallbackGroupIDOnInvalidRequestIn applies the In predicate on the "fallback_group_id_on_invalid_request" field.
func FallbackGroupIDOnInvalidRequestIn(vs ...int64) predicate.Group {
return predicate.Group(sql.FieldIn(FieldFallbackGroupIDOnInvalidRequest, vs...))
}
// FallbackGroupIDOnInvalidRequestNotIn applies the NotIn predicate on the "fallback_group_id_on_invalid_request" field.
func FallbackGroupIDOnInvalidRequestNotIn(vs ...int64) predicate.Group {
return predicate.Group(sql.FieldNotIn(FieldFallbackGroupIDOnInvalidRequest, vs...))
}
// FallbackGroupIDOnInvalidRequestGT applies the GT predicate on the "fallback_group_id_on_invalid_request" field.
func FallbackGroupIDOnInvalidRequestGT(v int64) predicate.Group {
return predicate.Group(sql.FieldGT(FieldFallbackGroupIDOnInvalidRequest, v))
}
// FallbackGroupIDOnInvalidRequestGTE applies the GTE predicate on the "fallback_group_id_on_invalid_request" field.
func FallbackGroupIDOnInvalidRequestGTE(v int64) predicate.Group {
return predicate.Group(sql.FieldGTE(FieldFallbackGroupIDOnInvalidRequest, v))
}
// FallbackGroupIDOnInvalidRequestLT applies the LT predicate on the "fallback_group_id_on_invalid_request" field.
func FallbackGroupIDOnInvalidRequestLT(v int64) predicate.Group {
return predicate.Group(sql.FieldLT(FieldFallbackGroupIDOnInvalidRequest, v))
}
// FallbackGroupIDOnInvalidRequestLTE applies the LTE predicate on the "fallback_group_id_on_invalid_request" field.
func FallbackGroupIDOnInvalidRequestLTE(v int64) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldFallbackGroupIDOnInvalidRequest, v))
}
// FallbackGroupIDOnInvalidRequestIsNil applies the IsNil predicate on the "fallback_group_id_on_invalid_request" field.
func FallbackGroupIDOnInvalidRequestIsNil() predicate.Group {
return predicate.Group(sql.FieldIsNull(FieldFallbackGroupIDOnInvalidRequest))
}
// FallbackGroupIDOnInvalidRequestNotNil applies the NotNil predicate on the "fallback_group_id_on_invalid_request" field.
func FallbackGroupIDOnInvalidRequestNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldFallbackGroupIDOnInvalidRequest))
}
// ModelRoutingIsNil applies the IsNil predicate on the "model_routing" field.
func ModelRoutingIsNil() predicate.Group {
return predicate.Group(sql.FieldIsNull(FieldModelRouting))
@@ -1090,6 +1150,16 @@ func ModelRoutingEnabledNEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldModelRoutingEnabled, v))
}
// McpXMLInjectEQ applies the EQ predicate on the "mcp_xml_inject" field.
func McpXMLInjectEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldMcpXMLInject, v))
}
// McpXMLInjectNEQ applies the NEQ predicate on the "mcp_xml_inject" field.
func McpXMLInjectNEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldMcpXMLInject, v))
}
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.Group {
return predicate.Group(func(s *sql.Selector) {

View File

@@ -286,6 +286,20 @@ func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate {
return _c
}
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
func (_c *GroupCreate) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupCreate {
_c.mutation.SetFallbackGroupIDOnInvalidRequest(v)
return _c
}
// SetNillableFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field if the given value is not nil.
func (_c *GroupCreate) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *GroupCreate {
if v != nil {
_c.SetFallbackGroupIDOnInvalidRequest(*v)
}
return _c
}
// SetModelRouting sets the "model_routing" field.
func (_c *GroupCreate) SetModelRouting(v map[string][]int64) *GroupCreate {
_c.mutation.SetModelRouting(v)
@@ -306,6 +320,26 @@ func (_c *GroupCreate) SetNillableModelRoutingEnabled(v *bool) *GroupCreate {
return _c
}
// SetMcpXMLInject sets the "mcp_xml_inject" field.
func (_c *GroupCreate) SetMcpXMLInject(v bool) *GroupCreate {
_c.mutation.SetMcpXMLInject(v)
return _c
}
// SetNillableMcpXMLInject sets the "mcp_xml_inject" field if the given value is not nil.
func (_c *GroupCreate) SetNillableMcpXMLInject(v *bool) *GroupCreate {
if v != nil {
_c.SetMcpXMLInject(*v)
}
return _c
}
// SetSupportedModelScopes sets the "supported_model_scopes" field.
func (_c *GroupCreate) SetSupportedModelScopes(v []string) *GroupCreate {
_c.mutation.SetSupportedModelScopes(v)
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -479,6 +513,14 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultModelRoutingEnabled
_c.mutation.SetModelRoutingEnabled(v)
}
if _, ok := _c.mutation.McpXMLInject(); !ok {
v := group.DefaultMcpXMLInject
_c.mutation.SetMcpXMLInject(v)
}
if _, ok := _c.mutation.SupportedModelScopes(); !ok {
v := group.DefaultSupportedModelScopes
_c.mutation.SetSupportedModelScopes(v)
}
return nil
}
@@ -537,6 +579,12 @@ func (_c *GroupCreate) check() error {
if _, ok := _c.mutation.ModelRoutingEnabled(); !ok {
return &ValidationError{Name: "model_routing_enabled", err: errors.New(`ent: missing required field "Group.model_routing_enabled"`)}
}
if _, ok := _c.mutation.McpXMLInject(); !ok {
return &ValidationError{Name: "mcp_xml_inject", err: errors.New(`ent: missing required field "Group.mcp_xml_inject"`)}
}
if _, ok := _c.mutation.SupportedModelScopes(); !ok {
return &ValidationError{Name: "supported_model_scopes", err: errors.New(`ent: missing required field "Group.supported_model_scopes"`)}
}
return nil
}
@@ -640,6 +688,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
_node.FallbackGroupID = &value
}
if value, ok := _c.mutation.FallbackGroupIDOnInvalidRequest(); ok {
_spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value)
_node.FallbackGroupIDOnInvalidRequest = &value
}
if value, ok := _c.mutation.ModelRouting(); ok {
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
_node.ModelRouting = value
@@ -648,6 +700,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
_node.ModelRoutingEnabled = value
}
if value, ok := _c.mutation.McpXMLInject(); ok {
_spec.SetField(group.FieldMcpXMLInject, field.TypeBool, value)
_node.McpXMLInject = value
}
if value, ok := _c.mutation.SupportedModelScopes(); ok {
_spec.SetField(group.FieldSupportedModelScopes, field.TypeJSON, value)
_node.SupportedModelScopes = value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1128,6 +1188,30 @@ func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert {
return u
}
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
func (u *GroupUpsert) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsert {
u.Set(group.FieldFallbackGroupIDOnInvalidRequest, v)
return u
}
// UpdateFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field to the value that was provided on create.
func (u *GroupUpsert) UpdateFallbackGroupIDOnInvalidRequest() *GroupUpsert {
u.SetExcluded(group.FieldFallbackGroupIDOnInvalidRequest)
return u
}
// AddFallbackGroupIDOnInvalidRequest adds v to the "fallback_group_id_on_invalid_request" field.
func (u *GroupUpsert) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsert {
u.Add(group.FieldFallbackGroupIDOnInvalidRequest, v)
return u
}
// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
func (u *GroupUpsert) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsert {
u.SetNull(group.FieldFallbackGroupIDOnInvalidRequest)
return u
}
// SetModelRouting sets the "model_routing" field.
func (u *GroupUpsert) SetModelRouting(v map[string][]int64) *GroupUpsert {
u.Set(group.FieldModelRouting, v)
@@ -1158,6 +1242,30 @@ func (u *GroupUpsert) UpdateModelRoutingEnabled() *GroupUpsert {
return u
}
// SetMcpXMLInject sets the "mcp_xml_inject" field.
func (u *GroupUpsert) SetMcpXMLInject(v bool) *GroupUpsert {
u.Set(group.FieldMcpXMLInject, v)
return u
}
// UpdateMcpXMLInject sets the "mcp_xml_inject" field to the value that was provided on create.
func (u *GroupUpsert) UpdateMcpXMLInject() *GroupUpsert {
u.SetExcluded(group.FieldMcpXMLInject)
return u
}
// SetSupportedModelScopes sets the "supported_model_scopes" field.
func (u *GroupUpsert) SetSupportedModelScopes(v []string) *GroupUpsert {
u.Set(group.FieldSupportedModelScopes, v)
return u
}
// UpdateSupportedModelScopes sets the "supported_model_scopes" field to the value that was provided on create.
func (u *GroupUpsert) UpdateSupportedModelScopes() *GroupUpsert {
u.SetExcluded(group.FieldSupportedModelScopes)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -1581,6 +1689,34 @@ func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne {
})
}
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
func (u *GroupUpsertOne) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetFallbackGroupIDOnInvalidRequest(v)
})
}
// AddFallbackGroupIDOnInvalidRequest adds v to the "fallback_group_id_on_invalid_request" field.
func (u *GroupUpsertOne) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.AddFallbackGroupIDOnInvalidRequest(v)
})
}
// UpdateFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateFallbackGroupIDOnInvalidRequest() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateFallbackGroupIDOnInvalidRequest()
})
}
// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
func (u *GroupUpsertOne) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.ClearFallbackGroupIDOnInvalidRequest()
})
}
// SetModelRouting sets the "model_routing" field.
func (u *GroupUpsertOne) SetModelRouting(v map[string][]int64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
@@ -1616,6 +1752,34 @@ func (u *GroupUpsertOne) UpdateModelRoutingEnabled() *GroupUpsertOne {
})
}
// SetMcpXMLInject sets the "mcp_xml_inject" field.
func (u *GroupUpsertOne) SetMcpXMLInject(v bool) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetMcpXMLInject(v)
})
}
// UpdateMcpXMLInject sets the "mcp_xml_inject" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateMcpXMLInject() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateMcpXMLInject()
})
}
// SetSupportedModelScopes sets the "supported_model_scopes" field.
func (u *GroupUpsertOne) SetSupportedModelScopes(v []string) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetSupportedModelScopes(v)
})
}
// UpdateSupportedModelScopes sets the "supported_model_scopes" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateSupportedModelScopes() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateSupportedModelScopes()
})
}
// Exec executes the query.
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -2205,6 +2369,34 @@ func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk {
})
}
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
func (u *GroupUpsertBulk) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetFallbackGroupIDOnInvalidRequest(v)
})
}
// AddFallbackGroupIDOnInvalidRequest adds v to the "fallback_group_id_on_invalid_request" field.
func (u *GroupUpsertBulk) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.AddFallbackGroupIDOnInvalidRequest(v)
})
}
// UpdateFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateFallbackGroupIDOnInvalidRequest() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateFallbackGroupIDOnInvalidRequest()
})
}
// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
func (u *GroupUpsertBulk) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.ClearFallbackGroupIDOnInvalidRequest()
})
}
// SetModelRouting sets the "model_routing" field.
func (u *GroupUpsertBulk) SetModelRouting(v map[string][]int64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
@@ -2240,6 +2432,34 @@ func (u *GroupUpsertBulk) UpdateModelRoutingEnabled() *GroupUpsertBulk {
})
}
// SetMcpXMLInject sets the "mcp_xml_inject" field.
func (u *GroupUpsertBulk) SetMcpXMLInject(v bool) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetMcpXMLInject(v)
})
}
// UpdateMcpXMLInject sets the "mcp_xml_inject" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateMcpXMLInject() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateMcpXMLInject()
})
}
// SetSupportedModelScopes sets the "supported_model_scopes" field.
func (u *GroupUpsertBulk) SetSupportedModelScopes(v []string) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetSupportedModelScopes(v)
})
}
// UpdateSupportedModelScopes sets the "supported_model_scopes" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateSupportedModelScopes() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateSupportedModelScopes()
})
}
// Exec executes the query.
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {

View File

@@ -10,6 +10,7 @@ import (
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/dialect/sql/sqljson"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/account"
"github.com/Wei-Shaw/sub2api/ent/apikey"
@@ -395,6 +396,33 @@ func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate {
return _u
}
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
func (_u *GroupUpdate) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdate {
_u.mutation.ResetFallbackGroupIDOnInvalidRequest()
_u.mutation.SetFallbackGroupIDOnInvalidRequest(v)
return _u
}
// SetNillableFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *GroupUpdate {
if v != nil {
_u.SetFallbackGroupIDOnInvalidRequest(*v)
}
return _u
}
// AddFallbackGroupIDOnInvalidRequest adds value to the "fallback_group_id_on_invalid_request" field.
func (_u *GroupUpdate) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdate {
_u.mutation.AddFallbackGroupIDOnInvalidRequest(v)
return _u
}
// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
func (_u *GroupUpdate) ClearFallbackGroupIDOnInvalidRequest() *GroupUpdate {
_u.mutation.ClearFallbackGroupIDOnInvalidRequest()
return _u
}
// SetModelRouting sets the "model_routing" field.
func (_u *GroupUpdate) SetModelRouting(v map[string][]int64) *GroupUpdate {
_u.mutation.SetModelRouting(v)
@@ -421,6 +449,32 @@ func (_u *GroupUpdate) SetNillableModelRoutingEnabled(v *bool) *GroupUpdate {
return _u
}
// SetMcpXMLInject sets the "mcp_xml_inject" field.
func (_u *GroupUpdate) SetMcpXMLInject(v bool) *GroupUpdate {
_u.mutation.SetMcpXMLInject(v)
return _u
}
// SetNillableMcpXMLInject sets the "mcp_xml_inject" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableMcpXMLInject(v *bool) *GroupUpdate {
if v != nil {
_u.SetMcpXMLInject(*v)
}
return _u
}
// SetSupportedModelScopes sets the "supported_model_scopes" field.
func (_u *GroupUpdate) SetSupportedModelScopes(v []string) *GroupUpdate {
_u.mutation.SetSupportedModelScopes(v)
return _u
}
// AppendSupportedModelScopes appends value to the "supported_model_scopes" field.
func (_u *GroupUpdate) AppendSupportedModelScopes(v []string) *GroupUpdate {
_u.mutation.AppendSupportedModelScopes(v)
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -829,6 +883,15 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.FallbackGroupIDCleared() {
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
}
if value, ok := _u.mutation.FallbackGroupIDOnInvalidRequest(); ok {
_spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedFallbackGroupIDOnInvalidRequest(); ok {
_spec.AddField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value)
}
if _u.mutation.FallbackGroupIDOnInvalidRequestCleared() {
_spec.ClearField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64)
}
if value, ok := _u.mutation.ModelRouting(); ok {
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
}
@@ -838,6 +901,17 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.ModelRoutingEnabled(); ok {
_spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
}
if value, ok := _u.mutation.McpXMLInject(); ok {
_spec.SetField(group.FieldMcpXMLInject, field.TypeBool, value)
}
if value, ok := _u.mutation.SupportedModelScopes(); ok {
_spec.SetField(group.FieldSupportedModelScopes, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedSupportedModelScopes(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, group.FieldSupportedModelScopes, value)
})
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1513,6 +1587,33 @@ func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne {
return _u
}
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
func (_u *GroupUpdateOne) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdateOne {
_u.mutation.ResetFallbackGroupIDOnInvalidRequest()
_u.mutation.SetFallbackGroupIDOnInvalidRequest(v)
return _u
}
// SetNillableFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *GroupUpdateOne {
if v != nil {
_u.SetFallbackGroupIDOnInvalidRequest(*v)
}
return _u
}
// AddFallbackGroupIDOnInvalidRequest adds value to the "fallback_group_id_on_invalid_request" field.
func (_u *GroupUpdateOne) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdateOne {
_u.mutation.AddFallbackGroupIDOnInvalidRequest(v)
return _u
}
// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
func (_u *GroupUpdateOne) ClearFallbackGroupIDOnInvalidRequest() *GroupUpdateOne {
_u.mutation.ClearFallbackGroupIDOnInvalidRequest()
return _u
}
// SetModelRouting sets the "model_routing" field.
func (_u *GroupUpdateOne) SetModelRouting(v map[string][]int64) *GroupUpdateOne {
_u.mutation.SetModelRouting(v)
@@ -1539,6 +1640,32 @@ func (_u *GroupUpdateOne) SetNillableModelRoutingEnabled(v *bool) *GroupUpdateOn
return _u
}
// SetMcpXMLInject sets the "mcp_xml_inject" field.
func (_u *GroupUpdateOne) SetMcpXMLInject(v bool) *GroupUpdateOne {
_u.mutation.SetMcpXMLInject(v)
return _u
}
// SetNillableMcpXMLInject sets the "mcp_xml_inject" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableMcpXMLInject(v *bool) *GroupUpdateOne {
if v != nil {
_u.SetMcpXMLInject(*v)
}
return _u
}
// SetSupportedModelScopes sets the "supported_model_scopes" field.
func (_u *GroupUpdateOne) SetSupportedModelScopes(v []string) *GroupUpdateOne {
_u.mutation.SetSupportedModelScopes(v)
return _u
}
// AppendSupportedModelScopes appends value to the "supported_model_scopes" field.
func (_u *GroupUpdateOne) AppendSupportedModelScopes(v []string) *GroupUpdateOne {
_u.mutation.AppendSupportedModelScopes(v)
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1977,6 +2104,15 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if _u.mutation.FallbackGroupIDCleared() {
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
}
if value, ok := _u.mutation.FallbackGroupIDOnInvalidRequest(); ok {
_spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedFallbackGroupIDOnInvalidRequest(); ok {
_spec.AddField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value)
}
if _u.mutation.FallbackGroupIDOnInvalidRequestCleared() {
_spec.ClearField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64)
}
if value, ok := _u.mutation.ModelRouting(); ok {
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
}
@@ -1986,6 +2122,17 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if value, ok := _u.mutation.ModelRoutingEnabled(); ok {
_spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
}
if value, ok := _u.mutation.McpXMLInject(); ok {
_spec.SetField(group.FieldMcpXMLInject, field.TypeBool, value)
}
if value, ok := _u.mutation.SupportedModelScopes(); ok {
_spec.SetField(group.FieldSupportedModelScopes, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedSupportedModelScopes(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, group.FieldSupportedModelScopes, value)
})
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,

View File

@@ -45,6 +45,30 @@ func (f AccountGroupFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AccountGroupMutation", m)
}
// The AnnouncementFunc type is an adapter to allow the use of ordinary
// function as Announcement mutator.
type AnnouncementFunc func(context.Context, *ent.AnnouncementMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f AnnouncementFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if mv, ok := m.(*ent.AnnouncementMutation); ok {
return f(ctx, mv)
}
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AnnouncementMutation", m)
}
// The AnnouncementReadFunc type is an adapter to allow the use of ordinary
// function as AnnouncementRead mutator.
type AnnouncementReadFunc func(context.Context, *ent.AnnouncementReadMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f AnnouncementReadFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if mv, ok := m.(*ent.AnnouncementReadMutation); ok {
return f(ctx, mv)
}
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AnnouncementReadMutation", m)
}
// The GroupFunc type is an adapter to allow the use of ordinary
// function as Group mutator.
type GroupFunc func(context.Context, *ent.GroupMutation) (ent.Value, error)

View File

@@ -10,6 +10,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/account"
"github.com/Wei-Shaw/sub2api/ent/accountgroup"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
@@ -164,6 +166,60 @@ func (f TraverseAccountGroup) Traverse(ctx context.Context, q ent.Query) error {
return fmt.Errorf("unexpected query type %T. expect *ent.AccountGroupQuery", q)
}
// The AnnouncementFunc type is an adapter to allow the use of ordinary function as a Querier.
type AnnouncementFunc func(context.Context, *ent.AnnouncementQuery) (ent.Value, error)
// Query calls f(ctx, q).
func (f AnnouncementFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
if q, ok := q.(*ent.AnnouncementQuery); ok {
return f(ctx, q)
}
return nil, fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementQuery", q)
}
// The TraverseAnnouncement type is an adapter to allow the use of ordinary function as Traverser.
type TraverseAnnouncement func(context.Context, *ent.AnnouncementQuery) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f TraverseAnnouncement) Intercept(next ent.Querier) ent.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f TraverseAnnouncement) Traverse(ctx context.Context, q ent.Query) error {
if q, ok := q.(*ent.AnnouncementQuery); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementQuery", q)
}
// The AnnouncementReadFunc type is an adapter to allow the use of ordinary function as a Querier.
type AnnouncementReadFunc func(context.Context, *ent.AnnouncementReadQuery) (ent.Value, error)
// Query calls f(ctx, q).
func (f AnnouncementReadFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
if q, ok := q.(*ent.AnnouncementReadQuery); ok {
return f(ctx, q)
}
return nil, fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q)
}
// The TraverseAnnouncementRead type is an adapter to allow the use of ordinary function as Traverser.
type TraverseAnnouncementRead func(context.Context, *ent.AnnouncementReadQuery) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f TraverseAnnouncementRead) Intercept(next ent.Querier) ent.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f TraverseAnnouncementRead) Traverse(ctx context.Context, q ent.Query) error {
if q, ok := q.(*ent.AnnouncementReadQuery); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q)
}
// The GroupFunc type is an adapter to allow the use of ordinary function as a Querier.
type GroupFunc func(context.Context, *ent.GroupQuery) (ent.Value, error)
@@ -524,6 +580,10 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.AccountQuery, predicate.Account, account.OrderOption]{typ: ent.TypeAccount, tq: q}, nil
case *ent.AccountGroupQuery:
return &query[*ent.AccountGroupQuery, predicate.AccountGroup, accountgroup.OrderOption]{typ: ent.TypeAccountGroup, tq: q}, nil
case *ent.AnnouncementQuery:
return &query[*ent.AnnouncementQuery, predicate.Announcement, announcement.OrderOption]{typ: ent.TypeAnnouncement, tq: q}, nil
case *ent.AnnouncementReadQuery:
return &query[*ent.AnnouncementReadQuery, predicate.AnnouncementRead, announcementread.OrderOption]{typ: ent.TypeAnnouncementRead, tq: q}, nil
case *ent.GroupQuery:
return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil
case *ent.PromoCodeQuery:

View File

@@ -20,6 +20,9 @@ var (
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
{Name: "ip_whitelist", Type: field.TypeJSON, Nullable: true},
{Name: "ip_blacklist", Type: field.TypeJSON, Nullable: true},
{Name: "quota", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "quota_used", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "expires_at", Type: field.TypeTime, Nullable: true},
{Name: "group_id", Type: field.TypeInt64, Nullable: true},
{Name: "user_id", Type: field.TypeInt64},
}
@@ -31,13 +34,13 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "api_keys_groups_api_keys",
Columns: []*schema.Column{APIKeysColumns[9]},
Columns: []*schema.Column{APIKeysColumns[12]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "api_keys_users_api_keys",
Columns: []*schema.Column{APIKeysColumns[10]},
Columns: []*schema.Column{APIKeysColumns[13]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
@@ -46,12 +49,12 @@ var (
{
Name: "apikey_user_id",
Unique: false,
Columns: []*schema.Column{APIKeysColumns[10]},
Columns: []*schema.Column{APIKeysColumns[13]},
},
{
Name: "apikey_group_id",
Unique: false,
Columns: []*schema.Column{APIKeysColumns[9]},
Columns: []*schema.Column{APIKeysColumns[12]},
},
{
Name: "apikey_status",
@@ -63,6 +66,16 @@ var (
Unique: false,
Columns: []*schema.Column{APIKeysColumns[3]},
},
{
Name: "apikey_quota_quota_used",
Unique: false,
Columns: []*schema.Column{APIKeysColumns[9], APIKeysColumns[10]},
},
{
Name: "apikey_expires_at",
Unique: false,
Columns: []*schema.Column{APIKeysColumns[11]},
},
},
}
// AccountsColumns holds the columns for the "accounts" table.
@@ -204,6 +217,98 @@ var (
},
},
}
// AnnouncementsColumns holds the columns for the "announcements" table.
AnnouncementsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "title", Type: field.TypeString, Size: 200},
{Name: "content", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
{Name: "status", Type: field.TypeString, Size: 20, Default: "draft"},
{Name: "targeting", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "starts_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "ends_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "created_by", Type: field.TypeInt64, Nullable: true},
{Name: "updated_by", Type: field.TypeInt64, Nullable: true},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
}
// AnnouncementsTable holds the schema information for the "announcements" table.
AnnouncementsTable = &schema.Table{
Name: "announcements",
Columns: AnnouncementsColumns,
PrimaryKey: []*schema.Column{AnnouncementsColumns[0]},
Indexes: []*schema.Index{
{
Name: "announcement_status",
Unique: false,
Columns: []*schema.Column{AnnouncementsColumns[3]},
},
{
Name: "announcement_created_at",
Unique: false,
Columns: []*schema.Column{AnnouncementsColumns[9]},
},
{
Name: "announcement_starts_at",
Unique: false,
Columns: []*schema.Column{AnnouncementsColumns[5]},
},
{
Name: "announcement_ends_at",
Unique: false,
Columns: []*schema.Column{AnnouncementsColumns[6]},
},
},
}
// AnnouncementReadsColumns holds the columns for the "announcement_reads" table.
AnnouncementReadsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "read_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "announcement_id", Type: field.TypeInt64},
{Name: "user_id", Type: field.TypeInt64},
}
// AnnouncementReadsTable holds the schema information for the "announcement_reads" table.
AnnouncementReadsTable = &schema.Table{
Name: "announcement_reads",
Columns: AnnouncementReadsColumns,
PrimaryKey: []*schema.Column{AnnouncementReadsColumns[0]},
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "announcement_reads_announcements_reads",
Columns: []*schema.Column{AnnouncementReadsColumns[3]},
RefColumns: []*schema.Column{AnnouncementsColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "announcement_reads_users_announcement_reads",
Columns: []*schema.Column{AnnouncementReadsColumns[4]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
},
Indexes: []*schema.Index{
{
Name: "announcementread_announcement_id",
Unique: false,
Columns: []*schema.Column{AnnouncementReadsColumns[3]},
},
{
Name: "announcementread_user_id",
Unique: false,
Columns: []*schema.Column{AnnouncementReadsColumns[4]},
},
{
Name: "announcementread_read_at",
Unique: false,
Columns: []*schema.Column{AnnouncementReadsColumns[1]},
},
{
Name: "announcementread_announcement_id_user_id",
Unique: true,
Columns: []*schema.Column{AnnouncementReadsColumns[3], AnnouncementReadsColumns[4]},
},
},
}
// GroupsColumns holds the columns for the "groups" table.
GroupsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -226,8 +331,11 @@ var (
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
{Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true},
{Name: "model_routing", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "model_routing_enabled", Type: field.TypeBool, Default: false},
{Name: "mcp_xml_inject", Type: field.TypeBool, Default: true},
{Name: "supported_model_scopes", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
}
// GroupsTable holds the schema information for the "groups" table.
GroupsTable = &schema.Table{
@@ -840,6 +948,8 @@ var (
APIKeysTable,
AccountsTable,
AccountGroupsTable,
AnnouncementsTable,
AnnouncementReadsTable,
GroupsTable,
PromoCodesTable,
PromoCodeUsagesTable,
@@ -871,6 +981,14 @@ func init() {
AccountGroupsTable.Annotation = &entsql.Annotation{
Table: "account_groups",
}
AnnouncementsTable.Annotation = &entsql.Annotation{
Table: "announcements",
}
AnnouncementReadsTable.ForeignKeys[0].RefTable = AnnouncementsTable
AnnouncementReadsTable.ForeignKeys[1].RefTable = UsersTable
AnnouncementReadsTable.Annotation = &entsql.Annotation{
Table: "announcement_reads",
}
GroupsTable.Annotation = &entsql.Annotation{
Table: "groups",
}

File diff suppressed because it is too large Load Diff

View File

@@ -15,6 +15,12 @@ type Account func(*sql.Selector)
// AccountGroup is the predicate function for accountgroup builders.
type AccountGroup func(*sql.Selector)
// Announcement is the predicate function for announcement builders.
type Announcement func(*sql.Selector)
// AnnouncementRead is the predicate function for announcementread builders.
type AnnouncementRead func(*sql.Selector)
// Group is the predicate function for group builders.
type Group func(*sql.Selector)

View File

@@ -7,6 +7,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent/account"
"github.com/Wei-Shaw/sub2api/ent/accountgroup"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/promocode"
@@ -89,6 +91,14 @@ func init() {
apikey.DefaultStatus = apikeyDescStatus.Default.(string)
// apikey.StatusValidator is a validator for the "status" field. It is called by the builders before save.
apikey.StatusValidator = apikeyDescStatus.Validators[0].(func(string) error)
// apikeyDescQuota is the schema descriptor for quota field.
apikeyDescQuota := apikeyFields[7].Descriptor()
// apikey.DefaultQuota holds the default value on creation for the quota field.
apikey.DefaultQuota = apikeyDescQuota.Default.(float64)
// apikeyDescQuotaUsed is the schema descriptor for quota_used field.
apikeyDescQuotaUsed := apikeyFields[8].Descriptor()
// apikey.DefaultQuotaUsed holds the default value on creation for the quota_used field.
apikey.DefaultQuotaUsed = apikeyDescQuotaUsed.Default.(float64)
accountMixin := schema.Account{}.Mixin()
accountMixinHooks1 := accountMixin[1].Hooks()
account.Hooks[0] = accountMixinHooks1[0]
@@ -210,6 +220,56 @@ func init() {
accountgroupDescCreatedAt := accountgroupFields[3].Descriptor()
// accountgroup.DefaultCreatedAt holds the default value on creation for the created_at field.
accountgroup.DefaultCreatedAt = accountgroupDescCreatedAt.Default.(func() time.Time)
announcementFields := schema.Announcement{}.Fields()
_ = announcementFields
// announcementDescTitle is the schema descriptor for title field.
announcementDescTitle := announcementFields[0].Descriptor()
// announcement.TitleValidator is a validator for the "title" field. It is called by the builders before save.
announcement.TitleValidator = func() func(string) error {
validators := announcementDescTitle.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
}
return func(title string) error {
for _, fn := range fns {
if err := fn(title); err != nil {
return err
}
}
return nil
}
}()
// announcementDescContent is the schema descriptor for content field.
announcementDescContent := announcementFields[1].Descriptor()
// announcement.ContentValidator is a validator for the "content" field. It is called by the builders before save.
announcement.ContentValidator = announcementDescContent.Validators[0].(func(string) error)
// announcementDescStatus is the schema descriptor for status field.
announcementDescStatus := announcementFields[2].Descriptor()
// announcement.DefaultStatus holds the default value on creation for the status field.
announcement.DefaultStatus = announcementDescStatus.Default.(string)
// announcement.StatusValidator is a validator for the "status" field. It is called by the builders before save.
announcement.StatusValidator = announcementDescStatus.Validators[0].(func(string) error)
// announcementDescCreatedAt is the schema descriptor for created_at field.
announcementDescCreatedAt := announcementFields[8].Descriptor()
// announcement.DefaultCreatedAt holds the default value on creation for the created_at field.
announcement.DefaultCreatedAt = announcementDescCreatedAt.Default.(func() time.Time)
// announcementDescUpdatedAt is the schema descriptor for updated_at field.
announcementDescUpdatedAt := announcementFields[9].Descriptor()
// announcement.DefaultUpdatedAt holds the default value on creation for the updated_at field.
announcement.DefaultUpdatedAt = announcementDescUpdatedAt.Default.(func() time.Time)
// announcement.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
announcement.UpdateDefaultUpdatedAt = announcementDescUpdatedAt.UpdateDefault.(func() time.Time)
announcementreadFields := schema.AnnouncementRead{}.Fields()
_ = announcementreadFields
// announcementreadDescReadAt is the schema descriptor for read_at field.
announcementreadDescReadAt := announcementreadFields[2].Descriptor()
// announcementread.DefaultReadAt holds the default value on creation for the read_at field.
announcementread.DefaultReadAt = announcementreadDescReadAt.Default.(func() time.Time)
// announcementreadDescCreatedAt is the schema descriptor for created_at field.
announcementreadDescCreatedAt := announcementreadFields[3].Descriptor()
// announcementread.DefaultCreatedAt holds the default value on creation for the created_at field.
announcementread.DefaultCreatedAt = announcementreadDescCreatedAt.Default.(func() time.Time)
groupMixin := schema.Group{}.Mixin()
groupMixinHooks1 := groupMixin[1].Hooks()
group.Hooks[0] = groupMixinHooks1[0]
@@ -282,9 +342,17 @@ func init() {
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
groupDescModelRoutingEnabled := groupFields[17].Descriptor()
groupDescModelRoutingEnabled := groupFields[18].Descriptor()
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
// groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field.
groupDescMcpXMLInject := groupFields[19].Descriptor()
// group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field.
group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool)
// groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field.
groupDescSupportedModelScopes := groupFields[20].Descriptor()
// group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field.
group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string)
promocodeFields := schema.PromoCode{}.Fields()
_ = promocodeFields
// promocodeDescCode is the schema descriptor for code field.

View File

@@ -4,7 +4,7 @@ package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/domain"
"entgo.io/ent"
"entgo.io/ent/dialect"
@@ -111,7 +111,7 @@ func (Account) Fields() []ent.Field {
// status: 账户状态,如 "active", "error", "disabled"
field.String("status").
MaxLen(20).
Default(service.StatusActive),
Default(domain.StatusActive),
// error_message: 错误信息,记录账户异常时的详细信息
field.String("error_message").

View File

@@ -0,0 +1,90 @@
package schema
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// Announcement holds the schema definition for the Announcement entity.
//
// 删除策略:硬删除(已读记录通过外键级联删除)
type Announcement struct {
ent.Schema
}
func (Announcement) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "announcements"},
}
}
func (Announcement) Fields() []ent.Field {
return []ent.Field{
field.String("title").
MaxLen(200).
NotEmpty().
Comment("公告标题"),
field.String("content").
SchemaType(map[string]string{dialect.Postgres: "text"}).
NotEmpty().
Comment("公告内容(支持 Markdown"),
field.String("status").
MaxLen(20).
Default(domain.AnnouncementStatusDraft).
Comment("状态: draft, active, archived"),
field.JSON("targeting", domain.AnnouncementTargeting{}).
Optional().
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
Comment("展示条件JSON 规则)"),
field.Time("starts_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}).
Comment("开始展示时间(为空表示立即生效)"),
field.Time("ends_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}).
Comment("结束展示时间(为空表示永久生效)"),
field.Int64("created_by").
Optional().
Nillable().
Comment("创建人用户ID管理员"),
field.Int64("updated_by").
Optional().
Nillable().
Comment("更新人用户ID管理员"),
field.Time("created_at").
Immutable().
Default(time.Now).
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("updated_at").
Default(time.Now).
UpdateDefault(time.Now).
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
}
}
func (Announcement) Edges() []ent.Edge {
return []ent.Edge{
edge.To("reads", AnnouncementRead.Type),
}
}
func (Announcement) Indexes() []ent.Index {
return []ent.Index{
index.Fields("status"),
index.Fields("created_at"),
index.Fields("starts_at"),
index.Fields("ends_at"),
}
}

View File

@@ -0,0 +1,65 @@
package schema
import (
"time"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// AnnouncementRead holds the schema definition for the AnnouncementRead entity.
//
// 记录用户对公告的已读状态(首次已读时间)。
type AnnouncementRead struct {
ent.Schema
}
func (AnnouncementRead) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "announcement_reads"},
}
}
func (AnnouncementRead) Fields() []ent.Field {
return []ent.Field{
field.Int64("announcement_id"),
field.Int64("user_id"),
field.Time("read_at").
Default(time.Now).
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}).
Comment("用户首次已读时间"),
field.Time("created_at").
Immutable().
Default(time.Now).
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
}
}
func (AnnouncementRead) Edges() []ent.Edge {
return []ent.Edge{
edge.From("announcement", Announcement.Type).
Ref("reads").
Field("announcement_id").
Unique().
Required(),
edge.From("user", User.Type).
Ref("announcement_reads").
Field("user_id").
Unique().
Required(),
}
}
func (AnnouncementRead) Indexes() []ent.Index {
return []ent.Index{
index.Fields("announcement_id"),
index.Fields("user_id"),
index.Fields("read_at"),
index.Fields("announcement_id", "user_id").Unique(),
}
}

View File

@@ -2,9 +2,10 @@ package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/domain"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
@@ -45,13 +46,30 @@ func (APIKey) Fields() []ent.Field {
Nillable(),
field.String("status").
MaxLen(20).
Default(service.StatusActive),
Default(domain.StatusActive),
field.JSON("ip_whitelist", []string{}).
Optional().
Comment("Allowed IPs/CIDRs, e.g. [\"192.168.1.100\", \"10.0.0.0/8\"]"),
field.JSON("ip_blacklist", []string{}).
Optional().
Comment("Blocked IPs/CIDRs"),
// ========== Quota fields ==========
// Quota limit in USD (0 = unlimited)
field.Float("quota").
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
Default(0).
Comment("Quota limit in USD for this API key (0 = unlimited)"),
// Used quota amount
field.Float("quota_used").
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
Default(0).
Comment("Used quota amount in USD"),
// Expiration time (nil = never expires)
field.Time("expires_at").
Optional().
Nillable().
Comment("Expiration time for this API key (null = never expires)"),
}
}
@@ -77,5 +95,8 @@ func (APIKey) Indexes() []ent.Index {
index.Fields("group_id"),
index.Fields("status"),
index.Fields("deleted_at"),
// Index for quota queries
index.Fields("quota", "quota_used"),
index.Fields("expires_at"),
}
}

View File

@@ -2,7 +2,7 @@ package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/domain"
"entgo.io/ent"
"entgo.io/ent/dialect"
@@ -49,15 +49,15 @@ func (Group) Fields() []ent.Field {
Default(false),
field.String("status").
MaxLen(20).
Default(service.StatusActive),
Default(domain.StatusActive),
// Subscription-related fields (added by migration 003)
field.String("platform").
MaxLen(50).
Default(service.PlatformAnthropic),
Default(domain.PlatformAnthropic),
field.String("subscription_type").
MaxLen(20).
Default(service.SubscriptionTypeStandard),
Default(domain.SubscriptionTypeStandard),
field.Float("daily_limit_usd").
Optional().
Nillable().
@@ -95,6 +95,10 @@ func (Group) Fields() []ent.Field {
Optional().
Nillable().
Comment("非 Claude Code 请求降级使用的分组 ID"),
field.Int64("fallback_group_id_on_invalid_request").
Optional().
Nillable().
Comment("无效请求兜底使用的分组 ID"),
// 模型路由配置 (added by migration 040)
field.JSON("model_routing", map[string][]int64{}).
@@ -106,6 +110,17 @@ func (Group) Fields() []ent.Field {
field.Bool("model_routing_enabled").
Default(false).
Comment("是否启用模型路由配置"),
// MCP XML 协议注入开关 (added by migration 042)
field.Bool("mcp_xml_inject").
Default(true).
Comment("是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)"),
// 支持的模型系列 (added by migration 046)
field.JSON("supported_model_scopes", []string{}).
Default([]string{"claude", "gemini_text", "gemini_image"}).
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
Comment("支持的模型系列claude, gemini_text, gemini_image"),
}
}

View File

@@ -3,7 +3,7 @@ package schema
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/domain"
"entgo.io/ent"
"entgo.io/ent/dialect"
@@ -49,7 +49,7 @@ func (PromoCode) Fields() []ent.Field {
Comment("已使用次数"),
field.String("status").
MaxLen(20).
Default(service.PromoCodeStatusActive).
Default(domain.PromoCodeStatusActive).
Comment("状态: active, disabled"),
field.Time("expires_at").
Optional().

View File

@@ -3,7 +3,7 @@ package schema
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/domain"
"entgo.io/ent"
"entgo.io/ent/dialect"
@@ -41,13 +41,13 @@ func (RedeemCode) Fields() []ent.Field {
Unique(),
field.String("type").
MaxLen(20).
Default(service.RedeemTypeBalance),
Default(domain.RedeemTypeBalance),
field.Float("value").
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
Default(0),
field.String("status").
MaxLen(20).
Default(service.StatusUnused),
Default(domain.StatusUnused),
field.Int64("used_by").
Optional().
Nillable(),

View File

@@ -2,7 +2,7 @@ package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/domain"
"entgo.io/ent"
"entgo.io/ent/dialect"
@@ -43,7 +43,7 @@ func (User) Fields() []ent.Field {
NotEmpty(),
field.String("role").
MaxLen(20).
Default(service.RoleUser),
Default(domain.RoleUser),
field.Float("balance").
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
Default(0),
@@ -51,7 +51,7 @@ func (User) Fields() []ent.Field {
Default(5),
field.String("status").
MaxLen(20).
Default(service.StatusActive),
Default(domain.StatusActive),
// Optional profile fields (added later; default '' in DB migration)
field.String("username").
@@ -81,6 +81,7 @@ func (User) Edges() []ent.Edge {
edge.To("redeem_codes", RedeemCode.Type),
edge.To("subscriptions", UserSubscription.Type),
edge.To("assigned_subscriptions", UserSubscription.Type),
edge.To("announcement_reads", AnnouncementRead.Type),
edge.To("allowed_groups", Group.Type).
Through("user_allowed_groups", UserAllowedGroup.Type),
edge.To("usage_logs", UsageLog.Type),

View File

@@ -4,7 +4,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/domain"
"entgo.io/ent"
"entgo.io/ent/dialect"
@@ -44,7 +44,7 @@ func (UserSubscription) Fields() []ent.Field {
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.String("status").
MaxLen(20).
Default(service.SubscriptionStatusActive),
Default(domain.SubscriptionStatusActive),
field.Time("daily_window_start").
Optional().

View File

@@ -20,6 +20,10 @@ type Tx struct {
Account *AccountClient
// AccountGroup is the client for interacting with the AccountGroup builders.
AccountGroup *AccountGroupClient
// Announcement is the client for interacting with the Announcement builders.
Announcement *AnnouncementClient
// AnnouncementRead is the client for interacting with the AnnouncementRead builders.
AnnouncementRead *AnnouncementReadClient
// Group is the client for interacting with the Group builders.
Group *GroupClient
// PromoCode is the client for interacting with the PromoCode builders.
@@ -180,6 +184,8 @@ func (tx *Tx) init() {
tx.APIKey = NewAPIKeyClient(tx.config)
tx.Account = NewAccountClient(tx.config)
tx.AccountGroup = NewAccountGroupClient(tx.config)
tx.Announcement = NewAnnouncementClient(tx.config)
tx.AnnouncementRead = NewAnnouncementReadClient(tx.config)
tx.Group = NewGroupClient(tx.config)
tx.PromoCode = NewPromoCodeClient(tx.config)
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)

View File

@@ -61,6 +61,8 @@ type UserEdges struct {
Subscriptions []*UserSubscription `json:"subscriptions,omitempty"`
// AssignedSubscriptions holds the value of the assigned_subscriptions edge.
AssignedSubscriptions []*UserSubscription `json:"assigned_subscriptions,omitempty"`
// AnnouncementReads holds the value of the announcement_reads edge.
AnnouncementReads []*AnnouncementRead `json:"announcement_reads,omitempty"`
// AllowedGroups holds the value of the allowed_groups edge.
AllowedGroups []*Group `json:"allowed_groups,omitempty"`
// UsageLogs holds the value of the usage_logs edge.
@@ -73,7 +75,7 @@ type UserEdges struct {
UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [9]bool
loadedTypes [10]bool
}
// APIKeysOrErr returns the APIKeys value or an error if the edge
@@ -112,10 +114,19 @@ func (e UserEdges) AssignedSubscriptionsOrErr() ([]*UserSubscription, error) {
return nil, &NotLoadedError{edge: "assigned_subscriptions"}
}
// AnnouncementReadsOrErr returns the AnnouncementReads value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) AnnouncementReadsOrErr() ([]*AnnouncementRead, error) {
if e.loadedTypes[4] {
return e.AnnouncementReads, nil
}
return nil, &NotLoadedError{edge: "announcement_reads"}
}
// AllowedGroupsOrErr returns the AllowedGroups value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) AllowedGroupsOrErr() ([]*Group, error) {
if e.loadedTypes[4] {
if e.loadedTypes[5] {
return e.AllowedGroups, nil
}
return nil, &NotLoadedError{edge: "allowed_groups"}
@@ -124,7 +135,7 @@ func (e UserEdges) AllowedGroupsOrErr() ([]*Group, error) {
// UsageLogsOrErr returns the UsageLogs value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) UsageLogsOrErr() ([]*UsageLog, error) {
if e.loadedTypes[5] {
if e.loadedTypes[6] {
return e.UsageLogs, nil
}
return nil, &NotLoadedError{edge: "usage_logs"}
@@ -133,7 +144,7 @@ func (e UserEdges) UsageLogsOrErr() ([]*UsageLog, error) {
// AttributeValuesOrErr returns the AttributeValues value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) AttributeValuesOrErr() ([]*UserAttributeValue, error) {
if e.loadedTypes[6] {
if e.loadedTypes[7] {
return e.AttributeValues, nil
}
return nil, &NotLoadedError{edge: "attribute_values"}
@@ -142,7 +153,7 @@ func (e UserEdges) AttributeValuesOrErr() ([]*UserAttributeValue, error) {
// PromoCodeUsagesOrErr returns the PromoCodeUsages value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) PromoCodeUsagesOrErr() ([]*PromoCodeUsage, error) {
if e.loadedTypes[7] {
if e.loadedTypes[8] {
return e.PromoCodeUsages, nil
}
return nil, &NotLoadedError{edge: "promo_code_usages"}
@@ -151,7 +162,7 @@ func (e UserEdges) PromoCodeUsagesOrErr() ([]*PromoCodeUsage, error) {
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) {
if e.loadedTypes[8] {
if e.loadedTypes[9] {
return e.UserAllowedGroups, nil
}
return nil, &NotLoadedError{edge: "user_allowed_groups"}
@@ -313,6 +324,11 @@ func (_m *User) QueryAssignedSubscriptions() *UserSubscriptionQuery {
return NewUserClient(_m.config).QueryAssignedSubscriptions(_m)
}
// QueryAnnouncementReads queries the "announcement_reads" edge of the User entity.
func (_m *User) QueryAnnouncementReads() *AnnouncementReadQuery {
return NewUserClient(_m.config).QueryAnnouncementReads(_m)
}
// QueryAllowedGroups queries the "allowed_groups" edge of the User entity.
func (_m *User) QueryAllowedGroups() *GroupQuery {
return NewUserClient(_m.config).QueryAllowedGroups(_m)

View File

@@ -51,6 +51,8 @@ const (
EdgeSubscriptions = "subscriptions"
// EdgeAssignedSubscriptions holds the string denoting the assigned_subscriptions edge name in mutations.
EdgeAssignedSubscriptions = "assigned_subscriptions"
// EdgeAnnouncementReads holds the string denoting the announcement_reads edge name in mutations.
EdgeAnnouncementReads = "announcement_reads"
// EdgeAllowedGroups holds the string denoting the allowed_groups edge name in mutations.
EdgeAllowedGroups = "allowed_groups"
// EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations.
@@ -91,6 +93,13 @@ const (
AssignedSubscriptionsInverseTable = "user_subscriptions"
// AssignedSubscriptionsColumn is the table column denoting the assigned_subscriptions relation/edge.
AssignedSubscriptionsColumn = "assigned_by"
// AnnouncementReadsTable is the table that holds the announcement_reads relation/edge.
AnnouncementReadsTable = "announcement_reads"
// AnnouncementReadsInverseTable is the table name for the AnnouncementRead entity.
// It exists in this package in order to avoid circular dependency with the "announcementread" package.
AnnouncementReadsInverseTable = "announcement_reads"
// AnnouncementReadsColumn is the table column denoting the announcement_reads relation/edge.
AnnouncementReadsColumn = "user_id"
// AllowedGroupsTable is the table that holds the allowed_groups relation/edge. The primary key declared below.
AllowedGroupsTable = "user_allowed_groups"
// AllowedGroupsInverseTable is the table name for the Group entity.
@@ -335,6 +344,20 @@ func ByAssignedSubscriptions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOp
}
}
// ByAnnouncementReadsCount orders the results by announcement_reads count.
func ByAnnouncementReadsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newAnnouncementReadsStep(), opts...)
}
}
// ByAnnouncementReads orders the results by announcement_reads terms.
func ByAnnouncementReads(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newAnnouncementReadsStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
// ByAllowedGroupsCount orders the results by allowed_groups count.
func ByAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
@@ -432,6 +455,13 @@ func newAssignedSubscriptionsStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.O2M, false, AssignedSubscriptionsTable, AssignedSubscriptionsColumn),
)
}
func newAnnouncementReadsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(AnnouncementReadsInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, AnnouncementReadsTable, AnnouncementReadsColumn),
)
}
func newAllowedGroupsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),

View File

@@ -952,6 +952,29 @@ func HasAssignedSubscriptionsWith(preds ...predicate.UserSubscription) predicate
})
}
// HasAnnouncementReads applies the HasEdge predicate on the "announcement_reads" edge.
func HasAnnouncementReads() predicate.User {
return predicate.User(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, AnnouncementReadsTable, AnnouncementReadsColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasAnnouncementReadsWith applies the HasEdge predicate on the "announcement_reads" edge with a given conditions (other predicates).
func HasAnnouncementReadsWith(preds ...predicate.AnnouncementRead) predicate.User {
return predicate.User(func(s *sql.Selector) {
step := newAnnouncementReadsStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasAllowedGroups applies the HasEdge predicate on the "allowed_groups" edge.
func HasAllowedGroups() predicate.User {
return predicate.User(func(s *sql.Selector) {

View File

@@ -11,6 +11,7 @@ import (
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
@@ -269,6 +270,21 @@ func (_c *UserCreate) AddAssignedSubscriptions(v ...*UserSubscription) *UserCrea
return _c.AddAssignedSubscriptionIDs(ids...)
}
// AddAnnouncementReadIDs adds the "announcement_reads" edge to the AnnouncementRead entity by IDs.
func (_c *UserCreate) AddAnnouncementReadIDs(ids ...int64) *UserCreate {
_c.mutation.AddAnnouncementReadIDs(ids...)
return _c
}
// AddAnnouncementReads adds the "announcement_reads" edges to the AnnouncementRead entity.
func (_c *UserCreate) AddAnnouncementReads(v ...*AnnouncementRead) *UserCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddAnnouncementReadIDs(ids...)
}
// AddAllowedGroupIDs adds the "allowed_groups" edge to the Group entity by IDs.
func (_c *UserCreate) AddAllowedGroupIDs(ids ...int64) *UserCreate {
_c.mutation.AddAllowedGroupIDs(ids...)
@@ -618,6 +634,22 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
}
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.AnnouncementReadsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AnnouncementReadsTable,
Columns: []string{user.AnnouncementReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.AllowedGroupsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2M,

View File

@@ -13,6 +13,7 @@ import (
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
@@ -36,6 +37,7 @@ type UserQuery struct {
withRedeemCodes *RedeemCodeQuery
withSubscriptions *UserSubscriptionQuery
withAssignedSubscriptions *UserSubscriptionQuery
withAnnouncementReads *AnnouncementReadQuery
withAllowedGroups *GroupQuery
withUsageLogs *UsageLogQuery
withAttributeValues *UserAttributeValueQuery
@@ -166,6 +168,28 @@ func (_q *UserQuery) QueryAssignedSubscriptions() *UserSubscriptionQuery {
return query
}
// QueryAnnouncementReads chains the current query on the "announcement_reads" edge.
func (_q *UserQuery) QueryAnnouncementReads() *AnnouncementReadQuery {
query := (&AnnouncementReadClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, selector),
sqlgraph.To(announcementread.Table, announcementread.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.AnnouncementReadsTable, user.AnnouncementReadsColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryAllowedGroups chains the current query on the "allowed_groups" edge.
func (_q *UserQuery) QueryAllowedGroups() *GroupQuery {
query := (&GroupClient{config: _q.config}).Query()
@@ -472,6 +496,7 @@ func (_q *UserQuery) Clone() *UserQuery {
withRedeemCodes: _q.withRedeemCodes.Clone(),
withSubscriptions: _q.withSubscriptions.Clone(),
withAssignedSubscriptions: _q.withAssignedSubscriptions.Clone(),
withAnnouncementReads: _q.withAnnouncementReads.Clone(),
withAllowedGroups: _q.withAllowedGroups.Clone(),
withUsageLogs: _q.withUsageLogs.Clone(),
withAttributeValues: _q.withAttributeValues.Clone(),
@@ -527,6 +552,17 @@ func (_q *UserQuery) WithAssignedSubscriptions(opts ...func(*UserSubscriptionQue
return _q
}
// WithAnnouncementReads tells the query-builder to eager-load the nodes that are connected to
// the "announcement_reads" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithAnnouncementReads(opts ...func(*AnnouncementReadQuery)) *UserQuery {
query := (&AnnouncementReadClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withAnnouncementReads = query
return _q
}
// WithAllowedGroups tells the query-builder to eager-load the nodes that are connected to
// the "allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithAllowedGroups(opts ...func(*GroupQuery)) *UserQuery {
@@ -660,11 +696,12 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
var (
nodes = []*User{}
_spec = _q.querySpec()
loadedTypes = [9]bool{
loadedTypes = [10]bool{
_q.withAPIKeys != nil,
_q.withRedeemCodes != nil,
_q.withSubscriptions != nil,
_q.withAssignedSubscriptions != nil,
_q.withAnnouncementReads != nil,
_q.withAllowedGroups != nil,
_q.withUsageLogs != nil,
_q.withAttributeValues != nil,
@@ -723,6 +760,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
return nil, err
}
}
if query := _q.withAnnouncementReads; query != nil {
if err := _q.loadAnnouncementReads(ctx, query, nodes,
func(n *User) { n.Edges.AnnouncementReads = []*AnnouncementRead{} },
func(n *User, e *AnnouncementRead) { n.Edges.AnnouncementReads = append(n.Edges.AnnouncementReads, e) }); err != nil {
return nil, err
}
}
if query := _q.withAllowedGroups; query != nil {
if err := _q.loadAllowedGroups(ctx, query, nodes,
func(n *User) { n.Edges.AllowedGroups = []*Group{} },
@@ -887,6 +931,36 @@ func (_q *UserQuery) loadAssignedSubscriptions(ctx context.Context, query *UserS
}
return nil
}
func (_q *UserQuery) loadAnnouncementReads(ctx context.Context, query *AnnouncementReadQuery, nodes []*User, init func(*User), assign func(*User, *AnnouncementRead)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(announcementread.FieldUserID)
}
query.Where(predicate.AnnouncementRead(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(user.AnnouncementReadsColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.UserID
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *UserQuery) loadAllowedGroups(ctx context.Context, query *GroupQuery, nodes []*User, init func(*User), assign func(*User, *Group)) error {
edgeIDs := make([]driver.Value, len(nodes))
byID := make(map[int64]*User)

View File

@@ -11,6 +11,7 @@ import (
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
@@ -301,6 +302,21 @@ func (_u *UserUpdate) AddAssignedSubscriptions(v ...*UserSubscription) *UserUpda
return _u.AddAssignedSubscriptionIDs(ids...)
}
// AddAnnouncementReadIDs adds the "announcement_reads" edge to the AnnouncementRead entity by IDs.
func (_u *UserUpdate) AddAnnouncementReadIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAnnouncementReadIDs(ids...)
return _u
}
// AddAnnouncementReads adds the "announcement_reads" edges to the AnnouncementRead entity.
func (_u *UserUpdate) AddAnnouncementReads(v ...*AnnouncementRead) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddAnnouncementReadIDs(ids...)
}
// AddAllowedGroupIDs adds the "allowed_groups" edge to the Group entity by IDs.
func (_u *UserUpdate) AddAllowedGroupIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAllowedGroupIDs(ids...)
@@ -450,6 +466,27 @@ func (_u *UserUpdate) RemoveAssignedSubscriptions(v ...*UserSubscription) *UserU
return _u.RemoveAssignedSubscriptionIDs(ids...)
}
// ClearAnnouncementReads clears all "announcement_reads" edges to the AnnouncementRead entity.
func (_u *UserUpdate) ClearAnnouncementReads() *UserUpdate {
_u.mutation.ClearAnnouncementReads()
return _u
}
// RemoveAnnouncementReadIDs removes the "announcement_reads" edge to AnnouncementRead entities by IDs.
func (_u *UserUpdate) RemoveAnnouncementReadIDs(ids ...int64) *UserUpdate {
_u.mutation.RemoveAnnouncementReadIDs(ids...)
return _u
}
// RemoveAnnouncementReads removes "announcement_reads" edges to AnnouncementRead entities.
func (_u *UserUpdate) RemoveAnnouncementReads(v ...*AnnouncementRead) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveAnnouncementReadIDs(ids...)
}
// ClearAllowedGroups clears all "allowed_groups" edges to the Group entity.
func (_u *UserUpdate) ClearAllowedGroups() *UserUpdate {
_u.mutation.ClearAllowedGroups()
@@ -852,6 +889,51 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.AnnouncementReadsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AnnouncementReadsTable,
Columns: []string{user.AnnouncementReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedAnnouncementReadsIDs(); len(nodes) > 0 && !_u.mutation.AnnouncementReadsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AnnouncementReadsTable,
Columns: []string{user.AnnouncementReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AnnouncementReadsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AnnouncementReadsTable,
Columns: []string{user.AnnouncementReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.AllowedGroupsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2M,
@@ -1330,6 +1412,21 @@ func (_u *UserUpdateOne) AddAssignedSubscriptions(v ...*UserSubscription) *UserU
return _u.AddAssignedSubscriptionIDs(ids...)
}
// AddAnnouncementReadIDs adds the "announcement_reads" edge to the AnnouncementRead entity by IDs.
func (_u *UserUpdateOne) AddAnnouncementReadIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAnnouncementReadIDs(ids...)
return _u
}
// AddAnnouncementReads adds the "announcement_reads" edges to the AnnouncementRead entity.
func (_u *UserUpdateOne) AddAnnouncementReads(v ...*AnnouncementRead) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddAnnouncementReadIDs(ids...)
}
// AddAllowedGroupIDs adds the "allowed_groups" edge to the Group entity by IDs.
func (_u *UserUpdateOne) AddAllowedGroupIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAllowedGroupIDs(ids...)
@@ -1479,6 +1576,27 @@ func (_u *UserUpdateOne) RemoveAssignedSubscriptions(v ...*UserSubscription) *Us
return _u.RemoveAssignedSubscriptionIDs(ids...)
}
// ClearAnnouncementReads clears all "announcement_reads" edges to the AnnouncementRead entity.
func (_u *UserUpdateOne) ClearAnnouncementReads() *UserUpdateOne {
_u.mutation.ClearAnnouncementReads()
return _u
}
// RemoveAnnouncementReadIDs removes the "announcement_reads" edge to AnnouncementRead entities by IDs.
func (_u *UserUpdateOne) RemoveAnnouncementReadIDs(ids ...int64) *UserUpdateOne {
_u.mutation.RemoveAnnouncementReadIDs(ids...)
return _u
}
// RemoveAnnouncementReads removes "announcement_reads" edges to AnnouncementRead entities.
func (_u *UserUpdateOne) RemoveAnnouncementReads(v ...*AnnouncementRead) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveAnnouncementReadIDs(ids...)
}
// ClearAllowedGroups clears all "allowed_groups" edges to the Group entity.
func (_u *UserUpdateOne) ClearAllowedGroups() *UserUpdateOne {
_u.mutation.ClearAllowedGroups()
@@ -1911,6 +2029,51 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.AnnouncementReadsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AnnouncementReadsTable,
Columns: []string{user.AnnouncementReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedAnnouncementReadsIDs(); len(nodes) > 0 && !_u.mutation.AnnouncementReadsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AnnouncementReadsTable,
Columns: []string{user.AnnouncementReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AnnouncementReadsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AnnouncementReadsTable,
Columns: []string{user.AnnouncementReadsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(announcementread.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.AllowedGroupsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2M,

View File

@@ -1,9 +1,11 @@
module github.com/Wei-Shaw/sub2api
go 1.25.5
go 1.25.6
require (
entgo.io/ent v0.14.5
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/dgraph-io/ristretto v0.2.0
github.com/gin-gonic/gin v1.9.1
github.com/golang-jwt/jwt/v5 v5.2.2
github.com/google/uuid v1.6.0
@@ -11,7 +13,10 @@ require (
github.com/gorilla/websocket v1.5.3
github.com/imroc/req/v3 v3.57.0
github.com/lib/pq v1.10.9
github.com/pquerna/otp v1.5.0
github.com/redis/go-redis/v9 v9.17.2
github.com/refraction-networking/utls v1.8.1
github.com/robfig/cron/v3 v3.0.1
github.com/shirou/gopsutil/v4 v4.25.6
github.com/spf13/viper v1.18.2
github.com/stretchr/testify v1.11.1
@@ -25,13 +30,13 @@ require (
golang.org/x/sync v0.19.0
golang.org/x/term v0.38.0
gopkg.in/yaml.v3 v3.0.1
modernc.org/sqlite v1.44.3
)
require (
ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9 // indirect
dario.cat/mergo v1.0.2 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/agext/levenshtein v1.2.3 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect
@@ -48,7 +53,6 @@ require (
github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dgraph-io/ristretto v0.2.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v28.5.1+incompatible // indirect
@@ -107,13 +111,10 @@ require (
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/pquerna/otp v1.5.0 // indirect
github.com/quic-go/qpack v0.6.0 // indirect
github.com/quic-go/quic-go v0.57.1 // indirect
github.com/refraction-networking/utls v1.8.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
@@ -149,12 +150,10 @@ require (
golang.org/x/sys v0.39.0 // indirect
golang.org/x/text v0.32.0 // indirect
golang.org/x/tools v0.39.0 // indirect
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated // indirect
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
modernc.org/libc v1.67.6 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect
modernc.org/sqlite v1.44.1 // indirect
)

View File

@@ -55,6 +55,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgraph-io/ristretto v0.2.0 h1:XAfl+7cmoUDWW/2Lx8TGZQjjxIQ2Ley9DSf52dru4WE=
github.com/dgraph-io/ristretto v0.2.0/go.mod h1:8uBHCU/PBV4Ag0CJrP47b9Ofby5dqWNh4FicAdoqFNU=
github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 h1:fAjc9m62+UWV/WAFKLNi6ZS0675eEUC9y3AlwSbQu1Y=
github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
@@ -113,6 +115,8 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
@@ -123,6 +127,9 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4=
github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/hashicorp/hcl/v2 v2.18.1 h1:6nxnOJFku1EuSawSD81fuviYUV8DxFr3fp2dUi3ZYSo=
@@ -345,8 +352,6 @@ golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
@@ -374,9 +379,8 @@ golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
golang.org/x/tools/go/expect v0.1.0-deprecated h1:jY2C5HGYR5lqex3gEniOQL0r7Dq5+VGVgY1nudX5lXY=
golang.org/x/tools/go/expect v0.1.0-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY=
golang.org/x/tools/go/expect v0.1.1-deprecated h1:jpBZDwmgPhXsKZC6WhL20P4b/wmnpsEAGHaNy0n/rJM=
golang.org/x/tools/go/expect v0.1.1-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY=
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated h1:1h2MnaIAIXISqTFKdENegdpAgUXz6NrPEsbIeWaBRvM=
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated/go.mod h1:RVAQXBGNv1ib0J382/DPCRS/BPnsGebyM1Gj5VSDpG8=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -399,12 +403,32 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=
modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM=
modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA=
modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE=
modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI=
modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
modernc.org/sqlite v1.44.1 h1:qybx/rNpfQipX/t47OxbHmkkJuv2JWifCMH8SVUiDas=
modernc.org/sqlite v1.44.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.44.3 h1:+39JvV/HWMcYslAwRxHb8067w+2zowvFOUrOWIy9PjY=
modernc.org/sqlite v1.44.3/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

View File

@@ -415,6 +415,8 @@ type RedisConfig struct {
PoolSize int `mapstructure:"pool_size"`
// MinIdleConns: 最小空闲连接数,保持热连接减少冷启动延迟
MinIdleConns int `mapstructure:"min_idle_conns"`
// EnableTLS: 是否启用 TLS/SSL 连接
EnableTLS bool `mapstructure:"enable_tls"`
}
func (r *RedisConfig) Address() string {
@@ -762,6 +764,7 @@ func setDefaults() {
viper.SetDefault("redis.write_timeout_seconds", 3)
viper.SetDefault("redis.pool_size", 128)
viper.SetDefault("redis.min_idle_conns", 10)
viper.SetDefault("redis.enable_tls", false)
// Ops (vNext)
viper.SetDefault("ops.enabled", true)

View File

@@ -0,0 +1,226 @@
package domain
import (
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
const (
AnnouncementStatusDraft = "draft"
AnnouncementStatusActive = "active"
AnnouncementStatusArchived = "archived"
)
const (
AnnouncementConditionTypeSubscription = "subscription"
AnnouncementConditionTypeBalance = "balance"
)
const (
AnnouncementOperatorIn = "in"
AnnouncementOperatorGT = "gt"
AnnouncementOperatorGTE = "gte"
AnnouncementOperatorLT = "lt"
AnnouncementOperatorLTE = "lte"
AnnouncementOperatorEQ = "eq"
)
var (
ErrAnnouncementNotFound = infraerrors.NotFound("ANNOUNCEMENT_NOT_FOUND", "announcement not found")
ErrAnnouncementInvalidTarget = infraerrors.BadRequest("ANNOUNCEMENT_INVALID_TARGET", "invalid announcement targeting rules")
)
type AnnouncementTargeting struct {
// AnyOf 表示 OR任意一个条件组满足即可展示。
AnyOf []AnnouncementConditionGroup `json:"any_of,omitempty"`
}
type AnnouncementConditionGroup struct {
// AllOf 表示 AND组内所有条件都满足才算命中该组。
AllOf []AnnouncementCondition `json:"all_of,omitempty"`
}
type AnnouncementCondition struct {
// Type: subscription | balance
Type string `json:"type"`
// Operator:
// - subscription: in
// - balance: gt/gte/lt/lte/eq
Operator string `json:"operator"`
// subscription 条件匹配的订阅套餐group_id
GroupIDs []int64 `json:"group_ids,omitempty"`
// balance 条件:比较阈值
Value float64 `json:"value,omitempty"`
}
func (t AnnouncementTargeting) Matches(balance float64, activeSubscriptionGroupIDs map[int64]struct{}) bool {
// 空规则:展示给所有用户
if len(t.AnyOf) == 0 {
return true
}
for _, group := range t.AnyOf {
if len(group.AllOf) == 0 {
// 空条件组不命中(避免 OR 中出现无条件 “全命中”)
continue
}
allMatched := true
for _, cond := range group.AllOf {
if !cond.Matches(balance, activeSubscriptionGroupIDs) {
allMatched = false
break
}
}
if allMatched {
return true
}
}
return false
}
func (c AnnouncementCondition) Matches(balance float64, activeSubscriptionGroupIDs map[int64]struct{}) bool {
switch c.Type {
case AnnouncementConditionTypeSubscription:
if c.Operator != AnnouncementOperatorIn {
return false
}
if len(c.GroupIDs) == 0 {
return false
}
if len(activeSubscriptionGroupIDs) == 0 {
return false
}
for _, gid := range c.GroupIDs {
if _, ok := activeSubscriptionGroupIDs[gid]; ok {
return true
}
}
return false
case AnnouncementConditionTypeBalance:
switch c.Operator {
case AnnouncementOperatorGT:
return balance > c.Value
case AnnouncementOperatorGTE:
return balance >= c.Value
case AnnouncementOperatorLT:
return balance < c.Value
case AnnouncementOperatorLTE:
return balance <= c.Value
case AnnouncementOperatorEQ:
return balance == c.Value
default:
return false
}
default:
return false
}
}
func (t AnnouncementTargeting) NormalizeAndValidate() (AnnouncementTargeting, error) {
normalized := AnnouncementTargeting{AnyOf: make([]AnnouncementConditionGroup, 0, len(t.AnyOf))}
// 允许空 targeting展示给所有用户
if len(t.AnyOf) == 0 {
return normalized, nil
}
if len(t.AnyOf) > 50 {
return AnnouncementTargeting{}, ErrAnnouncementInvalidTarget
}
for _, g := range t.AnyOf {
if len(g.AllOf) == 0 {
return AnnouncementTargeting{}, ErrAnnouncementInvalidTarget
}
if len(g.AllOf) > 50 {
return AnnouncementTargeting{}, ErrAnnouncementInvalidTarget
}
group := AnnouncementConditionGroup{AllOf: make([]AnnouncementCondition, 0, len(g.AllOf))}
for _, c := range g.AllOf {
cond := AnnouncementCondition{
Type: strings.TrimSpace(c.Type),
Operator: strings.TrimSpace(c.Operator),
Value: c.Value,
}
for _, gid := range c.GroupIDs {
if gid <= 0 {
return AnnouncementTargeting{}, ErrAnnouncementInvalidTarget
}
cond.GroupIDs = append(cond.GroupIDs, gid)
}
if err := cond.validate(); err != nil {
return AnnouncementTargeting{}, err
}
group.AllOf = append(group.AllOf, cond)
}
normalized.AnyOf = append(normalized.AnyOf, group)
}
return normalized, nil
}
func (c AnnouncementCondition) validate() error {
switch c.Type {
case AnnouncementConditionTypeSubscription:
if c.Operator != AnnouncementOperatorIn {
return ErrAnnouncementInvalidTarget
}
if len(c.GroupIDs) == 0 {
return ErrAnnouncementInvalidTarget
}
return nil
case AnnouncementConditionTypeBalance:
switch c.Operator {
case AnnouncementOperatorGT, AnnouncementOperatorGTE, AnnouncementOperatorLT, AnnouncementOperatorLTE, AnnouncementOperatorEQ:
return nil
default:
return ErrAnnouncementInvalidTarget
}
default:
return ErrAnnouncementInvalidTarget
}
}
type Announcement struct {
ID int64
Title string
Content string
Status string
Targeting AnnouncementTargeting
StartsAt *time.Time
EndsAt *time.Time
CreatedBy *int64
UpdatedBy *int64
CreatedAt time.Time
UpdatedAt time.Time
}
func (a *Announcement) IsActiveAt(now time.Time) bool {
if a == nil {
return false
}
if a.Status != AnnouncementStatusActive {
return false
}
if a.StartsAt != nil && now.Before(*a.StartsAt) {
return false
}
if a.EndsAt != nil && !now.Before(*a.EndsAt) {
// ends_at 语义:到点即下线
return false
}
return true
}

View File

@@ -0,0 +1,66 @@
package domain
// Status constants
const (
StatusActive = "active"
StatusDisabled = "disabled"
StatusError = "error"
StatusUnused = "unused"
StatusUsed = "used"
StatusExpired = "expired"
)
// Role constants
const (
RoleAdmin = "admin"
RoleUser = "user"
)
// Platform constants
const (
PlatformAnthropic = "anthropic"
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
PlatformAntigravity = "antigravity"
)
// Account type constants
const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeAPIKey = "apikey" // API Key类型账号
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
)
// Redeem type constants
const (
RedeemTypeBalance = "balance"
RedeemTypeConcurrency = "concurrency"
RedeemTypeSubscription = "subscription"
RedeemTypeInvitation = "invitation"
)
// PromoCode status constants
const (
PromoCodeStatusActive = "active"
PromoCodeStatusDisabled = "disabled"
)
// Admin adjustment type constants
const (
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
)
// Group subscription type constants
const (
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
)
// Subscription status constants
const (
SubscriptionStatusActive = "active"
SubscriptionStatusExpired = "expired"
SubscriptionStatusSuspended = "suspended"
)

View File

@@ -84,7 +84,7 @@ type CreateAccountRequest struct {
Name string `json:"name" binding:"required"`
Notes *string `json:"notes"`
Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream"`
Credentials map[string]any `json:"credentials" binding:"required"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
@@ -102,7 +102,7 @@ type CreateAccountRequest struct {
type UpdateAccountRequest struct {
Name string `json:"name"`
Notes *string `json:"notes"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`

View File

@@ -290,5 +290,9 @@ func (s *stubAdminService) ExpireRedeemCode(ctx context.Context, id int64) (*ser
return &code, nil
}
func (s *stubAdminService) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]service.RedeemCode, int64, float64, error) {
return s.redeems, int64(len(s.redeems)), 100.0, nil
}
// Ensure stub implements interface.
var _ service.AdminService = (*stubAdminService)(nil)

View File

@@ -0,0 +1,246 @@
package admin
import (
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AnnouncementHandler handles admin announcement management
type AnnouncementHandler struct {
announcementService *service.AnnouncementService
}
// NewAnnouncementHandler creates a new admin announcement handler
func NewAnnouncementHandler(announcementService *service.AnnouncementService) *AnnouncementHandler {
return &AnnouncementHandler{
announcementService: announcementService,
}
}
type CreateAnnouncementRequest struct {
Title string `json:"title" binding:"required"`
Content string `json:"content" binding:"required"`
Status string `json:"status" binding:"omitempty,oneof=draft active archived"`
Targeting service.AnnouncementTargeting `json:"targeting"`
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0/empty = immediate
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0/empty = never
}
type UpdateAnnouncementRequest struct {
Title *string `json:"title"`
Content *string `json:"content"`
Status *string `json:"status" binding:"omitempty,oneof=draft active archived"`
Targeting *service.AnnouncementTargeting `json:"targeting"`
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0 = clear
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0 = clear
}
// List handles listing announcements with filters
// GET /api/v1/admin/announcements
func (h *AnnouncementHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
status := strings.TrimSpace(c.Query("status"))
search := strings.TrimSpace(c.Query("search"))
if len(search) > 200 {
search = search[:200]
}
params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
}
items, paginationResult, err := h.announcementService.List(
c.Request.Context(),
params,
service.AnnouncementListFilters{Status: status, Search: search},
)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.Announcement, 0, len(items))
for i := range items {
out = append(out, *dto.AnnouncementFromService(&items[i]))
}
response.Paginated(c, out, paginationResult.Total, page, pageSize)
}
// GetByID handles getting an announcement by ID
// GET /api/v1/admin/announcements/:id
func (h *AnnouncementHandler) GetByID(c *gin.Context) {
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || announcementID <= 0 {
response.BadRequest(c, "Invalid announcement ID")
return
}
item, err := h.announcementService.GetByID(c.Request.Context(), announcementID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.AnnouncementFromService(item))
}
// Create handles creating a new announcement
// POST /api/v1/admin/announcements
func (h *AnnouncementHandler) Create(c *gin.Context) {
var req CreateAnnouncementRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
input := &service.CreateAnnouncementInput{
Title: req.Title,
Content: req.Content,
Status: req.Status,
Targeting: req.Targeting,
ActorID: &subject.UserID,
}
if req.StartsAt != nil && *req.StartsAt > 0 {
t := time.Unix(*req.StartsAt, 0)
input.StartsAt = &t
}
if req.EndsAt != nil && *req.EndsAt > 0 {
t := time.Unix(*req.EndsAt, 0)
input.EndsAt = &t
}
created, err := h.announcementService.Create(c.Request.Context(), input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.AnnouncementFromService(created))
}
// Update handles updating an announcement
// PUT /api/v1/admin/announcements/:id
func (h *AnnouncementHandler) Update(c *gin.Context) {
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || announcementID <= 0 {
response.BadRequest(c, "Invalid announcement ID")
return
}
var req UpdateAnnouncementRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
input := &service.UpdateAnnouncementInput{
Title: req.Title,
Content: req.Content,
Status: req.Status,
Targeting: req.Targeting,
ActorID: &subject.UserID,
}
if req.StartsAt != nil {
if *req.StartsAt == 0 {
var cleared *time.Time = nil
input.StartsAt = &cleared
} else {
t := time.Unix(*req.StartsAt, 0)
ptr := &t
input.StartsAt = &ptr
}
}
if req.EndsAt != nil {
if *req.EndsAt == 0 {
var cleared *time.Time = nil
input.EndsAt = &cleared
} else {
t := time.Unix(*req.EndsAt, 0)
ptr := &t
input.EndsAt = &ptr
}
}
updated, err := h.announcementService.Update(c.Request.Context(), announcementID, input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.AnnouncementFromService(updated))
}
// Delete handles deleting an announcement
// DELETE /api/v1/admin/announcements/:id
func (h *AnnouncementHandler) Delete(c *gin.Context) {
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || announcementID <= 0 {
response.BadRequest(c, "Invalid announcement ID")
return
}
if err := h.announcementService.Delete(c.Request.Context(), announcementID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Announcement deleted successfully"})
}
// ListReadStatus handles listing users read status for an announcement
// GET /api/v1/admin/announcements/:id/read-status
func (h *AnnouncementHandler) ListReadStatus(c *gin.Context) {
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || announcementID <= 0 {
response.BadRequest(c, "Invalid announcement ID")
return
}
page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
}
search := strings.TrimSpace(c.Query("search"))
if len(search) > 200 {
search = search[:200]
}
items, paginationResult, err := h.announcementService.ListUserReadStatus(
c.Request.Context(),
announcementID,
params,
search,
)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, items, paginationResult.Total, page, pageSize)
}

View File

@@ -35,14 +35,20 @@ type CreateGroupRequest struct {
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
// 图片生成计费配置antigravity 和 gemini 平台使用,负数表示清除配置)
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`
// 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
// UpdateGroupRequest represents update group request
@@ -58,14 +64,20 @@ type UpdateGroupRequest struct {
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
// 图片生成计费配置antigravity 和 gemini 平台使用,负数表示清除配置)
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
ClaudeCodeOnly *bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
ClaudeCodeOnly *bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string `json:"supported_model_scopes"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
// List handles listing all groups with pagination
@@ -155,22 +167,26 @@ func (h *GroupHandler) Create(c *gin.Context) {
}
group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{
Name: req.Name,
Description: req.Description,
Platform: req.Platform,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD,
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled,
Name: req.Name,
Description: req.Description,
Platform: req.Platform,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD,
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -196,23 +212,27 @@ func (h *GroupHandler) Update(c *gin.Context) {
}
group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{
Name: req.Name,
Description: req.Description,
Platform: req.Platform,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
Status: req.Status,
SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD,
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled,
Name: req.Name,
Description: req.Description,
Platform: req.Platform,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
Status: req.Status,
SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD,
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
response.ErrorFrom(c, err)

View File

@@ -29,7 +29,7 @@ func NewRedeemHandler(adminService service.AdminService) *RedeemHandler {
// GenerateRedeemCodesRequest represents generate redeem codes request
type GenerateRedeemCodesRequest struct {
Count int `json:"count" binding:"required,min=1,max=100"`
Type string `json:"type" binding:"required,oneof=balance concurrency subscription"`
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
Value float64 `json:"value" binding:"min=0"`
GroupID *int64 `json:"group_id"` // 订阅类型必填
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用默认30天最大100年

View File

@@ -49,6 +49,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
TotpEnabled: settings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
SMTPHost: settings.SMTPHost,
@@ -73,6 +74,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
EnableModelFallback: settings.EnableModelFallback,
@@ -92,11 +95,12 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
// UpdateSettingsRequest 更新设置请求
type UpdateSettingsRequest struct {
// 注册设置
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
// 邮件服务设置
SMTPHost string `json:"smtp_host"`
@@ -119,14 +123,16 @@ type UpdateSettingsRequest struct {
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
// OEM设置
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
@@ -242,6 +248,34 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
// “购买订阅”页面配置验证
purchaseEnabled := previousSettings.PurchaseSubscriptionEnabled
if req.PurchaseSubscriptionEnabled != nil {
purchaseEnabled = *req.PurchaseSubscriptionEnabled
}
purchaseURL := previousSettings.PurchaseSubscriptionURL
if req.PurchaseSubscriptionURL != nil {
purchaseURL = strings.TrimSpace(*req.PurchaseSubscriptionURL)
}
// - 启用时要求 URL 合法且非空
// - 禁用时允许为空;若提供了 URL 也做基本校验,避免误配置
if purchaseEnabled {
if purchaseURL == "" {
response.BadRequest(c, "Purchase Subscription URL is required when enabled")
return
}
if err := config.ValidateAbsoluteHTTPURL(purchaseURL); err != nil {
response.BadRequest(c, "Purchase Subscription URL must be an absolute http(s) URL")
return
}
} else if purchaseURL != "" {
if err := config.ValidateAbsoluteHTTPURL(purchaseURL); err != nil {
response.BadRequest(c, "Purchase Subscription URL must be an absolute http(s) URL")
return
}
}
// Ops metrics collector interval validation (seconds).
if req.OpsMetricsIntervalSeconds != nil {
v := *req.OpsMetricsIntervalSeconds
@@ -255,42 +289,45 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
PromoCodeEnabled: req.PromoCodeEnabled,
PasswordResetEnabled: req.PasswordResetEnabled,
TotpEnabled: req.TotpEnabled,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
APIBaseURL: req.APIBaseURL,
ContactInfo: req.ContactInfo,
DocURL: req.DocURL,
HomeContent: req.HomeContent,
HideCcsImportButton: req.HideCcsImportButton,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
PromoCodeEnabled: req.PromoCodeEnabled,
PasswordResetEnabled: req.PasswordResetEnabled,
InvitationCodeEnabled: req.InvitationCodeEnabled,
TotpEnabled: req.TotpEnabled,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
APIBaseURL: req.APIBaseURL,
ContactInfo: req.ContactInfo,
DocURL: req.DocURL,
HomeContent: req.HomeContent,
HideCcsImportButton: req.HideCcsImportButton,
PurchaseSubscriptionEnabled: purchaseEnabled,
PurchaseSubscriptionURL: purchaseURL,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
OpsMonitoringEnabled: func() bool {
if req.OpsMonitoringEnabled != nil {
return *req.OpsMonitoringEnabled
@@ -336,6 +373,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
TotpEnabled: updatedSettings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
SMTPHost: updatedSettings.SMTPHost,
@@ -360,6 +398,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
DocURL: updatedSettings.DocURL,
HomeContent: updatedSettings.HomeContent,
HideCcsImportButton: updatedSettings.HideCcsImportButton,
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
EnableModelFallback: updatedSettings.EnableModelFallback,

View File

@@ -277,3 +277,44 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) {
response.Success(c, stats)
}
// GetBalanceHistory handles getting user's balance/concurrency change history
// GET /api/v1/admin/users/:id/balance-history
// Query params:
// - type: filter by record type (balance, admin_balance, concurrency, admin_concurrency, subscription)
func (h *UserHandler) GetBalanceHistory(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
page, pageSize := response.ParsePagination(c)
codeType := c.Query("type")
codes, total, totalRecharged, err := h.adminService.GetUserBalanceHistory(c.Request.Context(), userID, page, pageSize, codeType)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Convert to admin DTO (includes notes field for admin visibility)
out := make([]dto.AdminRedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
}
// Custom response with total_recharged alongside pagination
pages := int((total + int64(pageSize) - 1) / int64(pageSize))
if pages < 1 {
pages = 1
}
response.Success(c, gin.H{
"items": out,
"total": total,
"page": page,
"page_size": pageSize,
"pages": pages,
"total_recharged": totalRecharged,
})
}

View File

@@ -0,0 +1,81 @@
package handler
import (
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AnnouncementHandler handles user announcement operations
type AnnouncementHandler struct {
announcementService *service.AnnouncementService
}
// NewAnnouncementHandler creates a new user announcement handler
func NewAnnouncementHandler(announcementService *service.AnnouncementService) *AnnouncementHandler {
return &AnnouncementHandler{
announcementService: announcementService,
}
}
// List handles listing announcements visible to current user
// GET /api/v1/announcements
func (h *AnnouncementHandler) List(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
unreadOnly := parseBoolQuery(c.Query("unread_only"))
items, err := h.announcementService.ListForUser(c.Request.Context(), subject.UserID, unreadOnly)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.UserAnnouncement, 0, len(items))
for i := range items {
out = append(out, *dto.UserAnnouncementFromService(&items[i]))
}
response.Success(c, out)
}
// MarkRead marks an announcement as read for current user
// POST /api/v1/announcements/:id/read
func (h *AnnouncementHandler) MarkRead(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || announcementID <= 0 {
response.BadRequest(c, "Invalid announcement ID")
return
}
if err := h.announcementService.MarkRead(c.Request.Context(), subject.UserID, announcementID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "ok"})
}
func parseBoolQuery(v string) bool {
switch strings.TrimSpace(strings.ToLower(v)) {
case "1", "true", "yes", "y", "on":
return true
default:
return false
}
}

View File

@@ -3,6 +3,7 @@ package handler
import (
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -27,11 +28,13 @@ func NewAPIKeyHandler(apiKeyService *service.APIKeyService) *APIKeyHandler {
// CreateAPIKeyRequest represents the create API key request payload
type CreateAPIKeyRequest struct {
Name string `json:"name" binding:"required"`
GroupID *int64 `json:"group_id"` // nullable
CustomKey *string `json:"custom_key"` // 可选的自定义key
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
Name string `json:"name" binding:"required"`
GroupID *int64 `json:"group_id"` // nullable
CustomKey *string `json:"custom_key"` // 可选的自定义key
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
Quota *float64 `json:"quota"` // 配额限制 (USD)
ExpiresInDays *int `json:"expires_in_days"` // 过期天数
}
// UpdateAPIKeyRequest represents the update API key request payload
@@ -41,6 +44,9 @@ type UpdateAPIKeyRequest struct {
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
Quota *float64 `json:"quota"` // 配额限制 (USD), 0=无限制
ExpiresAt *string `json:"expires_at"` // 过期时间 (ISO 8601)
ResetQuota *bool `json:"reset_quota"` // 重置已用配额
}
// List handles listing user's API keys with pagination
@@ -114,11 +120,15 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
}
svcReq := service.CreateAPIKeyRequest{
Name: req.Name,
GroupID: req.GroupID,
CustomKey: req.CustomKey,
IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist,
Name: req.Name,
GroupID: req.GroupID,
CustomKey: req.CustomKey,
IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist,
ExpiresInDays: req.ExpiresInDays,
}
if req.Quota != nil {
svcReq.Quota = *req.Quota
}
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
if err != nil {
@@ -153,6 +163,8 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
svcReq := service.UpdateAPIKeyRequest{
IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist,
Quota: req.Quota,
ResetQuota: req.ResetQuota,
}
if req.Name != "" {
svcReq.Name = &req.Name
@@ -161,6 +173,21 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
if req.Status != "" {
svcReq.Status = &req.Status
}
// Parse expires_at if provided
if req.ExpiresAt != nil {
if *req.ExpiresAt == "" {
// Empty string means clear expiration
svcReq.ExpiresAt = nil
svcReq.ClearExpiration = true
} else {
t, err := time.Parse(time.RFC3339, *req.ExpiresAt)
if err != nil {
response.BadRequest(c, "Invalid expires_at format: "+err.Error())
return
}
svcReq.ExpiresAt = &t
}
}
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, subject.UserID, svcReq)
if err != nil {

View File

@@ -15,23 +15,25 @@ import (
// AuthHandler handles authentication-related requests
type AuthHandler struct {
cfg *config.Config
authService *service.AuthService
userService *service.UserService
settingSvc *service.SettingService
promoService *service.PromoService
totpService *service.TotpService
cfg *config.Config
authService *service.AuthService
userService *service.UserService
settingSvc *service.SettingService
promoService *service.PromoService
redeemService *service.RedeemService
totpService *service.TotpService
}
// NewAuthHandler creates a new AuthHandler
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, totpService *service.TotpService) *AuthHandler {
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, redeemService *service.RedeemService, totpService *service.TotpService) *AuthHandler {
return &AuthHandler{
cfg: cfg,
authService: authService,
userService: userService,
settingSvc: settingService,
promoService: promoService,
totpService: totpService,
cfg: cfg,
authService: authService,
userService: userService,
settingSvc: settingService,
promoService: promoService,
redeemService: redeemService,
totpService: totpService,
}
}
@@ -41,7 +43,8 @@ type RegisterRequest struct {
Password string `json:"password" binding:"required,min=6"`
VerifyCode string `json:"verify_code"`
TurnstileToken string `json:"turnstile_token"`
PromoCode string `json:"promo_code"` // 注册优惠码
PromoCode string `json:"promo_code"` // 注册优惠码
InvitationCode string `json:"invitation_code"` // 邀请码
}
// SendVerifyCodeRequest 发送验证码请求
@@ -87,7 +90,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
}
}
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode)
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -346,6 +349,67 @@ func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
})
}
// ValidateInvitationCodeRequest 验证邀请码请求
type ValidateInvitationCodeRequest struct {
Code string `json:"code" binding:"required"`
}
// ValidateInvitationCodeResponse 验证邀请码响应
type ValidateInvitationCodeResponse struct {
Valid bool `json:"valid"`
ErrorCode string `json:"error_code,omitempty"`
}
// ValidateInvitationCode 验证邀请码(公开接口,注册前调用)
// POST /api/v1/auth/validate-invitation-code
func (h *AuthHandler) ValidateInvitationCode(c *gin.Context) {
// 检查邀请码功能是否启用
if h.settingSvc == nil || !h.settingSvc.IsInvitationCodeEnabled(c.Request.Context()) {
response.Success(c, ValidateInvitationCodeResponse{
Valid: false,
ErrorCode: "INVITATION_CODE_DISABLED",
})
return
}
var req ValidateInvitationCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// 验证邀请码
redeemCode, err := h.redeemService.GetByCode(c.Request.Context(), req.Code)
if err != nil {
response.Success(c, ValidateInvitationCodeResponse{
Valid: false,
ErrorCode: "INVITATION_CODE_NOT_FOUND",
})
return
}
// 检查类型和状态
if redeemCode.Type != service.RedeemTypeInvitation {
response.Success(c, ValidateInvitationCodeResponse{
Valid: false,
ErrorCode: "INVITATION_CODE_INVALID",
})
return
}
if redeemCode.Status != service.StatusUnused {
response.Success(c, ValidateInvitationCodeResponse{
Valid: false,
ErrorCode: "INVITATION_CODE_USED",
})
return
}
response.Success(c, ValidateInvitationCodeResponse{
Valid: true,
})
}
// ForgotPasswordRequest 忘记密码请求
type ForgotPasswordRequest struct {
Email string `json:"email" binding:"required,email"`

View File

@@ -0,0 +1,74 @@
package dto
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type Announcement struct {
ID int64 `json:"id"`
Title string `json:"title"`
Content string `json:"content"`
Status string `json:"status"`
Targeting service.AnnouncementTargeting `json:"targeting"`
StartsAt *time.Time `json:"starts_at,omitempty"`
EndsAt *time.Time `json:"ends_at,omitempty"`
CreatedBy *int64 `json:"created_by,omitempty"`
UpdatedBy *int64 `json:"updated_by,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type UserAnnouncement struct {
ID int64 `json:"id"`
Title string `json:"title"`
Content string `json:"content"`
StartsAt *time.Time `json:"starts_at,omitempty"`
EndsAt *time.Time `json:"ends_at,omitempty"`
ReadAt *time.Time `json:"read_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func AnnouncementFromService(a *service.Announcement) *Announcement {
if a == nil {
return nil
}
return &Announcement{
ID: a.ID,
Title: a.Title,
Content: a.Content,
Status: a.Status,
Targeting: a.Targeting,
StartsAt: a.StartsAt,
EndsAt: a.EndsAt,
CreatedBy: a.CreatedBy,
UpdatedBy: a.UpdatedBy,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
}
}
func UserAnnouncementFromService(a *service.UserAnnouncement) *UserAnnouncement {
if a == nil {
return nil
}
return &UserAnnouncement{
ID: a.Announcement.ID,
Title: a.Announcement.Title,
Content: a.Announcement.Content,
StartsAt: a.Announcement.StartsAt,
EndsAt: a.Announcement.EndsAt,
ReadAt: a.ReadAt,
CreatedAt: a.Announcement.CreatedAt,
UpdatedAt: a.Announcement.UpdatedAt,
}
}

View File

@@ -76,6 +76,9 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
Status: k.Status,
IPWhitelist: k.IPWhitelist,
IPBlacklist: k.IPBlacklist,
Quota: k.Quota,
QuotaUsed: k.QuotaUsed,
ExpiresAt: k.ExpiresAt,
CreatedAt: k.CreatedAt,
UpdatedAt: k.UpdatedAt,
User: UserFromServiceShallow(k.User),
@@ -105,10 +108,12 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
return nil
}
out := &AdminGroup{
Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
AccountCount: g.AccountCount,
Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.MCPXMLInject,
SupportedModelScopes: g.SupportedModelScopes,
AccountCount: g.AccountCount,
}
if len(g.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
@@ -138,8 +143,10 @@ func groupFromServiceBase(g *service.Group) Group {
ImagePrice4K: g.ImagePrice4K,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
// 无效请求兜底分组
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}
}
@@ -204,6 +211,17 @@ func AccountFromServiceShallow(a *service.Account) *Account {
}
}
if scopeLimits := a.GetAntigravityScopeRateLimits(); len(scopeLimits) > 0 {
out.ScopeRateLimits = make(map[string]ScopeRateLimitInfo, len(scopeLimits))
now := time.Now()
for scope, remainingSec := range scopeLimits {
out.ScopeRateLimits[scope] = ScopeRateLimitInfo{
ResetAt: now.Add(time.Duration(remainingSec) * time.Second),
RemainingSec: remainingSec,
}
}
}
return out
}
@@ -321,7 +339,7 @@ func RedeemCodeFromServiceAdmin(rc *service.RedeemCode) *AdminRedeemCode {
}
func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode {
return RedeemCode{
out := RedeemCode{
ID: rc.ID,
Code: rc.Code,
Type: rc.Type,
@@ -335,6 +353,14 @@ func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode {
User: UserFromServiceShallow(rc.User),
Group: GroupFromServiceShallow(rc.Group),
}
// For admin_balance/admin_concurrency types, include notes so users can see
// why they were charged or credited by admin
if (rc.Type == "admin_balance" || rc.Type == "admin_concurrency") && rc.Notes != "" {
out.Notes = &rc.Notes
}
return out
}
// AccountSummaryFromService returns a minimal AccountSummary for usage log display.
@@ -358,6 +384,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
AccountID: l.AccountID,
RequestID: l.RequestID,
Model: l.Model,
ReasoningEffort: l.ReasoningEffort,
GroupID: l.GroupID,
SubscriptionID: l.SubscriptionID,
InputTokens: l.InputTokens,

View File

@@ -6,6 +6,7 @@ type SystemSettings struct {
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
@@ -26,14 +27,16 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
@@ -57,23 +60,26 @@ type SystemSettings struct {
}
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version"`
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version"`
}
// StreamTimeoutSettings 流超时处理配置 DTO

View File

@@ -2,6 +2,11 @@ package dto
import "time"
type ScopeRateLimitInfo struct {
ResetAt time.Time `json:"reset_at"`
RemainingSec int64 `json:"remaining_sec"`
}
type User struct {
ID int64 `json:"id"`
Email string `json:"email"`
@@ -27,16 +32,19 @@ type AdminUser struct {
}
type APIKey struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
Key string `json:"key"`
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
Status string `json:"status"`
IPWhitelist []string `json:"ip_whitelist"`
IPBlacklist []string `json:"ip_blacklist"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
Key string `json:"key"`
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
Status string `json:"status"`
IPWhitelist []string `json:"ip_whitelist"`
IPBlacklist []string `json:"ip_blacklist"`
Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited)
QuotaUsed float64 `json:"quota_used"` // Used quota amount in USD
ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = never expires)
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
@@ -64,6 +72,8 @@ type Group struct {
// Claude Code 客户端限制
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
// 无效请求兜底分组
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
@@ -78,8 +88,13 @@ type AdminGroup struct {
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
AccountCount int64 `json:"account_count,omitempty"`
// MCP XML 协议注入(仅 antigravity 平台使用)
MCPXMLInject bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
AccountCount int64 `json:"account_count,omitempty"`
}
type Account struct {
@@ -108,6 +123,9 @@ type Account struct {
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
OverloadUntil *time.Time `json:"overload_until"`
// Antigravity scope 级限流状态(从 extra 提取)
ScopeRateLimits map[string]ScopeRateLimitInfo `json:"scope_rate_limits,omitempty"`
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"`
TempUnschedulableReason string `json:"temp_unschedulable_reason"`
@@ -198,6 +216,10 @@ type RedeemCode struct {
GroupID *int64 `json:"group_id"`
ValidityDays int `json:"validity_days"`
// Notes is only populated for admin_balance/admin_concurrency types
// so users can see why they were charged or credited
Notes *string `json:"notes,omitempty"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
}
@@ -218,6 +240,9 @@ type UsageLog struct {
AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"`
Model string `json:"model"`
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
// nil means not provided / not applicable.
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
GroupID *int64 `json:"group_id"`
SubscriptionID *int64 `json:"subscription_id"`

View File

@@ -14,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
@@ -30,6 +31,8 @@ type GatewayHandler struct {
antigravityGatewayService *service.AntigravityGatewayService
userService *service.UserService
billingCacheService *service.BillingCacheService
usageService *service.UsageService
apiKeyService *service.APIKeyService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
maxAccountSwitchesGemini int
@@ -43,6 +46,8 @@ func NewGatewayHandler(
userService *service.UserService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
usageService *service.UsageService,
apiKeyService *service.APIKeyService,
cfg *config.Config,
) *GatewayHandler {
pingInterval := time.Duration(0)
@@ -63,6 +68,8 @@ func NewGatewayHandler(
antigravityGatewayService: antigravityGatewayService,
userService: userService,
billingCacheService: billingCacheService,
usageService: usageService,
apiKeyService: apiKeyService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
maxAccountSwitches: maxAccountSwitches,
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
@@ -278,10 +285,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
requestCtx := c.Request.Context()
if switchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
}
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, body)
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body)
} else {
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
result, err = h.geminiCompatService.Forward(requestCtx, c, account, body)
}
if accountReleaseFunc != nil {
accountReleaseFunc()
@@ -313,13 +324,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: clientIP,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
@@ -328,139 +340,193 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
}
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
currentAPIKey := apiKey
currentSubscription := subscription
var fallbackGroupID *int64
if apiKey.Group != nil {
fallbackGroupID = apiKey.Group.FallbackGroupIDOnInvalidRequest
}
fallbackUsed := false
for {
// 选择支持该模型的账号
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID)
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
retryWithFallback := false
// 检查请求拦截预热请求、SUGGESTION MODE等
if account.IsInterceptWarmupEnabled() {
interceptType := detectInterceptType(body)
if interceptType != InterceptTypeNone {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if reqStream {
sendMockInterceptStream(c, reqModel, interceptType)
} else {
sendMockInterceptResponse(c, reqModel, interceptType)
}
return
}
}
// 3. 获取账号并发槽位
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
for {
// 选择支持该模型的账号
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
if err == nil && canWait {
accountWaitCounted = true
}
defer func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}()
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body)
} else {
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq)
}
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverStatus = failoverErr.StatusCode
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
// 错误响应已在Forward中处理这里只记录日志
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
account := selection.Account
setOpsSelectedAccount(c, account.ID)
// 检查请求拦截预热请求、SUGGESTION MODE等
if account.IsInterceptWarmupEnabled() {
interceptType := detectInterceptType(body)
if interceptType != InterceptTypeNone {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if reqStream {
sendMockInterceptStream(c, reqModel, interceptType)
} else {
sendMockInterceptResponse(c, reqModel, interceptType)
}
return
}
}
// 3. 获取账号并发槽位
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
if err == nil && canWait {
accountWaitCounted = true
}
defer func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}()
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
requestCtx := c.Request.Context()
if switchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
}
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body)
} else {
result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq)
}
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var promptTooLongErr *service.PromptTooLongError
if errors.As(err, &promptTooLongErr) {
log.Printf("Prompt too long from antigravity: group=%d fallback_group_id=%v fallback_used=%v", currentAPIKey.GroupID, fallbackGroupID, fallbackUsed)
if !fallbackUsed && fallbackGroupID != nil && *fallbackGroupID > 0 {
fallbackGroup, err := h.gatewayService.ResolveGroupByID(c.Request.Context(), *fallbackGroupID)
if err != nil {
log.Printf("Resolve fallback group failed: %v", err)
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
return
}
if fallbackGroup.Platform != service.PlatformAnthropic ||
fallbackGroup.SubscriptionType == service.SubscriptionTypeSubscription ||
fallbackGroup.FallbackGroupIDOnInvalidRequest != nil {
log.Printf("Fallback group invalid: group=%d platform=%s subscription=%s", fallbackGroup.ID, fallbackGroup.Platform, fallbackGroup.SubscriptionType)
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
return
}
fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil {
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
// 兜底重试按“直接请求兜底分组”处理:清除强制平台,允许按分组平台调度
ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, "")
c.Request = c.Request.WithContext(ctx)
currentAPIKey = fallbackAPIKey
currentSubscription = nil
fallbackUsed = true
retryWithFallback = true
break
}
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
return
}
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverStatus = failoverErr.StatusCode
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
}
// 错误响应已在Forward中处理这里只记录日志
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
return
}
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
// 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: currentAPIKey,
User: currentAPIKey.User,
Account: usedAccount,
Subscription: currentSubscription,
UserAgent: ua,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account, userAgent, clientIP)
return
}
if !retryWithFallback {
return
}
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
// 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: clientIP,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account, userAgent, clientIP)
return
}
}
@@ -524,7 +590,18 @@ func (h *GatewayHandler) AntigravityModels(c *gin.Context) {
})
}
// Usage handles getting account balance for CC Switch integration
func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service.APIKey {
if apiKey == nil || group == nil {
return apiKey
}
cloned := *apiKey
groupID := group.ID
cloned.GroupID = &groupID
cloned.Group = group
return &cloned
}
// Usage handles getting account balance and usage statistics for CC Switch integration
// GET /v1/usage
func (h *GatewayHandler) Usage(c *gin.Context) {
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
@@ -539,7 +616,40 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
return
}
// 订阅模式:返回订阅限额信息
// Best-effort: 获取用量统计,失败不影响基础响应
var usageData gin.H
if h.usageService != nil {
dashStats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID)
if err == nil && dashStats != nil {
usageData = gin.H{
"today": gin.H{
"requests": dashStats.TodayRequests,
"input_tokens": dashStats.TodayInputTokens,
"output_tokens": dashStats.TodayOutputTokens,
"cache_creation_tokens": dashStats.TodayCacheCreationTokens,
"cache_read_tokens": dashStats.TodayCacheReadTokens,
"total_tokens": dashStats.TodayTokens,
"cost": dashStats.TodayCost,
"actual_cost": dashStats.TodayActualCost,
},
"total": gin.H{
"requests": dashStats.TotalRequests,
"input_tokens": dashStats.TotalInputTokens,
"output_tokens": dashStats.TotalOutputTokens,
"cache_creation_tokens": dashStats.TotalCacheCreationTokens,
"cache_read_tokens": dashStats.TotalCacheReadTokens,
"total_tokens": dashStats.TotalTokens,
"cost": dashStats.TotalCost,
"actual_cost": dashStats.TotalActualCost,
},
"average_duration_ms": dashStats.AverageDurationMs,
"rpm": dashStats.Rpm,
"tpm": dashStats.Tpm,
}
}
}
// 订阅模式:返回订阅限额信息 + 用量统计
if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() {
subscription, ok := middleware2.GetSubscriptionFromContext(c)
if !ok {
@@ -548,28 +658,46 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
}
remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription)
c.JSON(http.StatusOK, gin.H{
resp := gin.H{
"isValid": true,
"planName": apiKey.Group.Name,
"remaining": remaining,
"unit": "USD",
})
"subscription": gin.H{
"daily_usage_usd": subscription.DailyUsageUSD,
"weekly_usage_usd": subscription.WeeklyUsageUSD,
"monthly_usage_usd": subscription.MonthlyUsageUSD,
"daily_limit_usd": apiKey.Group.DailyLimitUSD,
"weekly_limit_usd": apiKey.Group.WeeklyLimitUSD,
"monthly_limit_usd": apiKey.Group.MonthlyLimitUSD,
"expires_at": subscription.ExpiresAt,
},
}
if usageData != nil {
resp["usage"] = usageData
}
c.JSON(http.StatusOK, resp)
return
}
// 余额模式:返回钱包余额
// 余额模式:返回钱包余额 + 用量统计
latestUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info")
return
}
c.JSON(http.StatusOK, gin.H{
resp := gin.H{
"isValid": true,
"planName": "钱包余额",
"remaining": latestUser.Balance,
"unit": "USD",
})
"balance": latestUser.Balance,
}
if usageData != nil {
resp["usage"] = usageData
}
c.JSON(http.StatusOK, resp)
}
// calculateSubscriptionRemaining 计算订阅剩余可用额度
@@ -725,6 +853,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext(c, body)
setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body)

View File

@@ -14,6 +14,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
@@ -335,10 +336,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 5) forward (根据平台分流)
var result *service.ForwardResult
requestCtx := c.Request.Context()
if switchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
}
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, modelName, action, stream, body)
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body)
} else {
result, err = h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
result, err = h.geminiCompatService.ForwardNative(requestCtx, c, account, modelName, action, stream, body)
}
if accountReleaseFunc != nil {
accountReleaseFunc()
@@ -366,18 +371,22 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
// 6) record usage async
// 6) record usage async (Gemini 使用长上下文双倍计费)
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: ip,
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: ip,
LongContextThreshold: 200000, // Gemini 200K 阈值
LongContextMultiplier: 2.0, // 超出部分双倍计费
APIKeyService: h.apiKeyService,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}

View File

@@ -10,6 +10,7 @@ type AdminHandlers struct {
User *admin.UserHandler
Group *admin.GroupHandler
Account *admin.AccountHandler
Announcement *admin.AnnouncementHandler
OAuth *admin.OAuthHandler
OpenAIOAuth *admin.OpenAIOAuthHandler
GeminiOAuth *admin.GeminiOAuthHandler
@@ -33,6 +34,7 @@ type Handlers struct {
Usage *UsageHandler
Redeem *RedeemHandler
Subscription *SubscriptionHandler
Announcement *AnnouncementHandler
Admin *AdminHandlers
Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler

View File

@@ -24,6 +24,7 @@ import (
type OpenAIGatewayHandler struct {
gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService
apiKeyService *service.APIKeyService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
}
@@ -33,6 +34,7 @@ func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
apiKeyService *service.APIKeyService,
cfg *config.Config,
) *OpenAIGatewayHandler {
pingInterval := time.Duration(0)
@@ -46,6 +48,7 @@ func NewOpenAIGatewayHandler(
return &OpenAIGatewayHandler{
gatewayService: gatewayService,
billingCacheService: billingCacheService,
apiKeyService: apiKeyService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
}
@@ -299,13 +302,14 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: ip,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: ip,
APIKeyService: h.apiKeyService,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}

View File

@@ -905,7 +905,7 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool {
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
switch strings.TrimSpace(code) {
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID", "USER_INACTIVE":
return true
}
if phase == "billing" || phase == "concurrency" {
@@ -1011,5 +1011,12 @@ func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message
}
}
// Check if invalid/missing API key errors should be ignored (user misconfiguration)
if settings.IgnoreInvalidApiKeyErrors {
if strings.Contains(bodyLower, "invalid_api_key") || strings.Contains(bodyLower, "api_key_required") {
return true
}
}
return false
}

View File

@@ -32,21 +32,25 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
}
response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: h.version,
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
TotpEnabled: settings.TotpEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: h.version,
})
}

View File

@@ -13,6 +13,7 @@ func ProvideAdminHandlers(
userHandler *admin.UserHandler,
groupHandler *admin.GroupHandler,
accountHandler *admin.AccountHandler,
announcementHandler *admin.AnnouncementHandler,
oauthHandler *admin.OAuthHandler,
openaiOAuthHandler *admin.OpenAIOAuthHandler,
geminiOAuthHandler *admin.GeminiOAuthHandler,
@@ -32,6 +33,7 @@ func ProvideAdminHandlers(
User: userHandler,
Group: groupHandler,
Account: accountHandler,
Announcement: announcementHandler,
OAuth: oauthHandler,
OpenAIOAuth: openaiOAuthHandler,
GeminiOAuth: geminiOAuthHandler,
@@ -66,6 +68,7 @@ func ProvideHandlers(
usageHandler *UsageHandler,
redeemHandler *RedeemHandler,
subscriptionHandler *SubscriptionHandler,
announcementHandler *AnnouncementHandler,
adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler,
@@ -79,6 +82,7 @@ func ProvideHandlers(
Usage: usageHandler,
Redeem: redeemHandler,
Subscription: subscriptionHandler,
Announcement: announcementHandler,
Admin: adminHandlers,
Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler,
@@ -96,6 +100,7 @@ var ProviderSet = wire.NewSet(
NewUsageHandler,
NewRedeemHandler,
NewSubscriptionHandler,
NewAnnouncementHandler,
NewGatewayHandler,
NewOpenAIGatewayHandler,
NewTotpHandler,
@@ -106,6 +111,7 @@ var ProviderSet = wire.NewSet(
admin.NewUserHandler,
admin.NewGroupHandler,
admin.NewAccountHandler,
admin.NewAnnouncementHandler,
admin.NewOAuthHandler,
admin.NewOpenAIOAuthHandler,
admin.NewGeminiOAuthHandler,

View File

@@ -33,24 +33,55 @@ const (
"https://www.googleapis.com/auth/experimentsandconfigs"
// User-Agent与 Antigravity-Manager 保持一致)
UserAgent = "antigravity/1.11.9 windows/amd64"
UserAgent = "antigravity/1.15.8 windows/amd64"
// Session 过期时间
SessionTTL = 30 * time.Minute
// URL 可用性 TTL不可用 URL 的恢复时间)
URLAvailabilityTTL = 5 * time.Minute
// Antigravity API 端点
antigravityProdBaseURL = "https://cloudcode-pa.googleapis.com"
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
)
// BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致)
var BaseURLs = []string{
"https://cloudcode-pa.googleapis.com", // prod (优先)
"https://daily-cloudcode-pa.sandbox.googleapis.com", // daily sandbox (备用)
antigravityProdBaseURL, // prod (优先)
antigravityDailyBaseURL, // daily sandbox (备用)
}
// BaseURL 默认 URL保持向后兼容
var BaseURL = BaseURLs[0]
// ForwardBaseURLs 返回 API 转发用的 URL 顺序daily 优先)
func ForwardBaseURLs() []string {
if len(BaseURLs) == 0 {
return nil
}
urls := append([]string(nil), BaseURLs...)
dailyIndex := -1
for i, url := range urls {
if url == antigravityDailyBaseURL {
dailyIndex = i
break
}
}
if dailyIndex <= 0 {
return urls
}
reordered := make([]string, 0, len(urls))
reordered = append(reordered, urls[dailyIndex])
for i, url := range urls {
if i == dailyIndex {
continue
}
reordered = append(reordered, url)
}
return reordered
}
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级)
type URLAvailability struct {
mu sync.RWMutex
@@ -100,22 +131,37 @@ func (u *URLAvailability) IsAvailable(url string) bool {
// GetAvailableURLs 返回可用的 URL 列表
// 最近成功的 URL 优先,其他按默认顺序
func (u *URLAvailability) GetAvailableURLs() []string {
return u.GetAvailableURLsWithBase(BaseURLs)
}
// GetAvailableURLsWithBase 返回可用的 URL 列表(使用自定义顺序)
// 最近成功的 URL 优先,其他按传入顺序
func (u *URLAvailability) GetAvailableURLsWithBase(baseURLs []string) []string {
u.mu.RLock()
defer u.mu.RUnlock()
now := time.Now()
result := make([]string, 0, len(BaseURLs))
result := make([]string, 0, len(baseURLs))
// 如果有最近成功的 URL 且可用,放在最前面
if u.lastSuccess != "" {
expiry, exists := u.unavailable[u.lastSuccess]
if !exists || now.After(expiry) {
result = append(result, u.lastSuccess)
found := false
for _, url := range baseURLs {
if url == u.lastSuccess {
found = true
break
}
}
if found {
expiry, exists := u.unavailable[u.lastSuccess]
if !exists || now.After(expiry) {
result = append(result, u.lastSuccess)
}
}
}
// 添加其他可用的 URL默认顺序)
for _, url := range BaseURLs {
// 添加其他可用的 URL传入顺序)
for _, url := range baseURLs {
// 跳过已添加的 lastSuccess
if url == u.lastSuccess {
continue

View File

@@ -44,11 +44,13 @@ type TransformOptions struct {
// IdentityPatch 可选:自定义注入到 systemInstruction 开头的身份防护提示词;
// 为空时使用默认模板(包含 [IDENTITY_PATCH] 及 SYSTEM_PROMPT_BEGIN 标记)。
IdentityPatch string
EnableMCPXML bool
}
func DefaultTransformOptions() TransformOptions {
return TransformOptions{
EnableIdentityPatch: true,
EnableMCPXML: true,
}
}
@@ -257,8 +259,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
// 添加用户的 system prompt
parts = append(parts, userSystemParts...)
// 检测是否有 MCP 工具,如有则注入 XML 调用协议
if hasMCPTools(tools) {
// 检测是否有 MCP 工具,如有且启用了 MCP XML 注入则注入 XML 调用协议
if opts.EnableMCPXML && hasMCPTools(tools) {
parts = append(parts, GeminiPart{Text: mcpXMLProtocol})
}
@@ -312,7 +314,7 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
parts = append([]GeminiPart{{
Text: "Thinking...",
Thought: true,
ThoughtSignature: dummyThoughtSignature,
ThoughtSignature: DummyThoughtSignature,
}}, parts...)
}
}
@@ -330,9 +332,10 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
return contents, strippedThinking, nil
}
// dummyThoughtSignature 用于跳过 Gemini 3 thought_signature 验证
// DummyThoughtSignature 用于跳过 Gemini 3 thought_signature 验证
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
const dummyThoughtSignature = "skip_thought_signature_validator"
// 导出供跨包使用(如 gemini_native_signature_cleaner 跨账号修复)
const DummyThoughtSignature = "skip_thought_signature_validator"
// buildParts 构建消息的 parts
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
@@ -370,7 +373,7 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
// signature 处理:
// - Claude 模型allowDummyThought=false必须是上游返回的真实 signaturedummy 视为缺失)
// - Gemini 模型allowDummyThought=true优先透传真实 signature缺失时使用 dummy signature
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
if block.Signature != "" && (allowDummyThought || block.Signature != DummyThoughtSignature) {
part.ThoughtSignature = block.Signature
} else if !allowDummyThought {
// Claude 模型需要有效 signature在缺失时降级为普通文本并在上层禁用 thinking mode。
@@ -381,7 +384,7 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
continue
} else {
// Gemini 模型使用 dummy signature
part.ThoughtSignature = dummyThoughtSignature
part.ThoughtSignature = DummyThoughtSignature
}
parts = append(parts, part)
@@ -411,10 +414,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
// tool_use 的 signature 处理:
// - Claude 模型allowDummyThought=false必须是上游返回的真实 signaturedummy 视为缺失)
// - Gemini 模型allowDummyThought=true优先透传真实 signature缺失时使用 dummy signature
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
if block.Signature != "" && (allowDummyThought || block.Signature != DummyThoughtSignature) {
part.ThoughtSignature = block.Signature
} else if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature
part.ThoughtSignature = DummyThoughtSignature
}
parts = append(parts, part)
@@ -492,9 +495,23 @@ func parseToolResultContent(content json.RawMessage, isError bool) string {
}
// buildGenerationConfig 构建 generationConfig
const (
defaultMaxOutputTokens = 64000
maxOutputTokensUpperBound = 65000
maxOutputTokensClaude = 64000
)
func maxOutputTokensLimit(model string) int {
if strings.HasPrefix(model, "claude-") {
return maxOutputTokensClaude
}
return maxOutputTokensUpperBound
}
func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
maxLimit := maxOutputTokensLimit(req.Model)
config := &GeminiGenerationConfig{
MaxOutputTokens: 64000, // 默认最大输出
MaxOutputTokens: defaultMaxOutputTokens, // 默认最大输出
StopSequences: DefaultStopSequences,
}
@@ -518,6 +535,10 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
}
}
if config.MaxOutputTokens > maxLimit {
config.MaxOutputTokens = maxLimit
}
// 其他参数
if req.Temperature != nil {
config.Temperature = req.Temperature

View File

@@ -86,7 +86,7 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
if len(parts) != 3 {
t.Fatalf("expected 3 parts, got %d", len(parts))
}
if !parts[1].Thought || parts[1].ThoughtSignature != dummyThoughtSignature {
if !parts[1].Thought || parts[1].ThoughtSignature != DummyThoughtSignature {
t.Fatalf("expected dummy thought signature, got thought=%v signature=%q",
parts[1].Thought, parts[1].ThoughtSignature)
}
@@ -126,8 +126,8 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
if len(parts) != 1 || parts[0].FunctionCall == nil {
t.Fatalf("expected 1 functionCall part, got %+v", parts)
}
if parts[0].ThoughtSignature != dummyThoughtSignature {
t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature)
if parts[0].ThoughtSignature != DummyThoughtSignature {
t.Fatalf("expected dummy tool signature %q, got %q", DummyThoughtSignature, parts[0].ThoughtSignature)
}
})

View File

@@ -9,11 +9,26 @@ const (
BetaClaudeCode = "claude-code-20250219"
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
BetaTokenCounting = "token-counting-2024-11-01"
)
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header
//
// NOTE: Claude Code OAuth credentials are scoped to Claude Code. When we "mimic"
// Claude Code for non-Claude-Code clients, we must include the claude-code beta
// even if the request doesn't use tools, otherwise upstream may reject the
// request as a non-Claude-Code API request.
const MessageBetaHeaderNoTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header
const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header不需要 claude-code beta
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
@@ -25,15 +40,17 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking
// DefaultHeaders 是 Claude Code 客户端默认请求头。
var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)",
// Keep these in sync with recent Claude CLI traffic to reduce the chance
// that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage.
"User-Agent": "claude-cli/2.1.22 (external, cli)",
"X-Stainless-Lang": "js",
"X-Stainless-Package-Version": "0.52.0",
"X-Stainless-Package-Version": "0.70.0",
"X-Stainless-OS": "Linux",
"X-Stainless-Arch": "x64",
"X-Stainless-Arch": "arm64",
"X-Stainless-Runtime": "node",
"X-Stainless-Runtime-Version": "v22.14.0",
"X-Stainless-Runtime-Version": "v24.13.0",
"X-Stainless-Retry-Count": "0",
"X-Stainless-Timeout": "60",
"X-Stainless-Timeout": "600",
"X-App": "cli",
"Anthropic-Dangerous-Direct-Browser-Access": "true",
}
@@ -79,3 +96,39 @@ func DefaultModelIDs() []string {
// DefaultTestModel 测试时使用的默认模型
const DefaultTestModel = "claude-sonnet-4-5-20250929"
// ModelIDOverrides Claude OAuth 请求需要的模型 ID 映射
var ModelIDOverrides = map[string]string{
"claude-sonnet-4-5": "claude-sonnet-4-5-20250929",
"claude-opus-4-5": "claude-opus-4-5-20251101",
"claude-haiku-4-5": "claude-haiku-4-5-20251001",
}
// ModelIDReverseOverrides 用于将上游模型 ID 还原为短名
var ModelIDReverseOverrides = map[string]string{
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
"claude-opus-4-5-20251101": "claude-opus-4-5",
"claude-haiku-4-5-20251001": "claude-haiku-4-5",
}
// NormalizeModelID 根据 Claude OAuth 规则映射模型
func NormalizeModelID(id string) string {
if id == "" {
return id
}
if mapped, ok := ModelIDOverrides[id]; ok {
return mapped
}
return id
}
// DenormalizeModelID 将上游模型 ID 转换为短名
func DenormalizeModelID(id string) string {
if id == "" {
return id
}
if mapped, ok := ModelIDReverseOverrides[id]; ok {
return mapped
}
return id
}

View File

@@ -14,6 +14,9 @@ const (
// RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。
RetryCount Key = "ctx_retry_count"
// AccountSwitchCount 表示请求过程中发生的账号切换次数
AccountSwitchCount Key = "ctx_account_switch_count"
// IsClaudeCodeClient 标识当前请求是否来自 Claude Code 客户端
IsClaudeCodeClient Key = "ctx_is_claude_code_client"
// Group 认证后的分组信息,由 API Key 认证中间件设置

View File

@@ -2,6 +2,7 @@
package response
import (
"log"
"math"
"net/http"
@@ -74,6 +75,12 @@ func ErrorFrom(c *gin.Context, err error) bool {
}
statusCode, status := infraerrors.ToHTTP(err)
// Log internal errors with full details for debugging
if statusCode >= 500 && c.Request != nil {
log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, err.Error())
}
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
return true
}

View File

@@ -809,12 +809,21 @@ func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, i
return err
}
path := "{antigravity_quota_scopes," + string(scope) + "}"
scopeKey := string(scope)
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
path,
`UPDATE accounts SET
extra = jsonb_set(
jsonb_set(COALESCE(extra, '{}'::jsonb), '{antigravity_quota_scopes}'::text[], COALESCE(extra->'antigravity_quota_scopes', '{}'::jsonb), true),
ARRAY['antigravity_quota_scopes', $1]::text[],
$2::jsonb,
true
),
updated_at = NOW(),
last_used_at = NOW()
WHERE id = $3 AND deleted_at IS NULL`,
scopeKey,
raw,
id,
)
@@ -829,6 +838,7 @@ func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, i
if affected == 0 {
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err)
}
@@ -849,12 +859,19 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
return err
}
path := "{model_rate_limits," + scope + "}"
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
path,
`UPDATE accounts SET
extra = jsonb_set(
jsonb_set(COALESCE(extra, '{}'::jsonb), '{model_rate_limits}'::text[], COALESCE(extra->'model_rate_limits', '{}'::jsonb), true),
ARRAY['model_rate_limits', $1]::text[],
$2::jsonb,
true
),
updated_at = NOW()
WHERE id = $3 AND deleted_at IS NULL`,
scope,
raw,
id,
)

View File

@@ -0,0 +1,83 @@
package repository
import (
"context"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type announcementReadRepository struct {
client *dbent.Client
}
func NewAnnouncementReadRepository(client *dbent.Client) service.AnnouncementReadRepository {
return &announcementReadRepository{client: client}
}
func (r *announcementReadRepository) MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error {
client := clientFromContext(ctx, r.client)
return client.AnnouncementRead.Create().
SetAnnouncementID(announcementID).
SetUserID(userID).
SetReadAt(readAt).
OnConflictColumns(announcementread.FieldAnnouncementID, announcementread.FieldUserID).
DoNothing().
Exec(ctx)
}
func (r *announcementReadRepository) GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error) {
if len(announcementIDs) == 0 {
return map[int64]time.Time{}, nil
}
rows, err := r.client.AnnouncementRead.Query().
Where(
announcementread.UserIDEQ(userID),
announcementread.AnnouncementIDIn(announcementIDs...),
).
All(ctx)
if err != nil {
return nil, err
}
out := make(map[int64]time.Time, len(rows))
for i := range rows {
out[rows[i].AnnouncementID] = rows[i].ReadAt
}
return out, nil
}
func (r *announcementReadRepository) GetReadMapByUsers(ctx context.Context, announcementID int64, userIDs []int64) (map[int64]time.Time, error) {
if len(userIDs) == 0 {
return map[int64]time.Time{}, nil
}
rows, err := r.client.AnnouncementRead.Query().
Where(
announcementread.AnnouncementIDEQ(announcementID),
announcementread.UserIDIn(userIDs...),
).
All(ctx)
if err != nil {
return nil, err
}
out := make(map[int64]time.Time, len(rows))
for i := range rows {
out[rows[i].UserID] = rows[i].ReadAt
}
return out, nil
}
func (r *announcementReadRepository) CountByAnnouncementID(ctx context.Context, announcementID int64) (int64, error) {
count, err := r.client.AnnouncementRead.Query().
Where(announcementread.AnnouncementIDEQ(announcementID)).
Count(ctx)
if err != nil {
return 0, err
}
return int64(count), nil
}

View File

@@ -0,0 +1,194 @@
package repository
import (
"context"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type announcementRepository struct {
client *dbent.Client
}
func NewAnnouncementRepository(client *dbent.Client) service.AnnouncementRepository {
return &announcementRepository{client: client}
}
func (r *announcementRepository) Create(ctx context.Context, a *service.Announcement) error {
client := clientFromContext(ctx, r.client)
builder := client.Announcement.Create().
SetTitle(a.Title).
SetContent(a.Content).
SetStatus(a.Status).
SetTargeting(a.Targeting)
if a.StartsAt != nil {
builder.SetStartsAt(*a.StartsAt)
}
if a.EndsAt != nil {
builder.SetEndsAt(*a.EndsAt)
}
if a.CreatedBy != nil {
builder.SetCreatedBy(*a.CreatedBy)
}
if a.UpdatedBy != nil {
builder.SetUpdatedBy(*a.UpdatedBy)
}
created, err := builder.Save(ctx)
if err != nil {
return err
}
applyAnnouncementEntityToService(a, created)
return nil
}
func (r *announcementRepository) GetByID(ctx context.Context, id int64) (*service.Announcement, error) {
m, err := r.client.Announcement.Query().
Where(announcement.IDEQ(id)).
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrAnnouncementNotFound, nil)
}
return announcementEntityToService(m), nil
}
func (r *announcementRepository) Update(ctx context.Context, a *service.Announcement) error {
client := clientFromContext(ctx, r.client)
builder := client.Announcement.UpdateOneID(a.ID).
SetTitle(a.Title).
SetContent(a.Content).
SetStatus(a.Status).
SetTargeting(a.Targeting)
if a.StartsAt != nil {
builder.SetStartsAt(*a.StartsAt)
} else {
builder.ClearStartsAt()
}
if a.EndsAt != nil {
builder.SetEndsAt(*a.EndsAt)
} else {
builder.ClearEndsAt()
}
if a.CreatedBy != nil {
builder.SetCreatedBy(*a.CreatedBy)
} else {
builder.ClearCreatedBy()
}
if a.UpdatedBy != nil {
builder.SetUpdatedBy(*a.UpdatedBy)
} else {
builder.ClearUpdatedBy()
}
updated, err := builder.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrAnnouncementNotFound, nil)
}
a.UpdatedAt = updated.UpdatedAt
return nil
}
func (r *announcementRepository) Delete(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.Announcement.Delete().Where(announcement.IDEQ(id)).Exec(ctx)
return err
}
func (r *announcementRepository) List(
ctx context.Context,
params pagination.PaginationParams,
filters service.AnnouncementListFilters,
) ([]service.Announcement, *pagination.PaginationResult, error) {
q := r.client.Announcement.Query()
if filters.Status != "" {
q = q.Where(announcement.StatusEQ(filters.Status))
}
if filters.Search != "" {
q = q.Where(
announcement.Or(
announcement.TitleContainsFold(filters.Search),
announcement.ContentContainsFold(filters.Search),
),
)
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
items, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(announcement.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
out := announcementEntitiesToService(items)
return out, paginationResultFromTotal(int64(total), params), nil
}
func (r *announcementRepository) ListActive(ctx context.Context, now time.Time) ([]service.Announcement, error) {
q := r.client.Announcement.Query().
Where(
announcement.StatusEQ(service.AnnouncementStatusActive),
announcement.Or(announcement.StartsAtIsNil(), announcement.StartsAtLTE(now)),
announcement.Or(announcement.EndsAtIsNil(), announcement.EndsAtGT(now)),
).
Order(dbent.Desc(announcement.FieldID))
items, err := q.All(ctx)
if err != nil {
return nil, err
}
return announcementEntitiesToService(items), nil
}
func applyAnnouncementEntityToService(dst *service.Announcement, src *dbent.Announcement) {
if dst == nil || src == nil {
return
}
dst.ID = src.ID
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}
func announcementEntityToService(m *dbent.Announcement) *service.Announcement {
if m == nil {
return nil
}
return &service.Announcement{
ID: m.ID,
Title: m.Title,
Content: m.Content,
Status: m.Status,
Targeting: m.Targeting,
StartsAt: m.StartsAt,
EndsAt: m.EndsAt,
CreatedBy: m.CreatedBy,
UpdatedBy: m.UpdatedBy,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func announcementEntitiesToService(models []*dbent.Announcement) []service.Announcement {
out := make([]service.Announcement, 0, len(models))
for i := range models {
if s := announcementEntityToService(models[i]); s != nil {
out = append(out, *s)
}
}
return out
}

View File

@@ -33,7 +33,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
SetKey(key.Key).
SetName(key.Name).
SetStatus(key.Status).
SetNillableGroupID(key.GroupID)
SetNillableGroupID(key.GroupID).
SetQuota(key.Quota).
SetQuotaUsed(key.QuotaUsed).
SetNillableExpiresAt(key.ExpiresAt)
if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist)
@@ -110,6 +113,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
apikey.FieldStatus,
apikey.FieldIPWhitelist,
apikey.FieldIPBlacklist,
apikey.FieldQuota,
apikey.FieldQuotaUsed,
apikey.FieldExpiresAt,
).
WithUser(func(q *dbent.UserQuery) {
q.Select(
@@ -136,8 +142,11 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldImagePrice4k,
group.FieldClaudeCodeOnly,
group.FieldFallbackGroupID,
group.FieldFallbackGroupIDOnInvalidRequest,
group.FieldModelRoutingEnabled,
group.FieldModelRouting,
group.FieldMcpXMLInject,
group.FieldSupportedModelScopes,
)
}).
Only(ctx)
@@ -161,6 +170,8 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
SetName(key.Name).
SetStatus(key.Status).
SetQuota(key.Quota).
SetQuotaUsed(key.QuotaUsed).
SetUpdatedAt(now)
if key.GroupID != nil {
builder.SetGroupID(*key.GroupID)
@@ -168,6 +179,13 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
builder.ClearGroupID()
}
// Expiration time
if key.ExpiresAt != nil {
builder.SetExpiresAt(*key.ExpiresAt)
} else {
builder.ClearExpiresAt()
}
// IP 限制字段
if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist)
@@ -357,6 +375,38 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64)
return keys, nil
}
// IncrementQuotaUsed atomically increments the quota_used field and returns the new value
func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
// Use raw SQL for atomic increment to avoid race conditions
// First get current value
m, err := r.activeQuery().
Where(apikey.IDEQ(id)).
Select(apikey.FieldQuotaUsed).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return 0, service.ErrAPIKeyNotFound
}
return 0, err
}
newValue := m.QuotaUsed + amount
// Update with new value
affected, err := r.client.APIKey.Update().
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
SetQuotaUsed(newValue).
Save(ctx)
if err != nil {
return 0, err
}
if affected == 0 {
return 0, service.ErrAPIKeyNotFound
}
return newValue, nil
}
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
if m == nil {
return nil
@@ -372,6 +422,9 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
GroupID: m.GroupID,
Quota: m.Quota,
QuotaUsed: m.QuotaUsed,
ExpiresAt: m.ExpiresAt,
}
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)
@@ -409,28 +462,31 @@ func groupEntityToService(g *dbent.Group) *service.Group {
return nil
}
return &service.Group{
ID: g.ID,
Name: g.Name,
Description: derefString(g.Description),
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
Hydrated: true,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd,
MonthlyLimitUSD: g.MonthlyLimitUsd,
ImagePrice1K: g.ImagePrice1k,
ImagePrice2K: g.ImagePrice2k,
ImagePrice4K: g.ImagePrice4k,
DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
ID: g.ID,
Name: g.Name,
Description: derefString(g.Description),
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
Hydrated: true,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd,
MonthlyLimitUSD: g.MonthlyLimitUsd,
ImagePrice1K: g.ImagePrice1k,
ImagePrice2K: g.ImagePrice2k,
ImagePrice4K: g.ImagePrice4k,
DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.McpXMLInject,
SupportedModelScopes: g.SupportedModelScopes,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}
}

View File

@@ -50,13 +50,18 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject)
// 设置模型路由配置
if groupIn.ModelRouting != nil {
builder = builder.SetModelRouting(groupIn.ModelRouting)
}
// 设置支持的模型系列(始终设置,空数组表示不限制)
builder = builder.SetSupportedModelScopes(groupIn.SupportedModelScopes)
created, err := builder.Save(ctx)
if err == nil {
groupIn.ID = created.ID
@@ -87,7 +92,6 @@ func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.G
if err != nil {
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
return groupEntityToService(m), nil
}
@@ -108,7 +112,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject)
// 处理 FallbackGroupIDnil 时清除,否则设置
if groupIn.FallbackGroupID != nil {
@@ -116,6 +121,12 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
} else {
builder = builder.ClearFallbackGroupID()
}
// 处理 FallbackGroupIDOnInvalidRequestnil 时清除,否则设置
if groupIn.FallbackGroupIDOnInvalidRequest != nil {
builder = builder.SetFallbackGroupIDOnInvalidRequest(*groupIn.FallbackGroupIDOnInvalidRequest)
} else {
builder = builder.ClearFallbackGroupIDOnInvalidRequest()
}
// 处理 ModelRoutingnil 时清除,否则设置
if groupIn.ModelRouting != nil {
@@ -124,6 +135,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
builder = builder.ClearModelRouting()
}
// 处理 SupportedModelScopes始终设置空数组表示不限制
builder = builder.SetSupportedModelScopes(groupIn.SupportedModelScopes)
updated, err := builder.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
@@ -425,3 +439,61 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
return counts, nil
}
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID去重
func (r *groupRepository) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
if len(groupIDs) == 0 {
return nil, nil
}
rows, err := r.sql.QueryContext(
ctx,
"SELECT DISTINCT account_id FROM account_groups WHERE group_id = ANY($1) ORDER BY account_id",
pq.Array(groupIDs),
)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
var accountIDs []int64
for rows.Next() {
var accountID int64
if err := rows.Scan(&accountID); err != nil {
return nil, err
}
accountIDs = append(accountIDs, accountID)
}
if err := rows.Err(); err != nil {
return nil, err
}
return accountIDs, nil
}
// BindAccountsToGroup 将多个账号绑定到指定分组(批量插入,忽略已存在的绑定)
func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
if len(accountIDs) == 0 {
return nil
}
// 使用 INSERT ... ON CONFLICT DO NOTHING 忽略已存在的绑定
_, err := r.sql.ExecContext(
ctx,
`INSERT INTO account_groups (account_id, group_id, priority, created_at)
SELECT unnest($1::bigint[]), $2, 50, NOW()
ON CONFLICT (account_id, group_id) DO NOTHING`,
pq.Array(accountIDs),
groupID,
)
if err != nil {
return err
}
// 发送调度器事件
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err)
}
return nil
}

View File

@@ -2,11 +2,11 @@ package repository
import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3"
@@ -22,7 +22,7 @@ type openaiOAuthService struct {
}
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(s.tokenURL, proxyURL)
client := createOpenAIReqClient(proxyURL)
if redirectURI == "" {
redirectURI = openai.DefaultRedirectURI
@@ -39,23 +39,24 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
resp, err := client.R().
SetContext(ctx).
SetHeader("User-Agent", "codex-cli/0.91.0").
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_EXCHANGE_FAILED", "token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
}
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(s.tokenURL, proxyURL)
client := createOpenAIReqClient(proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
@@ -67,29 +68,25 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
resp, err := client.R().
SetContext(ctx).
SetHeader("User-Agent", "codex-cli/0.91.0").
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
}
func createOpenAIReqClient(tokenURL, proxyURL string) *req.Client {
forceHTTP2 := false
if parsedURL, err := url.Parse(tokenURL); err == nil {
forceHTTP2 = strings.EqualFold(parsedURL.Scheme, "https")
}
func createOpenAIReqClient(proxyURL string) *req.Client {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 120 * time.Second,
ForceHTTP2: forceHTTP2,
ProxyURL: proxyURL,
Timeout: 120 * time.Second,
})
}

View File

@@ -43,6 +43,7 @@ INSERT INTO ops_system_metrics (
upstream_529_count,
token_consumed,
account_switch_count,
qps,
tps,
@@ -81,14 +82,14 @@ INSERT INTO ops_system_metrics (
$1,$2,$3,$4,
$5,$6,$7,$8,
$9,$10,$11,
$12,$13,$14,
$15,$16,$17,$18,$19,$20,
$21,$22,$23,$24,$25,$26,
$27,$28,$29,$30,
$31,$32,
$33,$34,
$35,$36,$37,
$38,$39
$12,$13,$14,$15,
$16,$17,$18,$19,$20,$21,
$22,$23,$24,$25,$26,$27,
$28,$29,$30,$31,
$32,$33,
$34,$35,
$36,$37,$38,
$39,$40
)`
_, err := r.db.ExecContext(
@@ -109,6 +110,7 @@ INSERT INTO ops_system_metrics (
input.Upstream529Count,
input.TokenConsumed,
input.AccountSwitchCount,
opsNullFloat64(input.QPS),
opsNullFloat64(input.TPS),
@@ -177,7 +179,8 @@ SELECT
db_conn_waiting,
goroutine_count,
concurrency_queue_depth
concurrency_queue_depth,
account_switch_count
FROM ops_system_metrics
WHERE window_minutes = $1
AND platform IS NULL
@@ -199,6 +202,7 @@ LIMIT 1`
var dbWaiting sql.NullInt64
var goroutines sql.NullInt64
var queueDepth sql.NullInt64
var accountSwitchCount sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, windowMinutes).Scan(
&out.ID,
@@ -217,6 +221,7 @@ LIMIT 1`
&dbWaiting,
&goroutines,
&queueDepth,
&accountSwitchCount,
); err != nil {
return nil, err
}
@@ -273,6 +278,10 @@ LIMIT 1`
v := int(queueDepth.Int64)
out.ConcurrencyQueueDepth = &v
}
if accountSwitchCount.Valid {
v := accountSwitchCount.Int64
out.AccountSwitchCount = &v
}
return &out, nil
}

View File

@@ -56,18 +56,44 @@ error_buckets AS (
AND COALESCE(status_code, 0) >= 400
GROUP BY 1
),
switch_buckets AS (
SELECT ` + errorBucketExpr + ` AS bucket,
COALESCE(SUM(CASE
WHEN split_part(ev->>'kind', ':', 1) IN ('failover', 'retry_exhausted_failover', 'failover_on_400') THEN 1
ELSE 0
END), 0) AS switch_count
FROM ops_error_logs
CROSS JOIN LATERAL jsonb_array_elements(
COALESCE(NULLIF(upstream_errors, 'null'::jsonb), '[]'::jsonb)
) AS ev
` + errorWhere + `
AND upstream_errors IS NOT NULL
GROUP BY 1
),
combined AS (
SELECT COALESCE(u.bucket, e.bucket) AS bucket,
COALESCE(u.success_count, 0) AS success_count,
COALESCE(e.error_count, 0) AS error_count,
COALESCE(u.token_consumed, 0) AS token_consumed
FROM usage_buckets u
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket
SELECT
bucket,
SUM(success_count) AS success_count,
SUM(error_count) AS error_count,
SUM(token_consumed) AS token_consumed,
SUM(switch_count) AS switch_count
FROM (
SELECT bucket, success_count, 0 AS error_count, token_consumed, 0 AS switch_count
FROM usage_buckets
UNION ALL
SELECT bucket, 0, error_count, 0, 0
FROM error_buckets
UNION ALL
SELECT bucket, 0, 0, 0, switch_count
FROM switch_buckets
) t
GROUP BY bucket
)
SELECT
bucket,
(success_count + error_count) AS request_count,
token_consumed
token_consumed,
switch_count
FROM combined
ORDER BY bucket ASC`
@@ -84,13 +110,18 @@ ORDER BY bucket ASC`
var bucket time.Time
var requests int64
var tokens sql.NullInt64
if err := rows.Scan(&bucket, &requests, &tokens); err != nil {
var switches sql.NullInt64
if err := rows.Scan(&bucket, &requests, &tokens, &switches); err != nil {
return nil, err
}
tokenConsumed := int64(0)
if tokens.Valid {
tokenConsumed = tokens.Int64
}
switchCount := int64(0)
if switches.Valid {
switchCount = switches.Int64
}
denom := float64(bucketSeconds)
if denom <= 0 {
@@ -103,6 +134,7 @@ ORDER BY bucket ASC`
BucketStart: bucket.UTC(),
RequestCount: requests,
TokenConsumed: tokenConsumed,
SwitchCount: switchCount,
QPS: qps,
TPS: tps,
})
@@ -385,6 +417,7 @@ func fillOpsThroughputBuckets(start, end time.Time, bucketSeconds int, points []
BucketStart: cursor,
RequestCount: 0,
TokenConsumed: 0,
SwitchCount: 0,
QPS: 0,
TPS: 0,
})

View File

@@ -28,7 +28,6 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
}
return &proxyProbeService{
ipInfoURL: defaultIPInfoURL,
insecureSkipVerify: insecure,
allowPrivateHosts: allowPrivate,
validateResolvedIP: validateResolvedIP,
@@ -36,12 +35,20 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
}
const (
defaultIPInfoURL = "http://ip-api.com/json/?lang=zh-CN"
defaultProxyProbeTimeout = 30 * time.Second
)
// probeURLs 按优先级排列的探测 URL 列表
// 某些 AI API 专用代理只允许访问特定域名,因此需要多个备选
var probeURLs = []struct {
url string
parser string // "ip-api" or "httpbin"
}{
{"http://ip-api.com/json/?lang=zh-CN", "ip-api"},
{"http://httpbin.org/ip", "httpbin"},
}
type proxyProbeService struct {
ipInfoURL string
insecureSkipVerify bool
allowPrivateHosts bool
validateResolvedIP bool
@@ -60,8 +67,21 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
}
var lastErr error
for _, probe := range probeURLs {
exitInfo, latencyMs, err := s.probeWithURL(ctx, client, probe.url, probe.parser)
if err == nil {
return exitInfo, latencyMs, nil
}
lastErr = err
}
return nil, 0, fmt.Errorf("all probe URLs failed, last error: %w", lastErr)
}
func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Client, url string, parser string) (*service.ProxyExitInfo, int64, error) {
startTime := time.Now()
req, err := http.NewRequestWithContext(ctx, "GET", s.ipInfoURL, nil)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, 0, fmt.Errorf("failed to create request: %w", err)
}
@@ -78,6 +98,22 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
}
switch parser {
case "ip-api":
return s.parseIPAPI(body, latencyMs)
case "httpbin":
return s.parseHTTPBin(body, latencyMs)
default:
return nil, latencyMs, fmt.Errorf("unknown parser: %s", parser)
}
}
func (s *proxyProbeService) parseIPAPI(body []byte, latencyMs int64) (*service.ProxyExitInfo, int64, error) {
var ipInfo struct {
Status string `json:"status"`
Message string `json:"message"`
@@ -89,13 +125,12 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
CountryCode string `json:"countryCode"`
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
}
if err := json.Unmarshal(body, &ipInfo); err != nil {
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
preview := string(body)
if len(preview) > 200 {
preview = preview[:200] + "..."
}
return nil, latencyMs, fmt.Errorf("failed to parse response: %w (body: %s)", err, preview)
}
if strings.ToLower(ipInfo.Status) != "success" {
if ipInfo.Message == "" {
@@ -116,3 +151,19 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
CountryCode: ipInfo.CountryCode,
}, latencyMs, nil
}
func (s *proxyProbeService) parseHTTPBin(body []byte, latencyMs int64) (*service.ProxyExitInfo, int64, error) {
// httpbin.org/ip 返回格式: {"origin": "1.2.3.4"}
var result struct {
Origin string `json:"origin"`
}
if err := json.Unmarshal(body, &result); err != nil {
return nil, latencyMs, fmt.Errorf("failed to parse httpbin response: %w", err)
}
if result.Origin == "" {
return nil, latencyMs, fmt.Errorf("httpbin: no IP found in response")
}
return &service.ProxyExitInfo{
IP: result.Origin,
}, latencyMs, nil
}

View File

@@ -5,6 +5,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/require"
@@ -21,7 +22,6 @@ type ProxyProbeServiceSuite struct {
func (s *ProxyProbeServiceSuite) SetupTest() {
s.ctx = context.Background()
s.prober = &proxyProbeService{
ipInfoURL: "http://ip-api.test/json/?lang=zh-CN",
allowPrivateHosts: true,
}
}
@@ -49,12 +49,16 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_UnsupportedProxyScheme() {
require.ErrorContains(s.T(), err, "failed to create proxy client")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
seen := make(chan string, 1)
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success_IPAPI() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seen <- r.RequestURI
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`)
// 检查是否是 ip-api 请求
if strings.Contains(r.RequestURI, "ip-api.com") {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`)
return
}
// 其他请求返回错误
w.WriteHeader(http.StatusServiceUnavailable)
}))
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
@@ -65,45 +69,59 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
require.Equal(s.T(), "r", info.Region)
require.Equal(s.T(), "cc", info.Country)
require.Equal(s.T(), "CC", info.CountryCode)
// Verify proxy received the request
select {
case uri := <-seen:
require.Contains(s.T(), uri, "ip-api.test", "expected request to go through proxy")
default:
require.Fail(s.T(), "expected proxy to receive request")
}
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_NonOKStatus() {
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success_HTTPBinFallback() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// ip-api 失败
if strings.Contains(r.RequestURI, "ip-api.com") {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
// httpbin 成功
if strings.Contains(r.RequestURI, "httpbin.org") {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"origin": "5.6.7.8"}`)
return
}
w.WriteHeader(http.StatusServiceUnavailable)
}))
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.NoError(s.T(), err, "ProbeProxy should fallback to httpbin")
require.GreaterOrEqual(s.T(), latencyMs, int64(0), "unexpected latency")
require.Equal(s.T(), "5.6.7.8", info.IP)
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_AllFailed() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "status: 503")
require.ErrorContains(s.T(), err, "all probe URLs failed")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidJSON() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-json")
if strings.Contains(r.RequestURI, "ip-api.com") {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-json")
return
}
// httpbin 也返回无效响应
if strings.Contains(r.RequestURI, "httpbin.org") {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-json")
return
}
w.WriteHeader(http.StatusServiceUnavailable)
}))
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "failed to parse response")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidIPInfoURL() {
s.prober.ipInfoURL = "://invalid-url"
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err, "expected error for invalid ipInfoURL")
require.ErrorContains(s.T(), err, "all probe URLs failed")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() {
@@ -114,6 +132,40 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() {
require.Error(s.T(), err, "expected error when proxy server is closed")
}
func (s *ProxyProbeServiceSuite) TestParseIPAPI_Success() {
body := []byte(`{"status":"success","query":"1.2.3.4","city":"Beijing","regionName":"Beijing","country":"China","countryCode":"CN"}`)
info, latencyMs, err := s.prober.parseIPAPI(body, 100)
require.NoError(s.T(), err)
require.Equal(s.T(), int64(100), latencyMs)
require.Equal(s.T(), "1.2.3.4", info.IP)
require.Equal(s.T(), "Beijing", info.City)
require.Equal(s.T(), "Beijing", info.Region)
require.Equal(s.T(), "China", info.Country)
require.Equal(s.T(), "CN", info.CountryCode)
}
func (s *ProxyProbeServiceSuite) TestParseIPAPI_Failure() {
body := []byte(`{"status":"fail","message":"rate limited"}`)
_, _, err := s.prober.parseIPAPI(body, 100)
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "rate limited")
}
func (s *ProxyProbeServiceSuite) TestParseHTTPBin_Success() {
body := []byte(`{"origin": "9.8.7.6"}`)
info, latencyMs, err := s.prober.parseHTTPBin(body, 50)
require.NoError(s.T(), err)
require.Equal(s.T(), int64(50), latencyMs)
require.Equal(s.T(), "9.8.7.6", info.IP)
}
func (s *ProxyProbeServiceSuite) TestParseHTTPBin_NoIP() {
body := []byte(`{"origin": ""}`)
_, _, err := s.prober.parseHTTPBin(body, 50)
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "no IP found")
}
func TestProxyProbeServiceSuite(t *testing.T) {
suite.Run(t, new(ProxyProbeServiceSuite))
}

Some files were not shown because too many files have changed in this diff Show More