Compare commits

..

290 Commits

Author SHA1 Message Date
Wesley Liddick
3bae525026 Merge pull request #650 from wucm667/feat/sync-page-title-on-locale-change
feat(i18n): 切换语言时同步更新页面标题
2026-02-27 19:48:36 +08:00
shaw
df00805a2a feat(frontend): 为管理端用量页面添加列显示设置 2026-02-27 19:41:26 +08:00
Wesley Liddick
a88ee96518 Merge pull request #665 from touwaeriol/fix/2k-image-default-pricing
fix: add 2K image default pricing at 1.5x base price
2026-02-27 19:20:44 +08:00
Wesley Liddick
3cc2f9bd57 Merge pull request #664 from wucm667/fix/account-priority-hint
fix(frontend): add priority hint in edit account modal
2026-02-27 19:19:36 +08:00
erio
d1b684b782 fix: add 2K image default pricing at 1.5x base price
Previously 2K images used the same base price as 1K ($0.134).
Now 2K uses 1.5x multiplier ($0.201), consistent with 4K using 2x ($0.268).

- Backend: add 2K size branch in getDefaultImagePrice
- Frontend: update 2K placeholder from 0.134 to 0.201
- Tests: update assertions for new 2K default price
2026-02-27 17:37:30 +08:00
wucm667
6460d4ad3a fix(frontend): add priority hint in edit account modal 2026-02-27 16:00:11 +08:00
Wesley Liddick
19ea392d5d Merge pull request #663 from touwaeriol/fix/update-antigravity-useragent-version
fix: update antigravity user-agent version to 1.19.6
2026-02-27 15:28:45 +08:00
Wesley Liddick
fb4d016176 Merge pull request #659 from touwaeriol/feature/gemini-3.1-flash-image
feat: 新增 gemini-3.1-flash-image 支持,替代 gemini-3-pro-image
2026-02-27 15:28:33 +08:00
erio
afec747d9e fix: update antigravity user-agent version to 1.19.6
Update the default user-agent version from 1.18.4 to 1.19.6
to match the latest official antigravity client.
2026-02-27 12:31:51 +08:00
erio
7388fcce41 fix: gofmt alignment in constants.go 2026-02-27 09:52:50 +08:00
erio
a6f9f9f968 feat: replace gemini-3-pro-image with gemini-3.1-flash-image
- Add migration 060 to update model_mapping for all antigravity accounts
- Remove gemini-3-pro-image and gemini-3-pro-image-preview mappings
- Add gemini-3.1-flash-image and gemini-3.1-flash-image-preview mappings
- Update frontend usage window to show GImage for new model
- Update isImageGenerationModel to support new model
2026-02-27 09:52:50 +08:00
Wesley Liddick
29759721e0 Merge pull request #651 from cagedbird043/pr/bulk-edit-platform-filter
fix(frontend): 批量编辑添加跨平台模型映射警告与智能过滤
2026-02-27 09:03:00 +08:00
Wesley Liddick
1941b20521 Merge pull request #657 from alfadb/fix/count-tokens-404-passthrough
fix(gateway): count_tokens 不支持时返回 404 而非伪造的 200
2026-02-27 08:42:46 +08:00
alfadb
e6969acb50 fix: address review - fix log wording and add response body assertion in test 2026-02-26 23:49:30 +08:00
alfadb
9489531431 fix(gateway): return 404 instead of fake 200 for unsupported count_tokens endpoint
PR #635 returned HTTP 200 with {"input_tokens": 0} when upstream doesn't
support count_tokens (404). This caused Claude Code CLI to trust the zero
value, believing context uses 0 tokens, so auto-compression never triggers.

Fix: return 404 with proper error body so CLI falls back to its local
tokenizer for accurate estimation. Return nil (not error) to avoid
polluting ops error metrics with expected 404s.

Affected paths:
- Passthrough APIKey accounts: upstream 404 now passed through as 404
- Antigravity accounts: same fix (was also returning fake 200)
2026-02-26 23:34:53 +08:00
cagedbird043
32b7c0ca9b feat(frontend): 补齐 GPT-5.3 系列模型到白名单、批量编辑列表与预设映射
- useModelWhitelist.ts 添加 gpt-5.3-codex、gpt-5.3-codex-spark
- BulkEditAccountModal.vue 添加 5.3 模型选项与预设按钮(含 5.2→5.3 升级映射)
2026-02-26 16:04:15 +08:00
shaw
4ac57b4edf fix: 临时移除fast-mode-2026-02-01避免429问题 2026-02-26 15:44:28 +08:00
cagedbird043
685a1e0ba3 feat(i18n): 添加批量编辑跨平台警告的中英文翻译 2026-02-26 15:24:50 +08:00
cagedbird043
e350aab1bd fix(frontend): 批量编辑添加跨平台模型映射警告与过滤
- 新增 selectedPlatforms prop,从父组件传入选中账号的平台集合
- 根据选中平台过滤模型列表与预设映射按钮,避免误操作
- 混选多平台时显示 amber 警告横幅,提醒用户注意映射适用性
- 仅警告不阻断,保持加法兼容
2026-02-26 15:24:50 +08:00
Wesley Liddick
0dd6986e28 Merge pull request #639 from cagedbird043/pr/refactor-antigravity-model-source
refactor(admin): 消除测试连接 Gemini 模型硬编码,统一由 DefaultModels 提供
2026-02-26 14:57:13 +08:00
Wesley Liddick
6d0102a70c Merge pull request #638 from cagedbird043/pr/antigravity-claude-model-cleanup
feat(antigravity): 更新 opencode.json 模板至 Claude 4.6 并补齐模型支持
2026-02-26 14:55:32 +08:00
cagedbird043
f96a2a18c1 feat(opencode): 更新 opencode.json 模板至 Claude 4.6(默认启用 thinking) 2026-02-26 14:27:51 +08:00
cagedbird043
f955b04a6f feat(frontend): 补齐 Antigravity Claude 4.6 前端预设映射与显示 2026-02-26 14:27:51 +08:00
cagedbird043
2fd6ac319b feat(antigravity): 添加 Claude Opus/Sonnet 4.6 后端模型定义 2026-02-26 14:27:51 +08:00
wucm667
82fbf452a8 feat(i18n): 切换语言时同步更新页面标题
- resolveDocumentTitle() 新增 titleKey 参数,优先通过 i18n 翻译
- router beforeEach 中将路由 meta.titleKey 传入标题解析函数
- setLocale() 切换语言后同步刷新 document.title
2026-02-26 14:04:13 +08:00
cagedbird043
ba69736f55 refactor(admin): 测试连接模型列表改为复用 antigravity.DefaultModels,消除硬编码重复 2026-02-26 13:34:10 +08:00
shaw
c75c6b6858 fix: 将 DriveClient 注入 GeminiOAuthService,消除单元测试中的真实 HTTP 调用
FetchGoogleOneTier 原先在方法内部直接创建 DriveClient 实例,
导致单元测试中对 googleapis.com 发起真实 HTTP 请求,在 CI 环境
产生 401 错误。

将 DriveClient 作为依赖注入到 GeminiOAuthService,遵循项目
端口与适配器架构规范:
- 新增 repository/gemini_drive_client.go 作为 Provider
- 注册到 repository Wire ProviderSet
- 测试中使用 mockDriveClient 替代真实调用
2026-02-26 10:53:04 +08:00
Wesley Liddick
de61745bb2 Merge pull request #635 from alfadb/fix/count-tokens-fallback-for-proxy
fix: count_tokens 端点不支持时降级返回空值
2026-02-26 10:07:30 +08:00
Wesley Liddick
3fab0fcd4c Merge pull request #644 from LemonZuo/fix/remove-pgdata-env-var
移除 PostgreSQL 容器多余重复的 PGDATA 环境变量
2026-02-26 09:40:50 +08:00
alfadb
03bcd94ae5 fix: count_tokens 端点不支持时降级返回空值 (404 only)
第三方 Anthropic 中转站通常不支持 /v1/messages/count_tokens 端点,
上游返回 404 时降级返回 {input_tokens: 0},客户端 fallback 到本地估算。

- 仅匹配 404 状态码,语义明确:端点不存在
- 其他错误 (400/429/500) 保留原始处理链和 ops 遥测
- 无需解析错误消息内容,不依赖字符串匹配
- 新增 table-driven 测试覆盖 fallback 和 non-fallback 路径
2026-02-26 09:28:45 +08:00
Lemon
0343bc7777 fix: 移除 PostgreSQL 容器多余重複的 PGDATA 环境变量 2026-02-26 09:01:03 +08:00
Wesley Liddick
565d19acfd Merge pull request #636 from cagedbird043/pr/antigravity-gemini-3.1-models
fix(antigravity): 补全测试连接 Gemini 模型列表并新增 3.1 Pro High/Low 支持
2026-02-26 08:52:19 +08:00
Wesley Liddick
960acf1982 Merge pull request #632 from cagedbird043/pr/gemini-v1beta-template-align
feat: 对齐 Gemini v1beta 模型模板与映射顺序
2026-02-26 08:45:56 +08:00
cagedbird043
ece911521e fix(antigravity): 修正 Gemini 3.1 Pro High/Low 发布日期为 2026-02-19 2026-02-25 20:18:19 +08:00
cagedbird043
5d95e59742 fix(admin): 补全 antigravity 测试连接下拉框的 Gemini 模型列表 2026-02-25 18:51:47 +08:00
cagedbird043
01d084bbfd feat(antigravity): 新增 Gemini 3.1 Pro High 和 Gemini 3.1 Pro Low 模型支持 2026-02-25 18:51:47 +08:00
cagedbird043
7918fc2844 feat: 对齐 Gemini v1beta 模板模型与顺序 2026-02-25 15:19:23 +08:00
Wesley Liddick
31b30a6df2 Merge pull request #631 from cagedbird043/pr/opencode-template-openai-antigravity
feat: 补齐 OpenCode 模板中的 OpenAI 与 Antigravity 模型配置
2026-02-25 14:23:43 +08:00
cagedbird043
d217b59e0b feat: 补齐 Antigravity OpenCode 模板模型配置 2026-02-25 14:17:31 +08:00
cagedbird043
169a4b9d32 feat: 补齐 OpenAI OpenCode 模板模型配置 2026-02-25 14:17:31 +08:00
shaw
15f3ffb165 chore: 调整模型定价文件仓库 2026-02-25 13:50:21 +08:00
Wesley Liddick
02db1010dd Merge pull request #630 from DouDOU-start/main
Sora 平台: SDK 重构、JSON 转义修复及 AT 手动导入
2026-02-25 11:49:48 +08:00
huangenjun
935ea66681 fix: 修复 sora_sdk_client 类型断言未检查的 errcheck lint 错误
使用安全的 comma-ok 模式替代裸类型断言,避免 golangci-lint errcheck 报错。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-25 11:43:08 +08:00
huangenjun
26060e702f feat: Sora 平台支持手动导入 Access Token
新增 Access Token 输入方式,支持批量粘贴(每行一个)直接创建账号,
无需经过 OAuth 授权流程。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-25 11:33:07 +08:00
huangenjun
65d4ca2563 fix: 修复流式响应中 URL 的 & 被转义为 \u0026 的问题
新增 jsonMarshalRaw 使用 SetEscapeHTML(false) 替代 json.Marshal,
避免 HTML 字符转义导致客户端无法直接使用返回的 URL。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-25 11:32:56 +08:00
huangenjun
3c619a8da5 refactor: 使用 go-sora2api SDK 替代自建 Sora 客户端
使用 go-sora2api v1.1.0 SDK 替代原有 ~2000 行自建 HTTP/PoW/TLS 指纹代码,
SDK 提供高并发性能优化(实例级 rand、PoW 缓冲区复用、context.Context 支持)。

- 新增 SoraSDKClient 适配器实现 SoraClient 接口
- 精简 sora_client.go 为仅保留接口和类型定义
- 更新 Wire 绑定使用 SoraSDKClient
- 删除 SoraDirectClient、sora_curl_cffi_sidecar、sora_request_guard 等旧代码

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-25 10:15:38 +08:00
shaw
ded9b6c14e fix: upgrade utls to v1.8.2 to resolve GO-2026-4512 vulnerability 2026-02-25 08:57:43 +08:00
Wesley Liddick
609abbbd7c Merge pull request #624 from cagedbird043/pr/antigravity-gemini31-passthrough-buttons
feat: 补充 Antigravity 的 Gemini 3.1 Pro 透传快捷按钮
2026-02-25 08:45:49 +08:00
Wesley Liddick
1b4e504fad Merge pull request #625 from cagedbird043/pr/antigravity-default-gemini31-passthrough
fix: 默认补全 Antigravity 的 Gemini 3.1 Pro 透传映射
2026-02-25 08:45:16 +08:00
Wesley Liddick
0a3a445828 Merge pull request #628 from cagedbird043/pr/docs-model-mapping-bulk-edit-tip
docs: 增加跨平台批量修改导致模型映射丢失的排障经验
2026-02-25 08:31:31 +08:00
Wesley Liddick
c7e18bd5be Merge pull request #627 from touwaeriol/pr/bugfixes-and-enhancements
feat: 反重力(Antigravity)增强、Failover 重构及新模型支持
2026-02-25 08:30:25 +08:00
cagedbird043
083d202fe4 docs: 增加跨平台批量修改导致模型映射丢失的排障经验 2026-02-25 01:02:25 +08:00
erio
8365a8328b merge: resolve conflicts with upstream/main (Gemini 3→3.1 mappings) 2026-02-25 00:38:39 +08:00
erio
58f21e4b3a fix: correct gofmt alignment in gemini-3.1-pro fallback pricing 2026-02-25 00:23:37 +08:00
erio
5bd7408b2f fix: add fallback pricing for opus-4.6 and gemini-3.1-pro models 2026-02-25 00:10:07 +08:00
erio
c671e8dd1d fix: 统一gemini-3默认映射为非强制3.1 2026-02-24 23:24:48 +08:00
cagedbird043
a3aed3c4c3 fix: 默认补全 antigravity 的 Gemini 3.1 Pro 透传映射 2026-02-24 22:54:11 +08:00
cagedbird043
c008649584 feat: 补充 antigravity 的 Gemini 3.1 Pro 透传快捷按钮 2026-02-24 22:53:53 +08:00
Wesley Liddick
516f8f287c Merge pull request #623 from cagedbird043/fix/antigravity-mapping-upgrade-additions
fix: 补全 Antigravity 模型映射升级与快捷按钮
2026-02-24 22:50:24 +08:00
Wesley Liddick
66148690c6 Merge pull request #622 from cagedbird043/fix/auto-clear-account-error-on-usage
fix: 刷新用量成功后自动清理账号可恢复错误状态
2026-02-24 22:49:08 +08:00
Wesley Liddick
cadd7f546f Merge pull request #621 from cagedbird043/fix/gemini-auth-url-613
fix: 修复 Gemini 授权链接生成失败(issue #613)
2026-02-24 22:48:09 +08:00
erio
a3ff317f1c feat: optimize model rate limit indicator layout with short aliases
- Change layout from fixed 3-column grid to vertical-first responsive
  columns (1 col for ≤4 items, 2 cols for ≤8, 3 cols for 9+)
- Add short aliases for all known model scope keys (e.g. COpus46, CSon46,
  G3PH, G3F) to reduce badge width
- Display countdown timer directly on each badge (supports h/m/s)
- Retain legacy scope aliases for backward compatibility
2026-02-24 22:11:50 +08:00
erio
d8d4b0c0c7 fix: enable Gemini model_mapping UI and extend warmup to Antigravity
- Remove Gemini platform exclusion from model restriction UI in
  Create/Edit account modals (Gemini now supports model_mapping)
- Remove outdated Gemini model passthrough info cards
- Add model_mapping field to GeminiCredentials type
- Extend warmup request interception toggle to Antigravity platform
- Remove redundant try/catch in API key account creation
- Remove noisy gateway.request_completed debug log
- Reorganize Gemini model mapping sections in constants.go
2026-02-24 21:30:32 +08:00
erio
d616f8c854 refactor: remove unused ClientSecret constant
The ClientSecret constant was left as an empty string after
getClientSecret() was refactored to use defaultClientSecret.
Remove the dead constant and update the test accordingly.
2026-02-24 21:09:46 +08:00
erio
b6fa8b8eec fix: update tests for defaultClientSecret and align migration 058
- Fix oauth_test.go and client_test.go to use defaultClientSecret
  variable instead of env var (init() already sets the default)
- Align migration 058 gemini-3-pro-high/low/preview mappings with
  constants.go (map to 3.1 versions)
2026-02-24 21:06:10 +08:00
erio
36d2e6999b feat: add default value for Antigravity OAuth client secret
Add a built-in default for ANTIGRAVITY_OAUTH_CLIENT_SECRET so the
service works out of the box without requiring environment variable
configuration. The env var can still override the default.
2026-02-24 20:54:28 +08:00
cagedbird043
076c00063d feat: 补全 antigravity 模型映射快捷按钮 2026-02-24 20:31:36 +08:00
cagedbird043
ea8104c6a2 fix: antigravity 默认补全 gemini-3-flash 透传 2026-02-24 20:31:36 +08:00
erio
ca3e9336e1 test: update UserAgent version assertion to match 1.18.4 default 2026-02-24 20:31:02 +08:00
erio
f92ab48166 fix: add gemini-3.1-pro-preview to default Antigravity model mapping
Add missing gemini-3.1-pro-preview -> gemini-3.1-pro-high mapping to
DefaultAntigravityModelMapping for consistency with migration 059.
2026-02-24 20:06:19 +08:00
cagedbird043
c10267ce2b fix: 刷新用量成功后自动清理账号可恢复错误状态 2026-02-24 20:04:36 +08:00
cagedbird043
9bd6a62ab3 test: 更新 Gemini OAuth 内置回退测试用例 2026-02-24 20:04:05 +08:00
cagedbird043
0dbea6ca58 fix: 修复 Gemini 授权链接生成失败并改进错误提示 2026-02-24 20:04:05 +08:00
erio
6523b23221 revert: remove backend-ci.yml changes (fork-specific CI config) 2026-02-24 19:45:23 +08:00
erio
29c406dda0 feat: add migrations for sonnet-4-6 and gemini-3.1-pro model mappings
Add migration 058 to update existing Antigravity accounts with
claude-sonnet-4-6 in model_mapping. Add migration 059 to add
gemini-3.1-pro-high/low/preview mappings.
2026-02-24 19:40:30 +08:00
erio
483c8f246d chore: update default Antigravity UserAgent version to 1.18.4
Update the default ANTIGRAVITY_USER_AGENT_VERSION from 1.84.2 to
1.18.4 to match the current Antigravity-Manager desktop client.
2026-02-24 19:39:15 +08:00
erio
645f283108 feat: add claude-sonnet-4-6 and gemini-3.1-pro model support
Add claude-sonnet-4-6 to identity injection modelInfoMap and
Antigravity model selector. Add gemini-3.1-pro-high/low to
Antigravity model list and Sonnet 4.6 preset mapping.
2026-02-24 19:30:01 +08:00
erio
da6fd45000 chore: add sonnet-4-6 mapping, config defaults, and CI improvements
- Add claude-sonnet-4-6 to default Antigravity model mapping
- Add antigravity_extra_retries default value in config
- Add cache-dependency-path to CI setup-go for faster builds
- Simplify vitest config to avoid vite plugin compatibility issues
2026-02-24 18:55:39 +08:00
erio
fb3ef5f388 fix(frontend): add Gemini models to bulk edit and fix status grid layout
Add Gemini model presets to BulkEditAccountModal for bulk model mapping.
Fix AccountStatusIndicator model rate limit grid layout using proper
grid container.
2026-02-24 18:55:25 +08:00
erio
86bc76e352 test: add warmup request interception unit tests
Add comprehensive tests for warmup request interception behavior
covering Antigravity accounts with various credential configurations.
2026-02-24 18:55:11 +08:00
erio
644058174e fix(gemini): enable model_mapping filtering for Gemini API Key accounts
Remove the special case that bypassed model-supported checks for Gemini
API Key accounts, allowing model_mapping to filter requests properly.
Add tests for multiplatform model filtering behavior.
2026-02-24 18:54:59 +08:00
erio
4573868c08 fix(antigravity): bill with mapped model and use final model key for rate limiting
- Use mapped model (billingModel) instead of original request model for billing
- Use resolveFinalAntigravityModelKey for 429 rate limit model key,
  ensuring rate limit records match the actual upstream model
- Add regression tests for both fixes
2026-02-24 18:08:19 +08:00
erio
09166a52f8 refactor: extract failover error handling into FailoverState
- Extract duplicated failover logic from gateway_handler.go (3 places)
  and gemini_v1beta_handler.go into shared failover_loop.go
- Introduce FailoverState with HandleFailoverError and HandleSelectionExhausted
- Move helper functions (needForceCacheBilling, sleepWithContext) into failover_loop.go
- Add comprehensive unit tests (32+ test cases)
- Delete redundant gateway_handler_single_account_retry_test.go
2026-02-24 18:08:04 +08:00
erio
aaac1aaca9 feat: add mixed-channel precheck API for account-group binding
Add a dedicated CheckMixedChannel endpoint that allows the frontend
to pre-validate mixed channel risk before submitting create/update
requests. This improves UX by showing warnings earlier in the flow
instead of only after form submission.

Backend changes:
- Add CheckMixedChannelRequest struct and CheckMixedChannel handler
- Register POST /check-mixed-channel route
- Expose CheckMixedChannelRisk as public method on AdminService
- Simplify Create/Update 409 responses (remove details/require_confirmation)
- Add comprehensive handler tests and stub methods

Frontend changes:
- Add checkMixedChannelRisk API function and TypeScript types
- Refactor CreateAccountModal to precheck before step transition and submission
- Refactor EditAccountModal to precheck before update submission
- Replace pendingPayload pattern with action-based dialog flow
2026-02-24 17:16:53 +08:00
erio
59898c16c6 fix: fix intercept_warmup_requests config not being saved
Extract applyInterceptWarmup utility to unify all credential building
call sites:
- Fix upstream account creation missing intercept_warmup_requests write
- Fix apikey edit mode missing else-branch to clear the setting
- Add backend unit test for IsInterceptWarmupEnabled
- Add frontend unit test for credentialsBuilder
2026-02-24 16:48:16 +08:00
erio
0dacdf480b fix: distinguish client disconnection from upstream retry failure
Before this change, when a client disconnected mid-request, the error
message was "Upstream request failed after retries", which is misleading
and pollutes error logs. Now we check context.Err() to return a more
accurate "Client disconnected" message for both Claude and Gemini
forward paths.
2026-02-24 16:45:08 +08:00
erio
fdf9f68298 fix: update Claude usage window to support 4.6 models
The usage progress bar only matched claude-sonnet-4-5 and
claude-opus-4-5-thinking. After upgrading to 4.6, the backend returns
claude-sonnet-4-6/claude-opus-4-6-thinking which didn't match,
causing the Claude usage bar to not display.

- Add claude-sonnet-4-6 and claude-opus-4-6-thinking to the match list
- Rename label from "C4.5" to "Claude" for future-proofing
2026-02-24 16:44:18 +08:00
shaw
7be5e1734c fix: 修复 CI 集成测试因 context deadline exceeded 未被跳过而失败
skipIfExternalServiceUnavailable 检查了 "timeout" 但 Go 的
context.DeadlineExceeded 错误信息是 "context deadline exceeded",
不包含 "timeout" 子串,导致外部服务不可达时测试直接失败而非跳过。
2026-02-24 15:04:04 +08:00
shaw
bfe414670f chore: update version 2026-02-24 14:51:10 +08:00
shaw
e435a46db5 fix: 修复 antigravity UserAgent 重构遗留的编译错误和测试不匹配
- oauth.go: GetUserAgent() 缺少闭合大括号导致语法错误
- client_test.go/oauth_test.go: UserAgent 变量已重构为 GetUserAgent(),更新测试引用
- model_rate_limit_test.go: gemini-3-pro-preview 映射目标已更新为 gemini-3.1-pro-high,同步测试
2026-02-24 14:44:57 +08:00
Wesley Liddick
84bd881e68 Merge pull request #608 from cagedbird043/feature/gemini-3-to-3.1-mapping
feat: Antigravity 将 gemini-3-pro 路由升级到 gemini-3.1-pro
2026-02-24 14:08:42 +08:00
Wesley Liddick
a901117b8c Merge pull request #605 from cagedbird043/feature/antigravity-user-agent-configurable
feat: 让 Antigravity User-Agent 版本可通过环境变量配置
2026-02-24 14:08:25 +08:00
Wesley Liddick
6bccb8a8a6 Merge branch 'main' into feature/antigravity-user-agent-configurable 2026-02-24 14:01:43 +08:00
Wesley Liddick
3de1e0e485 Merge pull request #597 from 0-don/feat/add-gemini-3.1-pro-preview
feat: add gemini-3.1-pro-preview to model lists
2026-02-24 12:25:17 +08:00
shaw
492b852a1f fix: 幂等测试使用哈希值避免超出 VARCHAR(64) 限制
idempotency_key_hash 和 request_fingerprint 列为 VARCHAR(64),
而 uniqueTestValue 生成的字符串含完整测试名可能超过 64 字符。
新增 hashedTestValue 辅助函数对测试值做 SHA-256 哈希,
与生产逻辑一致且严格符合列宽限制。
2026-02-24 12:18:07 +08:00
shaw
8a137405d4 fix: 移除重复的 ptrTime 函数声明修复编译错误
idempotency_repo_integration_test.go 中的 ptrTime 与
scheduler_cache.go 中的声明冲突,导致 repository 包测试构建失败。
2026-02-24 11:49:01 +08:00
shaw
f431f5ed72 fix: 恢复backend-ci.yml 2026-02-24 11:37:14 +08:00
shaw
980fc9608f fix: 修复日志重复输出及清理冗余迁移逻辑
- logger: sinkCore 包装 tee core 时绕过了子 core 的 Check 级别过滤,
  导致每条日志同时写入 stdout 和 stderr,表现为启动日志重复显示。
  修复为正确委托 Check 给内部 tee core,sinkCore.Write 仅负责 sink 转发。
- migration 054: 移除冗余的遗留列回填逻辑,migration 009 已完成数据迁移,
  直接删除遗留列即可。
2026-02-24 11:31:19 +08:00
Wesley Liddick
07be258dca Merge pull request #603 from mt21625457/release
feat : 大幅度的性能优化 和 新增了很多功能
2026-02-24 11:08:47 +08:00
Wesley Liddick
dbdb29594c Merge pull request #606 from Nek0Neko/fix/user-balance-modal-dark-mode
fix: 更新余额显示样式以支持深色模式
2026-02-24 10:35:52 +08:00
yangjianbo
53d55bb92f feat: 工程清理 2026-02-24 10:19:58 +08:00
cagedbird043
3f3efff065 feat: 添加 Gemini 3→3.1 前端快捷按钮及后端映射(包括 3.1 透传) 2026-02-23 23:06:07 +08:00
Leon-mac
57b078f2c7 fix: 更新余额显示样式以支持深色模式 2026-02-23 22:27:09 +08:00
cagedbird043
1fc6ef3d4f feat: 让 User-Agent 版本可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.84.2 2026-02-23 21:17:35 +08:00
yangjianbo
c2567831d9 fix(service): 使用 os.Root 修复 Sora 存储路径告警
- 将媒体写入和删除切换为 os.Root 沙箱 API
- 移除旧的路径拼接校验辅助函数并收敛删除逻辑
- 调整并新增相关单元测试覆盖删除行为

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 16:06:04 +08:00
yangjianbo
e8671fd7c2 fix(service): 修复 Sora 媒体落地路径穿越风险
- 新增安全路径拼接校验,确保目标文件仍在下载目录内
- 清理失败下载文件时复用安全校验,避免不安全删除路径
- 增加扩展名白名单归一化与相关单元测试

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 14:42:07 +08:00
yangjianbo
4950ee48a0 chore(version): 更新版本号至 0.1.83.4 2026-02-23 13:07:34 +08:00
yangjianbo
5fa45f3b8c feat(idempotency): 为关键写接口接入幂等并完善并发容错 2026-02-23 12:45:37 +08:00
yangjianbo
3b6584cc8d chore(version): 更新版本号至 0.1.83.3 2026-02-22 22:20:42 +08:00
yangjianbo
7be1195281 feat(api-key): 增加 API Key 上次使用时间并补齐测试 2026-02-22 22:07:17 +08:00
yangjianbo
1fae8d086d fix(codex): 补回窗口绝对重置时间类型定义 2026-02-22 21:06:22 +08:00
yangjianbo
10636d8a1f fix(codex): 修复额度窗口过期展示并补齐高覆盖测试
- 后端新增绝对重置时间字段计算(codex_5h_reset_at/codex_7d_reset_at)

- 前端统一窗口解析逻辑:绝对时间优先,updated_at+seconds 回退,过期自动归零

- 新增后端与前端单元测试,覆盖关键边界与异常场景
2026-02-22 21:04:52 +08:00
yangjianbo
c67f02eaf0 fix(jwt): 修复仅配置小时时会话提前失效问题
- 将 jwt.access_token_expire_minutes 默认值改为 0,未显式配置时回退 expire_hour

- 调整配置校验为允许 0,仅拒绝负数并补充优先级注释

- 新增配置与认证服务单元测试,覆盖分钟优先与小时回退场景

- 更新示例配置文档,明确分钟/小时优先级与默认行为
2026-02-22 17:37:35 +08:00
yangjianbo
0b32f61062 fix(ratelimit): 清除限流时同步清理临时不可调度状态
- ClearRateLimit 增加清理 temp_unschedulable 与缓存\n- 新增 ClearRateLimit 相关单元测试覆盖成功与失败分支
2026-02-22 17:00:29 +08:00
yangjianbo
2ee6c26676 fix(gateway): 修复粘性会话预取分组错配并优化并发等待热路径 2026-02-22 16:43:33 +08:00
yangjianbo
a89477ddf5 perf(gateway): 优化热点路径并补齐高覆盖测试 2026-02-22 13:31:30 +08:00
yangjianbo
2f520c8d47 docs(readme): 说明 Sora 功能暂不可用及技术问题
- 在 README.md 增加 Sora 暂不可用状态说明
- 在 README_CN.md 增加对应中文说明并标注恢复后可选
- 明确 gateway.sora_* 配置当前仅为预留项

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-22 13:06:56 +08:00
yangjianbo
33db7a0fb6 feat(gateway): 引入使用量记录有界 worker 池与自动扩缩容
- 新增 UsageRecordWorkerPool,支持有界队列、溢出降级策略与自动扩缩容
- 将 Gateway/OpenAI/Sora/Gemini 使用量记录改为提交到统一任务池执行
- 增加 usage_record 配置默认值与校验规则,并补充配置与任务提交相关测试
- 注入并托管 worker 池生命周期,服务退出时统一 StopAndWait

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-22 12:56:57 +08:00
yangjianbo
50b9897182 docs(perf): 新增后端热点API性能优化审计行动计划 2026-02-22 12:54:21 +08:00
yangjianbo
f8ac5538e2 Merge branch 'test' into release 2026-02-21 22:00:16 +08:00
yangjianbo
1985be26b2 fix(gateway): 恢复 Anthropic 透传流数据间隔超时保护并补充回归测试 2026-02-21 16:54:44 +08:00
yangjianbo
fdfc739b72 fix(anthropic): 补齐创建账号页自动透传开关并验证后端透传参数
- 在 CreateAccountModal 为 Anthropic API Key 增加自动透传开关

- 创建请求写入 extra.anthropic_passthrough 并补充状态重置

- 新增 AccountHandler 单测,验证 extra 字段从请求到 CreateAccountInput 的透传
2026-02-21 14:40:31 +08:00
yangjianbo
bde9dbc57a feat(anthropic): 支持 API Key 自动透传并优化透传链路性能
- 新增 Anthropic API Key 自动透传开关与后端透传分支(仅替换认证)

- 账号编辑页新增自动透传开关,默认关闭

- 优化透传性能:SSE usage 解析 gjson 快路径、减少请求体重复拷贝、优化流式写回与非流式 usage 解析

- 补充单元测试与 benchmark,确保 Claude OAuth 路径不受影响
2026-02-21 14:16:18 +08:00
yangjianbo
80510e5f16 fix(gateway): 明确旧协议接口不支持的错误提示
将 /v1/chat/completions 的拦截文案改为旧协议不支持,避免误导为会路由到 Sora。
明确要求客户端改用 /v1/responses。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-21 12:36:43 +08:00
yangjianbo
773f20ed5e Merge branch 'test' into release 2026-02-21 12:15:30 +08:00
yangjianbo
f323174d07 fix(openai): 修复 codex_cli_only 误拦截并补充 codex 家族识别
- 为 codex_cli_only 增加 originator 判定通道,避免仅依赖 User-Agent 误拦截
- 扩展官方客户端家族标识,补充 codex_chatgpt_desktop 等常见前缀
- 新增并更新单元测试与网关透传回归测试,覆盖 UA 与 originator 组合场景

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-21 12:06:24 +08:00
yangjianbo
987589eabc Merge branch 'test' into release 2026-02-21 10:07:53 +08:00
0-don
1004bd86ac feat: add gemini-3.1-pro-preview to model lists
Add the newly released Gemini 3.1 Pro model to both the
native API fallback list and the admin UI test model dropdown.
2026-02-20 23:27:30 +01:00
yangjianbo
03f69dd394 fix(proxy): 将401/405质量检测结果调整为告警 2026-02-20 14:42:07 +08:00
yangjianbo
d14c24bbf3 feat(proxy): 持久化质量检测结果并在列表展示 2026-02-20 12:13:04 +08:00
yangjianbo
48dc011b2a test(admin,service): 修复代理质量与计费单测口径 2026-02-19 21:39:31 +08:00
yangjianbo
b341810e60 fix(sora): 优化 challenge 重试与调试日志 2026-02-19 21:38:04 +08:00
yangjianbo
46d9aee6dd feat(proxy,sora): 增强代理质量检测与Sora稳定性并修复审查问题 2026-02-19 21:18:35 +08:00
yangjianbo
36a1a7998b feat(sora): 强制Sora走curl_cffi sidecar并完善校验测试 2026-02-19 20:29:31 +08:00
yangjianbo
40498aac9d feat(sora): 对齐sora2api分镜角色去水印与挑战错误治理 2026-02-19 20:04:10 +08:00
yangjianbo
440b87094a fix(sora): 增强 Cloudflare 挑战识别并收敛 Sora 请求链路
- 在 failover 场景透传上游响应头并识别 Cloudflare challenge/cf-ray

- 统一 Sora 任务请求的 UA 与代理使用,sentinel 与业务请求保持一致

- 修复流式错误事件 JSON 转义问题并补充相关单元测试
2026-02-19 15:09:58 +08:00
yangjianbo
0832dfb32e fix(sora): 默认开启 TLS 指纹并支持显式关闭 2026-02-19 08:30:54 +08:00
yangjianbo
be09188bda feat(account-test): 增强 Sora 账号测试能力探测与弹窗交互
- 后端新增 Sora2 邀请码与剩余额度探测,并补充对应结果解析
- Sora 测试流程补齐请求头与 Cloudflare 场景提示,完善单测覆盖
- 前端测试弹窗对 Sora 账号改为免选模型流程,并新增中英文提示文案

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 08:29:51 +08:00
yangjianbo
5d2219d299 fix(sora): 修复令牌刷新请求格式与流式错误转义
- 将 refresh_token 恢复请求改为表单编码并匹配 OAuth 约定
- 流式错误改为 JSON 序列化,避免消息含引号或换行导致 SSE 非法
- 补充 Sora token 恢复与 failover 流式错误透传回归测试

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 08:23:00 +08:00
yangjianbo
900cce20a1 feat(sora): 对齐 Sora OAuth 流程并隔离网关请求路径
- 新增并接通 Sora 专用 OAuth 接口与 ST/RT 换取能力
- 完成前端 Sora 授权、RT/ST 手动导入与账号创建流程
- 强化 Sora token 恢复、转发日志与网关路由隔离行为
- 补充后端服务层与路由层相关测试覆盖

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 08:02:56 +08:00
yangjianbo
36bb327024 fix: 更新 ListWithFilters 方法以支持 groupID 参数 2026-02-18 20:52:35 +08:00
yangjianbo
5d9667d27a Merge branch 'main' into test
# Conflicts:
#	backend/cmd/server/VERSION
#	backend/ent/migrate/schema.go
#	backend/ent/mutation.go
#	backend/ent/runtime/runtime.go
#	backend/ent/usagelog.go
#	backend/ent/usagelog/usagelog.go
#	backend/ent/usagelog/where.go
#	backend/ent/usagelog_create.go
#	backend/ent/usagelog_update.go
#	backend/internal/repository/usage_log_repo.go
#	backend/internal/server/api_contract_test.go
#	backend/internal/server/middleware/cors.go
#	backend/internal/service/gateway_service.go
2026-02-18 20:16:31 +08:00
yangjianbo
fad04ca995 Merge branch 'main' of https://github.com/mt21625457/aicodex2api 2026-02-18 20:10:32 +08:00
shaw
074bd0dfda fix: 临时移除context-1m-2025-08-07以确保避免sonnet1m触发429 2026-02-18 18:41:30 +08:00
shaw
b41fa5e15f feat: 前端新增sonnet4.6快捷映射按钮 2026-02-18 17:06:37 +08:00
Wesley Liddick
beceb45d23 Merge pull request #591 from miraserver/main
feat: add Cache TTL Override per account
2026-02-18 15:59:25 +08:00
Wesley Liddick
9450edf462 Merge pull request #589 from 0-don/fix/strip-unsupported-codex-params
fix: strip unsupported parameters from Codex model requests
2026-02-18 15:58:05 +08:00
Wesley Liddick
785a7397f8 Merge pull request #579 from KortanZ/main
fix: accept openai x-stainless-* header to fix CORS error
2026-02-18 15:57:44 +08:00
John Doe
3d1f03c286 feat: add Cache TTL Override per account + bump VERSION to 0.1.83
- Account-level cache TTL override: rewrite Anthropic cache_creation
  token classification (5m↔1h) in streaming/non-streaming responses
- New DB field cache_ttl_overridden in usage_log for billing tracking
- Migration 055_add_cache_ttl_overridden
- Frontend: CacheTTL override toggle in account create/edit modals
- Ent schema regenerated for new usage_log fields

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 14:19:24 +03:00
0-don
8ff40f52e0 fix: remove unsupported parameters from Codex model requests 2026-02-17 00:06:32 +01:00
yangjianbo
6577f2ef03 fix(gateway): 避免SSE delta将缓存创建明细重置为0
- 仅在 delta 中 5m/1h 值大于0时覆盖 usage 明细
- 新增回归测试覆盖 delta 默认 0 不应覆盖 message_start 非零值
- 迁移 054 在删除 legacy 字段前追加一次回填,避免升级实例丢失历史写入

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 13:23:12 +08:00
yangjianbo
41d0383fb7 merge(test): 合并 main 并解决前端筛选器冲突 2026-02-15 22:04:06 +08:00
程序猿MT
1cf51b14f7 Merge branch 'Wei-Shaw:main' into main 2026-02-15 20:49:14 +08:00
yangjianbo
372e04f69a fix(docker): 默认从cmd/server/VERSION读取版本号 2026-02-14 23:28:33 +08:00
yangjianbo
e2107ce45e fix(build): Docker 构建注入版本号并同步 aicodex 镜像脚本 2026-02-14 21:16:21 +08:00
Kortan
ab14df043a fix: accept openai x-stainless-* header to fix CORS error 2026-02-14 16:52:07 +08:00
yangjianbo
5feff6b1e5 feat: 0.1.74.9 2026-02-14 13:23:26 +08:00
yangjianbo
06b0f62e79 feat(accounts): 自动刷新改为ETag增量同步并优化单账号更新体验
- 前端自动刷新改为 ETag/304 增量合并,减少全量重刷

- 单账号更新后增加静默窗口,避免刚更新即被自动刷新覆盖

- 列表筛选移除时改为待同步提示,不再立即触发全量补页

- 后端账号列表支持 If-None-Match,命中返回 304

- 单账号接口统一补充运行时容量字段并暴露 ETag 头
2026-02-14 13:22:51 +08:00
yangjianbo
40d110efe4 chore(logging): 删除过时的日志审计文档 2026-02-14 12:36:42 +08:00
yangjianbo
f23318fbcf fix(frontend): 同步账号本地移除后的分页状态 2026-02-14 12:35:35 +08:00
yangjianbo
cbab49d65f feat: 0.1.74.8 2026-02-14 12:10:20 +08:00
yangjianbo
b5a3b3db66 Merge branch 'test' into release 2026-02-14 12:07:19 +08:00
yangjianbo
9cafa46dd3 fix(accounts): 账号管理改为单行增量更新并避免全量刷新
- 将编辑与重新授权成功事件改为回传更新后的账号对象
- 在账号列表页按 id 就地补丁更新单行数据并保留运行时容量字段
- 单账号操作(刷新凭证/清错/清限流/临时不可调度重置)改为单行更新
- 后端增强 clear-rate-limit 接口,返回更新后的账号对象
- 同步前端 clearRateLimit API 类型定义

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 12:06:17 +08:00
yangjianbo
f6bff97d26 fix(frontend): 修复前端审计问题并补充回归测试 2026-02-14 11:56:08 +08:00
yangjianbo
d04b47b3ca feat(backend): 提交后端审计修复与配套测试改动 2026-02-14 11:23:10 +08:00
yangjianbo
862199143e fix(ops): 修复错误日志查询 q 过滤列名二义性 2026-02-14 11:21:30 +08:00
yangjianbo
57e8abcb63 fix(openai): 自动透传预检 instructions 并本地 403 拦截 2026-02-14 10:49:01 +08:00
yangjianbo
ed31c54961 fix(openai): 拒绝日志记录原始 User-Agent 便于攻击研判 2026-02-14 09:59:19 +08:00
yangjianbo
4bfa69bffa fix(openai): 仅记录 codex_cli_only 拒绝日志并输出详细 User-Agent 2026-02-14 09:53:17 +08:00
yangjianbo
888f2936ad feat: version 0.1.74.7 2026-02-13 19:28:12 +08:00
yangjianbo
4e894bac1f Merge branch 'test' into release 2026-02-13 19:27:35 +08:00
yangjianbo
f96acf6e27 fix(ops): 修复日志级别过滤并增强OpenAI错误诊断日志
- 移除 warn 级别下 access info 的强制入库补写,确保运行时日志级别真实生效

- 将 OpenAI fallback matched 与 passthrough 断流提示按需求降级为 info

- 为 codex_cli_only 与 instructions required 场景补充请求诊断字段(含 User-Agent)

- 出于安全考虑移除请求体预览,仅保留 request_body_size 与白名单头信息

- 新增/更新回归测试,覆盖 Forward 入口到日志落库链路
2026-02-13 19:27:07 +08:00
yangjianbo
2459eafb71 feat: 完善日志 2026-02-13 13:35:47 +08:00
yangjianbo
ed681d0830 feat: 整理 2026-02-13 12:49:08 +08:00
yangjianbo
3734abed4c feat(openai): 支持 gpt-5.3-codex-spark 并统一 gpt-5.3 到 codex 计费 2026-02-13 09:28:07 +08:00
yangjianbo
abf5de69fb Merge branch 'main' into test 2026-02-12 23:43:47 +08:00
yangjianbo
7582dc53d2 fix(openai): 修复关闭 codex_cli_only 无法持久化问题
在编辑 OpenAI OAuth 账号时,若 codex_cli_only 从开启切换为关闭,
现改为显式写入 false,避免 extra 为空时后端忽略更新导致旧值残留。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 23:15:41 +08:00
程序猿MT
174d7c774d Merge branch 'Wei-Shaw:main' into main 2026-02-12 23:12:41 +08:00
yangjianbo
a9518cc5be feat(openai): 增加 OAuth 账号 Codex 官方客户端限制开关
新增 codex_cli_only 开关并默认关闭,关闭时完全绕过限制逻辑。
在 OpenAI 网关引入统一检测入口,集中判定账号类型、开关与客户端族。
开启后仅放行 codex_cli_rs、codex_vscode、codex_app 客户端家族。
补充后端判定与网关分支测试,并在前端创建/编辑页增加开关配置与回显。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 22:32:59 +08:00
yangjianbo
2f190d812a fix(openai): 透传OAuth强制store/stream并修复Codex识别 2026-02-12 21:02:52 +08:00
yangjianbo
d411cf4472 fix(openai): 收敛自动透传请求头并增强 OAuth 安全兜底 2026-02-12 20:12:15 +08:00
yangjianbo
1ae49b9ead feat: version 0.1.74.5 2026-02-12 19:32:13 +08:00
yangjianbo
0bf162f64a Merge branch 'dev' into release 2026-02-12 19:23:54 +08:00
yangjianbo
6423636177 Merge branch 'test' into dev 2026-02-12 19:23:35 +08:00
yangjianbo
b6aaee01ce fix(logging): 修复 warn 级别下系统日志空白问题
- 新增 logger.WriteSinkEvent,支持旁路写入 sink,不受全局级别门控影响\n- 在 http.access 中间件中,当 info 被门控时补写 sink,保障 Ops 系统日志可索引\n- 增加 level=warn 场景回归测试,验证访问日志仍可入库
2026-02-12 19:19:11 +08:00
yangjianbo
3511376c2c chore(logging): 默认使用 console 普通日志输出
- 将配置默认 log.format 从 json 调整为 console\n- 将 logger 初始化兜底默认格式调整为 console\n- 同步更新 deploy 配置示例
2026-02-12 19:07:16 +08:00
yangjianbo
584cfc3db2 chore(logging): 完成后端日志审计与结构化迁移
- 将高密度服务与处理器日志迁移到新日志系统(LegacyPrintf/结构化日志)
- 增加 stdlog bridge 与兼容测试,保留旧日志捕获能力
- 将 OpenAI 断流告警改为结构化 Warn 并改造对应测试为 sink 捕获
- 补齐后端相关文件 logger 引用并通过全量 go test
2026-02-12 19:01:09 +08:00
yangjianbo
eaa7d899f0 fix(ops): 优化系统日志展示为可读文本
解析 extra 字段(status_code/latency_ms/method/path 等)并拼成普通文本\n表格改为 3 列并固定时间/级别宽度,详情列填满后自动换行

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 18:00:16 +08:00
yangjianbo
84cc651b46 fix(logger): 修复 caller 字段与 OpsSystemLogSink 停止刷盘
修复点:

- zap logger 不再强制 AddCallerSkip(1),确保 caller 指向真实调用点

- slog handler 避免重复写 time 字段

- OpsSystemLogSink 优先从字段 component 识别业务组件;停止时 drain 队列并用可用 ctx 刷盘

补充:新增/完善对应单测
2026-02-12 17:42:29 +08:00
yangjianbo
b7243660c4 fix(deploy): 修复 Postgres 数据未持久化导致重启后无法登录
原因:postgres:18-alpine 默认 PGDATA 不在 /var/lib/postgresql/data,数据落到匿名卷,docker compose down/up 会触发 initdb 重新初始化。

修复:在 compose 中显式设置 PGDATA=/var/lib/postgresql/data,让数据落到 postgres_data 命名卷。
2026-02-12 17:42:18 +08:00
yangjianbo
e722992439 fix(setup): 数据库有用户时跳过管理员引导 2026-02-12 16:50:42 +08:00
yangjianbo
fff1d54858 feat(log): 落地统一日志底座与系统日志运维能力 2026-02-12 16:27:29 +08:00
yangjianbo
a5f29019d9 test(ops): 提升日志链路覆盖率并修复lint阻塞 2026-02-12 16:25:44 +08:00
yangjianbo
208c5380f4 fix(ops): 排除刷新信号避免分页重置页码 2026-02-12 15:00:22 +08:00
yangjianbo
29191af877 Merge branch 'dev' into release 2026-02-12 14:40:37 +08:00
yangjianbo
2d6066f985 Merge branch 'test' into dev 2026-02-12 14:40:22 +08:00
yangjianbo
3ea5e5c33a feat: update build aicodex.sh 2026-02-12 14:40:05 +08:00
yangjianbo
dbd7969a3e Merge branch 'test' into release 2026-02-12 14:27:58 +08:00
yangjianbo
af3069073a chore(lint): 修复 golangci-lint unused
- 移除 OpenAIGatewayHandler 未使用字段

- 删除并发缓存中未使用的 Redis 脚本常量

- 将仅供 unit 测试使用的 parseIntegralNumber 移入 unit build tag 文件
2026-02-12 14:20:56 +08:00
yangjianbo
65661f24e2 feat(ops): 运维监控新增 OpenAI Token 请求统计表
- 新增管理端接口 /api/v1/admin/ops/dashboard/openai-token-stats,按模型聚合统计 gpt% 请求

- 支持 time_range=30m|1h|1d|15d|30d(默认 30d),支持 platform/group_id 过滤

- 支持分页(page/page_size)或 TopN(top_n)互斥查询

- 前端运维监控页新增统计表卡片,包含空态/错误态与分页/TopN 交互

- 补齐后端与前端测试
2026-02-12 14:20:14 +08:00
yangjianbo
ed2eba9028 fix(gateway): 默认过滤OpenAI透传超时头并补充断流告警 2026-02-12 14:16:18 +08:00
yangjianbo
10c1590b1d Merge branch 'dev' into release 2026-02-12 12:12:40 +08:00
yangjianbo
114e172603 test(repository): 补充 JWT 密钥引导并发与兼容性单测 2026-02-12 12:07:20 +08:00
yangjianbo
09c8380b3d fix(repository): 修复 JWT 密钥引导冲突一致性与并发读取竞态 2026-02-12 12:04:13 +08:00
yangjianbo
ba567babf4 :erge branch 'dev' into release 2026-02-12 11:45:11 +08:00
yangjianbo
9403aa9bd1 feat: version 0.1.74.4 2026-02-12 11:44:45 +08:00
yangjianbo
34b8bbcbe4 Merge branch 'dev' into release 2026-02-12 11:43:47 +08:00
yangjianbo
6b36992d34 feat(security): 启动时自动迁移并持久化JWT密钥
- 新增 security_secrets 表及 Ent schema 用于存储系统级密钥
- 启动阶段支持无 jwt.secret 配置并在数据库中自动生成持久化
- 在 Ent 初始化后补齐密钥并执行完整配置校验
- 增加并发与异常分支单元测试,覆盖密钥引导核心路径

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 11:41:20 +08:00
yangjianbo
6533a4647d fix(openai): 增强自动透传命中日志 2026-02-12 11:41:06 +08:00
yangjianbo
9c910c2049 feat(openai): 支持自动透传开关并透传 User-Agent
- OpenAI OAuth/API Key 统一支持自动透传开关,编辑页可开关\n- 透传模式仅替换认证并保留计费/并发/审计,修复 API Key responses 端点拼接\n- Usage 页面显示原始 User-Agent 且不截断,补充回归测试与清单
2026-02-12 10:56:07 +08:00
yangjianbo
43dc23a47d Merge branch 'test' into release 2026-02-12 09:49:05 +08:00
yangjianbo
61a2bf469a feat(openai): 极致优化 OAuth 链路并补齐性能守护
- 优化 /v1/responses 热路径,减少重复解析与不必要拷贝\n- 优化并发与 token 竞争路径并补齐运行指标\n- 补充 OpenAI/Ops 相关单元测试与回归用例\n- 新增灰度阈值守护与压测脚本,支撑发布验收
2026-02-12 09:41:37 +08:00
yangjianbo
a88bb8684f fix(openai): 修复 OAuth 透传流式断开与压缩头问题
- 透传流式在客户端断开后继续 drain 上游并解析 usage,避免计费信息丢失

- 阻断透传 accept-encoding,避免压缩响应影响 SSE/usage 解析

- 阻断 proxy-authorization,避免透传代理鉴权信息

- 补充回归测试:请求头阻断与断流后 usage 采集
2026-02-11 22:17:38 +08:00
程序猿MT
8da5fac69e Merge branch 'Wei-Shaw:main' into main 2026-02-11 18:39:52 +08:00
yangjianbo
e2cdb6c758 feat: 优化build image 2026-02-11 18:07:50 +08:00
yangjianbo
f1e884ce2b feat(openai): 增加 OAuth 透传开关
- 仅对 Codex CLI 且账号开启时走原样透传(只替换认证)

- 透传模式禁用工具修正/模型替换,并旁路解析 usage 用于计费

- 管理后台增加开关与文案,ops upstream error 记录 passthrough 标记

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 00:59:39 +08:00
yangjianbo
86f3124720 perf(service): 优化重试场景 thinking 过滤性能
- 避免全量 Unmarshal 请求体,改为仅解析 messages 子树

- 顶层 thinking 使用 sjson 直接删除,减少整体重写

- content 仅在需要修改时延迟分配 new slice

- 增加 FilterThinkingBlocksForRetry 基准测试

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 00:47:26 +08:00
yangjianbo
4b309fa8b5 fix(gateway): 优化 ParseGatewayRequest 函数,使用 unsafe 提高性能并增加 JSON 校验 2026-02-10 22:12:24 +08:00
yangjianbo
166080b29c chore: 更新版本号至 0.1.74.3 2026-02-10 18:02:02 +08:00
yangjianbo
3b0910f664 Merge branch 'main' into test-sora 2026-02-10 18:01:17 +08:00
yangjianbo
e489996713 test(backend): 补充改动代码单元测试覆盖率至 85%+
新增 48 个测试用例覆盖修复代码的各分支路径:
- subscription_maintenance_queue: nil receiver/task、Stop 幂等、零值参数 (+6)
- billing_service: CalculateCostWithConfig、错误传播、SoraImageCost 等 (+12)
- timing_wheel_service: Schedule/ScheduleRecurring after Stop (+3)
- sora_media_cleanup_service: nil guard、Start/Stop 各分支、timezone (+10)
- sora_gateway_service: normalizeSoraMediaURLs、buildSoraContent 等辅助函数 (+17)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 17:52:10 +08:00
yangjianbo
54fe363257 fix(backend): 修复代码审核发现的 8 个确认问题
- P0-1: subscription_maintenance_queue 使用 RWMutex 防止 channel close/send 竞态
- P0-2: billing_service CalculateCostWithLongContext 修复被吞没的 out-range 错误
- P1-1: timing_wheel_service Schedule/ScheduleRecurring 添加 SetTimer 错误日志
- P1-2: sora_gateway_service StoreFromURLs 失败时降级使用原始 URL
- P1-3: concurrency_cache 用 Pipeline 替代 Lua 脚本兼容 Redis Cluster
- P1-6: sora_media_cleanup_service runCleanup 添加 nil cfg/storage 防护

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 17:51:49 +08:00
程序猿MT
1dd3158c7e Merge branch 'Wei-Shaw:main' into main 2026-02-10 13:55:51 +08:00
yangjianbo
5d1c51a37f fix(handler): 修复 gjson 迁移后的请求校验语义回退
- OpenAI handler: 添加 gjson.ValidBytes 校验 JSON 合法性;model 校验改为
  检查 gjson.String 类型而非仅判断非空(拒绝 model:123 等非法类型);stream
  字段添加 True/False 类型检查;sjson.SetBytes 返回值显式处理错误
- Sora handler: 添加 gjson.ValidBytes 校验;model 校验同上改为类型检查;
  messages 校验从 Exists+Type==JSON 改为 IsArray+len>0(拒绝空数组和对象)
- 补充 TestOpenAIHandler_GjsonValidation 和更新 TestSoraHandler_ValidationExtraction
  覆盖新增的边界校验场景

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 09:13:20 +08:00
yangjianbo
58912d4ac5 perf(backend): 使用 gjson/sjson 优化热路径 JSON 处理
将 API 网关热路径中的 json.Unmarshal+json.Marshal 替换为 gjson 零拷贝查询和 sjson 精准写入:
- unwrapV1InternalResponse 性能提升 22x(4009ns→182ns),内存分配减少 28.5x
- unwrapGeminiResponse、extractGeminiUsage、estimateGeminiCountTokens、ParseGeminiRateLimitResetTime 改为接收 []byte 使用 gjson 提取
- ParseGatewayRequest 的 model/stream/metadata/thinking/max_tokens 改用 gjson 类型安全提取
- Handler 层(sora/openai)改用 gjson 提取字段、sjson 注入/修改字段,移除 map[string]any 中间变量
- Sora Client 响应解析改用 gjson ForEach 遍历,减少内存分配
- 新增约 100 个单元测试用例,所有改动函数覆盖率 >85%

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 08:59:30 +08:00
yangjianbo
29ca1290b3 chore(test): 清理测试用例与类型导入 2026-02-10 00:37:56 +08:00
yangjianbo
3fcb0cc37c feat(subscription): 有界队列执行维护并改进鉴权解析 2026-02-10 00:37:47 +08:00
yangjianbo
2bfb16291f fix(unit): 修复 unit tag 测试编译与账号选择用例 2026-02-09 21:35:41 +08:00
yangjianbo
d367d1cde6 Merge branch 'main' into test-sora 2026-02-09 20:40:09 +08:00
yangjianbo
3c46f7d266 fix: update .gitignore to include frontend coverage directory 2026-02-09 20:26:46 +08:00
yangjianbo
16131c3d3f Merge branch 'main' of https://github.com/mt21625457/aicodex2api 2026-02-09 20:26:03 +08:00
yangjianbo
d7011163b8 fix: 修复代码审核发现的安全和质量问题
安全修复(P0):
- 移除硬编码的 OAuth client_secret(Antigravity、Gemini CLI),
  改为通过环境变量注入(ANTIGRAVITY_OAUTH_CLIENT_SECRET、
  GEMINI_CLI_OAUTH_CLIENT_SECRET)
- 新增 logredact.RedactText() 对非结构化文本做敏感信息脱敏,
  覆盖 GOCSPX-*/AIza* 令牌和常见 key=value 模式
- 日志中不再打印 org_uuid、account_uuid、email_address 等敏感值

安全修复(P1):
- URL 验证增强:新增 ValidateHTTPURL 统一入口,支持 allowlist 和
  私网地址阻断(localhost/内网 IP)
- 代理回退安全:代理初始化失败时默认阻止直连回退,防止 IP 泄露,
  可通过 security.proxy_fallback.allow_direct_on_error 显式开启
- Gemini OAuth 配置校验:client_id 与 client_secret 必须同时
  设置或同时留空

其他改进:
- 新增 tools/secret_scan.py 密钥扫描工具和 Makefile secret-scan 目标
- 更新所有 docker-compose 和部署配置,传递 OAuth secret 环境变量
- google_one OAuth 类型使用固定 redirectURI,与 code_assist 对齐

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 09:58:13 +08:00
yangjianbo
fc8a39e0f5 test: 删除CI工作流,大幅提升后端单元测试覆盖率至50%+
删除因GitHub计费锁定而失败的CI工作流。
为6个核心Go源文件补充单元测试,全部达到50%以上覆盖率:
- response/response.go: 97.6%
- antigravity/oauth.go: 90.1%
- antigravity/client.go: 88.6% (新增27个HTTP客户端测试)
- geminicli/oauth.go: 91.8%
- service/oauth_service.go: 61.2%
- service/gemini_oauth_service.go: 51.9%

新增/增强8个测试文件,共计5600+行测试代码。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 09:07:58 +08:00
yangjianbo
9da80e9fda feat: update 2026-02-08 12:13:29 +08:00
yangjianbo
bb5a5dd65e test: 完善自动化测试体系(7个模块,73个任务)
系统性地修复、补充和强化项目的自动化测试能力:

1. 测试基础设施修复
   - 修复 stubConcurrencyCache 缺失方法和构造函数参数不匹配
   - 创建 testutil 共享包(stubs.go, fixtures.go, httptest.go)
   - 为所有 Stub 添加编译期接口断言

2. 中间件测试补充
   - 新增 JWT 认证中间件测试(有效/过期/篡改/缺失 Token)
   - 补充 rate_limiter 和 recovery 中间件测试场景

3. 网关核心路径测试
   - 新增账户选择、等待队列、流式响应、并发控制、计费、Claude Code 检测测试
   - 覆盖负载均衡、粘性会话、SSE 转发、槽位管理等关键逻辑

4. 前端测试体系(11个新测试文件,163个测试用例)
   - Pinia stores: auth, app, subscriptions
   - API client: 请求拦截器、响应拦截器、401 刷新
   - Router guards: 认证重定向、管理员权限、简易模式限制
   - Composables: useForm, useTableLoader, useClipboard
   - Components: LoginForm, ApiKeyCreate, Dashboard

5. CI/CD 流水线重构
   - 重构 backend-ci.yml 为统一的 ci.yml
   - 前后端 4 个并行 Job + Postgres/Redis services
   - Race 检测、覆盖率收集与门禁、Docker 构建验证

6. E2E 自动化测试
   - e2e-test.sh 自动化脚本(Docker 启动→健康检查→测试→清理)
   - 用户注册→登录→API Key→网关调用完整链路测试
   - Mock 模式和 API Key 脱敏支持

7. 修复预存问题
   - tlsfingerprint dialer_test.go 缺失 build tag 导致集成测试编译冲突

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-08 12:05:39 +08:00
yangjianbo
53e1c8b268 perf(日志): 降噪优化,将常规成功日志降级为 Debug 级别
- GIN Logger 中间件跳过 /health 和 /setup/status 的请求日志
- UsageCleanup 空闲轮询(no_task)日志降级为 slog.Debug
- Scheduler 常规 rebuild ok 日志降级为 slog.Debug
- DashboardAggregation 常规聚合完成日志降级为 slog.Debug
- TokenRefresh 无刷新活动时周期日志降级为 slog.Debug

生产环境(Info 级别)下自动静默,debug 模式下仍可见。
错误、警告类日志保持原有级别不变。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 23:29:24 +08:00
yangjianbo
d876686a00 feat: update skills 2026-02-07 22:25:57 +08:00
yangjianbo
7546a56736 feat: update skills 2026-02-07 22:21:39 +08:00
yangjianbo
00caf0bcd8 test: 为代码审核修复添加详细单元测试(7个测试文件,50+测试用例)
新增测试文件:
- cors_test.go: CORS 条件化头部测试(12个测试,覆盖白名单/黑名单/通配符/凭证/多源/Vary)
- gateway_helper_backoff_test.go: nextBackoff 退避测试(6个测试+基准,验证指数增长/边界/抖动/收敛)
- billing_cache_jitter_test.go: jitteredTTL 抖动测试(5个测试+基准,验证范围/上界/方差/均值)
- subscription_calculate_progress_test.go: calculateProgress 纯函数测试(9个测试,覆盖日/周/月限额/超限截断/过期)
- openai_gateway_handler_test.go: SSE JSON 转义测试(7个子用例,验证双引号/反斜杠/换行符安全)

更新测试文件:
- response_transformer_test.go: 增强 generateRandomID 测试(7个测试,含并发/字符集/降级计数器)
- security_headers_test.go: 适配 GenerateNonce 新签名
- api_key_auth_test.go: 适配 NewSubscriptionService 新参数

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 22:14:07 +08:00
yangjianbo
9634494ba9 fix: 修复代码审核发现的10个问题(P0安全+P1数据一致性+P2性能优化)
P0: OpenAI SSE 错误消息 JSON 注入 — 使用 json.Marshal 替代 fmt.Sprintf
P1: subscription 续期包裹 Ent 事务确保原子性
P1: CSP nonce 生成处理 crypto/rand 错误,失败降级为 unsafe-inline
P1: singleflight 透传数据库真实错误,不再吞没为 not found
P1: GetUserSubscriptionsWithProgress 提取 calculateProgress 消除 N+1
P2: billing_cache/gateway_helper 迁移到 math/rand/v2 消除全局锁争用
P2: generateRandomID 降级分支增加原子计数器防碰撞
P2: CORS 非白名单 origin 不再设置 Allow-Headers/Methods/Max-Age
P2: Turnstile 验证移除 VerifyCode 空值跳过条件防绕过
P2: Redis Cluster Lua 脚本空 KEYS 添加兼容性警告注释

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 22:13:45 +08:00
yangjianbo
e1ac0db05c feat: 优化skills 2026-02-07 21:57:29 +08:00
yangjianbo
6f3e77a2df feat: update skills 2026-02-07 21:32:42 +08:00
yangjianbo
4a20a2a8ba fix: 修复批量更新凭证明细与缓存TTL抖动
- BatchUpdateCredentials 返回 success/failed/results 及 success_ids/failed_ids

- billing jitteredTTL 改为只减不增,确保TTL不超上界

- crypto/rand 失败时随机ID降级避免 panic

- OpenAI SelectAccount 失败日志去重并补充字段

- 修复两处类型断言以通过 errcheck
2026-02-07 21:18:03 +08:00
yangjianbo
bc3ca5f068 chore: 更新版本号至 0.1.74.1 2026-02-07 20:39:06 +08:00
yangjianbo
fd43be8d0b merge: 合并 main 分支到 test,解决 config 和 modelWhitelist 冲突
- config.go: 保留 Sora 配置,合入 SubscriptionCache 配置
- useModelWhitelist.ts: 同时保留 soraModels 和 antigravityModels

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 20:18:07 +08:00
yangjianbo
836ba14b70 fix: 修复函数签名变更后的调用参数不匹配
- handleUpstreamError 补齐新增的三个参数 (0, "", false)
- handleStreamingResponse 移除已删除的 nil 参数

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 20:05:29 +08:00
yangjianbo
a14dfb769a Merge branch 'dev-release' 2026-02-07 19:58:00 +08:00
yangjianbo
2588fa6a8f fix(audit): 第二批审计修复 — P0 生产 Bug、安全加固、性能优化、缓存一致性、代码质量
基于 backend-code-audit 审计报告,修复剩余 P0/P1/P2 共 34 项问题:

P0 生产 Bug:
- 修复 time.Since(time.Now()) 计时逻辑错误 (P0-03)
- generateRandomID 改用 crypto/rand 替代固定索引 (P0-04)
- IncrementQuotaUsed 重写为 Ent 原子操作消除 TOCTOU 竞态 (P0-05)

安全加固:
- gateway/openai handler 错误响应替换为泛化消息,防止内部信息泄露 (P1-14)
- usage_log_repo dateFormat 参数改用白名单映射,防止 SQL 注入 (P1-16)
- 默认配置安全加固:sslmode=prefer、response_headers=true、mode=release (P1-18/19, P2-15)

性能优化:
- gateway handler 循环内 defer 替换为显式 releaseWait 闭包 (P1-02)
- group_repo/promo_code_repo Count 前 Clone 查询避免状态污染 (P1-03)
- usage_log_repo 四个查询添加 LIMIT 10000 防止 OOM (P1-07)
- GetBatchUsageStats 添加时间范围参数,默认最近 30 天 (P1-10)
- ip.go CIDR 预编译为包级变量 (P1-11)
- BatchUpdateCredentials 重构为先验证后更新 (P1-13)

缓存一致性:
- billing_cache 添加 jitteredTTL 防止缓存雪崩 (P2-10)
- DeductUserBalance/UpdateSubscriptionUsage 错误传播修复 (P2-12)
- UserService.UpdateBalance 成功后异步失效 billingCache (P2-13)

代码质量:
- search 截断改为按 rune 处理,支持多字节字符 (P2-01)
- TLS Handshake 改为 HandshakeContext 支持 context 取消 (P2-07)
- CORS 预检添加 Access-Control-Max-Age: 86400 (P2-16)

测试覆盖:
- 新增 user_service_test.go(UpdateBalance 缓存失效 6 个用例)
- 新增 batch_update_credentials_test.go(fail-fast + 类型验证 7 个用例)
- 新增 response_transformer_test.go、ip_test.go、usage_log_repo_unit_test.go、search_truncate_test.go
- 集成测试:IncrementQuotaUsed 并发测试、billing_cache 错误传播测试
- config_test.go 补充 server.mode/sslmode 默认值断言

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 19:46:42 +08:00
yangjianbo
f6ca701917 fix(oauth): SessionStore.Stop() 添加 sync.Once 防重入保护 (P1-05)
oauth 和 openai 包的 SessionStore.Stop() 直接调用 close(stopCh),
重复调用会导致 panic。使用 sync.Once 包裹确保幂等安全。

新增单元测试覆盖连续调用和 50 goroutine 并发调用场景。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 17:39:18 +08:00
yangjianbo
a84604dceb fix(config): 禁止 server.frontend_url 携带 query/userinfo 2026-02-07 17:37:08 +08:00
yangjianbo
e75d3e3584 fix(security): 修复密码重置链接 Host Header 注入漏洞 (P0-07)
ForgotPassword 原来从 c.Request.Host 构建重置链接基础 URL,攻击者
可伪造 Host 头将重置链接指向恶意域名窃取 token。

修复方案:
- ServerConfig 新增 frontend_url 配置项
- auth_handler 改为从配置读取前端 URL,未配置时拒绝请求
- Validate() 校验 frontend_url 必须为绝对 HTTP(S) URL
- 新增 TestValidateServerFrontendURL 单元测试
- config.example.yaml 添加配置说明

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 17:15:26 +08:00
yangjianbo
8226a4ce4d perf(service): 优化 model 替换函数,用 gjson/sjson 替代全量 JSON 序列化
SSE 热路径中 replaceModelInSSELine 和 replaceModelInResponseBody 原来
使用 json.Unmarshal/Marshal 对每个事件做全量反序列化再序列化,现改为
gjson.Get/sjson.Set 精确字段操作,消除 O(n) 中间 map 分配,保持 JSON
字段顺序不变。涉及 OpenAIGatewayService 和 GatewayService 两个服务。

新增 23 个单元测试覆盖:顶层/嵌套 model 替换、不匹配跳过、空行/[DONE]/
非法 JSON 等边界情况。

Fixes: P1-08

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 17:09:55 +08:00
yangjianbo
65c0d8b51f fix(middleware): 管理员JWT增加TokenVersion校验
管理员改密后旧JWT会被拒绝,并补充单元测试覆盖。
2026-02-07 16:34:57 +08:00
yangjianbo
a9e256ce8c fix(openai): 修复 usage 为空导致 panic(P0-02) 2026-02-07 16:15:30 +08:00
yangjianbo
7e1674e43a chore(version): 更新版本号至 0.1.70.2 2026-02-07 14:58:52 +08:00
yangjianbo
fc104dfb56 feat:增加端口 2026-02-07 14:57:50 +08:00
yangjianbo
0e514ed80b perf(middleware): 优化订阅模式认证中间件,5次串行调用降至2步同步+1步异步
- 为 GetActiveSubscription 添加 ristretto L1 缓存 + singleflight 防击穿
- 合并 ValidateSubscription + CheckUsageLimits 为纯内存 ValidateAndCheckLimits
- 窗口维护操作(激活/重置)异步化,不再阻塞首字节
- 缓存返回浅拷贝,避免并发 data race 和缓存污染
- 所有管理操作(分配/续期/撤销/扩展/窗口重置)同步失效 L1 缓存
- 新增 SubscriptionCacheConfig 可配置 L1 缓存大小/TTL/抖动

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 14:43:12 +08:00
yangjianbo
782a54a8a1 chore(version): 更新版本号至 0.1.70.1 2026-02-07 11:17:46 +08:00
yangjianbo
4e01126ff2 test(codex): 清理无用的 opencode 缓存测试
移除不再需要的 setupCodexCache 调用与辅助函数(已不再回源/读写缓存)
2026-02-07 09:53:01 +08:00
yangjianbo
55b56328da feat(codex): 移除 opencode 指令回源与缓存
- 不再从 GitHub 拉取 opencode codex_header.txt\n- 删除 ~/.opencode 缓存与异步刷新逻辑\n- 所有 instructions 统一使用内置 codex_cli_instructions.md
2026-02-07 09:28:32 +08:00
yangjianbo
ce764bf2d9 feat(gateway): 支持强制 Codex CLI 模式并伪装 UA
- Codex CLI 请求仅使用内置 instructions,不再读取 opencode 缓存/回源\n- 新增 gateway.force_codex_cli(环境变量 GATEWAY_FORCE_CODEX_CLI)\n- ForceCodexCLI=true 时转发上游强制 User-Agent=codex_cli_rs/0.0.0\n- 更新 deploy 示例配置
2026-02-07 09:21:15 +08:00
yangjianbo
d71537d431 perf(service): SSE Scanner buffer 改用 sync.Pool 复用,减少高并发 GC 压力
将流式响应中 bufio.Scanner 的 64KB buffer 从每次 make 分配改为
sync.Pool 复用,统一切片表达式为 [:0]、变量命名为 scanBuf,
并补充对应的单元测试。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-06 22:55:12 +08:00
yangjianbo
ae1ba45350 perf(service): jitterTTL 改用 rand/v2 并移除锁 2026-02-06 21:22:38 +08:00
yangjianbo
c4182f8c33 perf(service): 移除 jitter 随机数全局锁 2026-02-06 21:20:25 +08:00
yangjianbo
028f8aaa97 feat: 优化.env参数 2026-02-06 21:01:30 +08:00
yangjianbo
d3f11fdbd3 chore(deploy): aicodex 默认 max_conns_per_host=8192 2026-02-06 20:50:44 +08:00
yangjianbo
8672b2f3ec chore(gateway): 提升 max_idle_conns 并补齐 env 2026-02-06 20:48:48 +08:00
yangjianbo
de753a149e chore(deploy): 补齐连接池默认与 8G 参数 2026-02-06 20:44:08 +08:00
yangjianbo
2d4bbbf49d feat: 优化codex冷启动, 还有连接池数据库配置信息 2026-02-06 20:31:42 +08:00
yangjianbo
792bef615c Merge branch 'main' into test 2026-02-06 09:59:15 +08:00
yangjianbo
000a943cce Merge branch 'main' into test 2026-02-06 08:43:42 +08:00
yangjianbo
f82e346f02 Merge branch 'main' into test 2026-02-06 08:12:40 +08:00
yangjianbo
d8e405511e Merge branch 'main' of https://github.com/mt21625457/aicodex2api 2026-02-06 06:56:23 +08:00
yangjianbo
74d35f0860 chore(ent): 重新生成代码 2026-02-04 20:41:26 +08:00
yangjianbo
de7ff902de Merge branch 'main' into test 2026-02-04 20:35:09 +08:00
yangjianbo
317f26f0bf feat: update caddy 2026-02-04 19:27:51 +08:00
程序猿MT
dd96ada3c6 Merge branch 'Wei-Shaw:main' into main 2026-02-04 18:56:47 +08:00
yangjianbo
9b120e68b8 fix(sora): 恢复流式辅助逻辑并通过 lint 2026-02-04 14:06:06 +08:00
yangjianbo
377bffe281 Merge branch 'main' into test 2026-02-03 22:48:04 +08:00
yangjianbo
31fe017888 Merge branch 'main' of https://github.com/mt21625457/aicodex2api 2026-02-03 21:00:11 +08:00
yangjianbo
99250ec527 fix(Sora): 加固直连安全与下载限制
补充图片输入 SSRF 防护与重定向限制\n增加媒体下载超时/大小上限配置并更新示例\n完善 recent_tasks 轮询回退策略与相关测试\n\n测试: go test ./... -tags=unit
2026-02-01 22:10:15 +08:00
yangjianbo
dcf5f60237 feat: add codex skills 2026-02-01 21:38:00 +08:00
yangjianbo
399dd78b2a feat(Sora): 直连生成并移除sora2api依赖
实现直连 Sora 客户端、媒体落地与清理策略\n更新网关与前端配置以支持 Sora 平台\n补齐单元测试与契约测试,新增 curl 测试脚本\n\n测试: go test ./... -tags=unit
2026-02-01 21:37:10 +08:00
yangjianbo
78d0ca3775 fix(sora): 修复流式重写与计费问题 2026-01-31 21:46:28 +08:00
yangjianbo
618a614cbf feat(Sora): 完成Sora网关接入与媒体能力
新增 Sora 网关路由、账号调度与同步服务\n补充媒体代理与签名 URL、模型列表动态拉取\n完善计费配置、前端支持与相关测试
2026-01-31 20:22:22 +08:00
yangjianbo
99dc3b59bc feat(账号): 添加 Sora 账号双表同步与创建
- 新增 sora_accounts 表与 accounts.extra GIN 索引\n- OpenAI OAuth 支持同时创建 Sora 账号并同步配置\n- Token 刷新同步关联 Sora 账号凭证与扩展表\n- 增加 Sora 账号连通性测试与前端开关文案
2026-01-30 14:08:04 +08:00
yangjianbo
d9e345f23d Merge branch 'test' of https://github.com/mt21625457/aicodex2api into test 2026-01-29 20:34:21 +08:00
yangjianbo
a505d992ee feat: 优化配置 2026-01-29 20:33:26 +08:00
yangjianbo
13262a5698 feat(sora): 新增 Sora 平台支持并修复高危安全和性能问题
新增功能:
- 新增 Sora 账号管理和 OAuth 认证
- 新增 Sora 视频/图片生成 API 网关
- 新增 Sora 任务调度和缓存机制
- 新增 Sora 使用统计和计费支持
- 前端增加 Sora 平台配置界面

安全修复(代码审核):
- [SEC-001] 限制媒体下载响应体大小(图片 20MB、视频 200MB),防止 DoS 攻击
- [SEC-002] 限制 SDK API 响应大小(1MB),防止内存耗尽
- [SEC-003] 修复 SSRF 风险,添加 URL 验证并强制使用代理配置

BUG 修复(代码审核):
- [BUG-001] 修复 for 循环内 defer 累积导致的资源泄漏
- [BUG-002] 修复图片并发槽位获取失败时已持有锁未释放的永久泄漏

性能优化(代码审核):
- [PERF-001] 添加 Sentinel Token 缓存(3 分钟有效期),减少 PoW 计算开销

技术细节:
- 使用 io.LimitReader 限制所有外部输入的大小
- 添加 urlvalidator 验证防止 SSRF 攻击
- 使用 sync.Map 实现线程安全的包级缓存
- 优化并发槽位管理,添加 releaseAll 模式防止泄漏

影响范围:
- 后端:新增 Sora 相关数据模型、服务、网关和管理接口
- 前端:新增 Sora 平台配置、账号管理和监控界面
- 配置:新增 Sora 相关配置项和环境变量

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-29 16:18:38 +08:00
yangjianbo
bece1b5201 perf(服务端): 启用 h2c 并保留 HTTP/1.1 回退 2026-01-24 20:01:03 +08:00
486 changed files with 71827 additions and 35962 deletions

View File

@@ -44,4 +44,4 @@ jobs:
with:
version: v2.7
args: --timeout=5m
working-directory: backend
working-directory: backend

7
.gitignore vendored
View File

@@ -121,7 +121,6 @@ AGENTS.md
scripts
.code-review-state
openspec/
docs/
code-reviews/
AGENTS.md
backend/cmd/server/server
@@ -129,4 +128,8 @@ deploy/docker-compose.override.yml
.gocache/
vite.config.js
docs/*
.serena/
.serena/
.codex/
frontend/coverage/
aicodex

View File

@@ -209,7 +209,30 @@ git add ent/ # 生成的文件也要提交
---
### 坑 10PR 提交前检查清单
### 坑 10前端测试看似正常,但后端调用失败(模型映射被批量误改)
**典型现象**
- 前端按钮点测看起来正常;
- 实际通过 API/客户端调用时返回 `Service temporarily unavailable` 或提示无可用账号;
- 常见于 OpenAI 账号(例如 Codex 模型)在批量修改后突然不可用。
**根因**
- OpenAI 账号编辑页默认不显式展示映射规则,容易让人误以为“没映射也没关系”;
- 但在**批量修改同时选中不同平台账号**OpenAI + Antigravity/Gemini模型白名单/映射可能被跨平台策略覆盖;
- 结果是 OpenAI 账号的关键模型映射丢失或被改坏,后端选不到可用账号。
**修复方案(按优先级)**
1. **快速修复(推荐)**:在批量修改中补回正确的透传映射(例如 `gpt-5.3-codex -> gpt-5.3-codex-spark`)。
2. **彻底重建**:删除并重新添加全部相关账号(最稳但成本高)。
**关键经验**
- 如果某模型已被软件内置默认映射覆盖,通常不需要额外再加透传;
- 但当上游模型更新快于本仓库默认映射时,**手动批量添加透传映射**是最简单、最低风险的临时兜底方案;
- 批量操作前尽量按平台分组,不要混选不同平台账号。
---
### 坑 11PR 提交前检查清单
提交 PR 前务必本地验证:

View File

@@ -36,7 +36,7 @@ RUN pnpm run build
FROM ${GOLANG_IMAGE} AS backend-builder
# Build arguments for version info (set by CI)
ARG VERSION=docker
ARG VERSION=
ARG COMMIT=docker
ARG DATE
ARG GOPROXY
@@ -61,9 +61,13 @@ COPY backend/ ./
COPY --from=frontend-builder /app/backend/internal/web/dist ./internal/web/dist
# Build the binary (BuildType=release for CI builds, embed frontend)
RUN CGO_ENABLED=0 GOOS=linux go build \
# Version precedence: build arg VERSION > cmd/server/VERSION
RUN VERSION_VALUE="${VERSION}" && \
if [ -z "${VERSION_VALUE}" ]; then VERSION_VALUE="$(tr -d '\r\n' < ./cmd/server/VERSION)"; fi && \
DATE_VALUE="${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)}" && \
CGO_ENABLED=0 GOOS=linux go build \
-tags embed \
-ldflags="-s -w -X main.Commit=${COMMIT} -X main.Date=${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)} -X main.BuildType=release" \
-ldflags="-s -w -X main.Version=${VERSION_VALUE} -X main.Commit=${COMMIT} -X main.Date=${DATE_VALUE} -X main.BuildType=release" \
-o /app/sub2api \
./cmd/server

View File

@@ -1,4 +1,4 @@
.PHONY: build build-backend build-frontend test test-backend test-frontend
.PHONY: build build-backend build-frontend test test-backend test-frontend secret-scan
# 一键编译前后端
build: build-backend build-frontend
@@ -20,3 +20,6 @@ test-backend:
test-frontend:
@pnpm --dir frontend run lint:check
@pnpm --dir frontend run typecheck
secret-scan:
@python3 tools/secret_scan.py

View File

@@ -363,6 +363,12 @@ default:
rate_multiplier: 1.0
```
### Sora Status (Temporarily Unavailable)
> ⚠️ Sora-related features are temporarily unavailable due to technical issues in upstream integration and media delivery.
> Please do not rely on Sora in production at this time.
> Existing `gateway.sora_*` configuration keys are reserved and may not take effect until these issues are resolved.
Additional security-related options are available in `config.yaml`:
- `cors.allowed_origins` for CORS allowlist

View File

@@ -139,6 +139,8 @@ curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install
使用 Docker Compose 部署,包含 PostgreSQL 和 Redis 容器。
如果你的服务器是 **Ubuntu 24.04**,建议直接参考:`deploy/ubuntu24-docker-compose-aicodex.md`,其中包含「安装最新版 Docker + docker-compose-aicodex.yml 部署」的完整步骤。
#### 前置条件
- Docker 20.10+
@@ -370,6 +372,33 @@ default:
rate_multiplier: 1.0
```
### Sora 功能状态(暂不可用)
> ⚠️ 当前 Sora 相关功能因上游接入与媒体链路存在技术问题,暂时不可用。
> 现阶段请勿在生产环境依赖 Sora 能力。
> 文档中的 `gateway.sora_*` 配置仅作预留,待技术问题修复后再恢复可用。
### Sora 媒体签名 URL功能恢复后可选
当配置 `gateway.sora_media_signing_key``gateway.sora_media_signed_url_ttl_seconds > 0` 时,网关会将 Sora 输出的媒体地址改写为临时签名 URL`/sora/media-signed/...`)。这样无需 API Key 即可在浏览器中直接访问,且具备过期控制与防篡改能力(签名包含 path + query
```yaml
gateway:
# /sora/media 是否强制要求 API Key默认 false
sora_media_require_api_key: false
# 媒体临时签名密钥(为空则禁用签名)
sora_media_signing_key: "your-signing-key"
# 临时签名 URL 有效期(秒)
sora_media_signed_url_ttl_seconds: 900
```
> 若未配置签名密钥,`/sora/media-signed` 将返回 503。
> 如需更严格的访问控制,可将 `sora_media_require_api_key` 设为 true仅允许携带 API Key 的 `/sora/media` 访问。
访问策略说明:
- `/sora/media`:内部调用或客户端携带 API Key 才能下载
- `/sora/media-signed`:外部可访问,但有签名 + 过期控制
`config.yaml` 还支持以下安全相关配置:
- `cors.allowed_origins` 配置 CORS 白名单
@@ -383,6 +412,14 @@ default:
- `server.trusted_proxies` 启用可信代理解析 X-Forwarded-For
- `turnstile.required` 在 release 模式强制启用 Turnstile
**网关防御纵深建议(重点)**
- `gateway.upstream_response_read_max_bytes`:限制非流式上游响应读取大小(默认 `8MB`),用于防止异常响应导致内存放大。
- `gateway.proxy_probe_response_read_max_bytes`:限制代理探测响应读取大小(默认 `1MB`)。
- `gateway.gemini_debug_response_headers`:默认 `false`,仅在排障时短时开启,避免高频请求日志开销。
- `/auth/register``/auth/login``/auth/login/2fa``/auth/send-verify-code` 已提供服务端兜底限流Redis 故障时 fail-close
- 推荐将 WAF/CDN 作为第一层防护,服务端限流与响应读取上限作为第二层兜底;两层同时保留,避免旁路流量与误配置风险。
**⚠️ 安全警告HTTP URL 配置**
`security.url_allowlist.enabled=false` 时,系统默认执行最小 URL 校验,**拒绝 HTTP URL**,仅允许 HTTPS。要允许 HTTP URL例如用于开发或内网测试必须显式设置
@@ -428,6 +465,29 @@ Invalid base URL: invalid url scheme: http
./sub2api
```
#### HTTP/2 (h2c) 与 HTTP/1.1 回退
后端明文端口默认支持 h2c并保留 HTTP/1.1 回退用于 WebSocket 与旧客户端。浏览器通常不支持 h2c性能收益主要在反向代理或内网链路。
**反向代理示例Caddy**
```caddyfile
transport http {
versions h2c h1
}
```
**验证:**
```bash
# h2c prior knowledge
curl --http2-prior-knowledge -I http://localhost:8080/health
# HTTP/1.1 回退
curl --http1.1 -I http://localhost:8080/health
# WebSocket 回退验证(需管理员 token
websocat -H="Sec-WebSocket-Protocol: sub2api-admin, jwt.<ADMIN_TOKEN>" ws://localhost:8080/api/v1/admin/ops/ws/qps
```
#### 开发模式
```bash

View File

@@ -14,4 +14,7 @@ test-integration:
go test -tags=integration ./...
test-e2e:
go test -tags=e2e ./...
./scripts/e2e-test.sh
test-e2e-local:
go test -tags=e2e -v -timeout=300s ./internal/integration/...

View File

@@ -17,7 +17,7 @@ func main() {
email := flag.String("email", "", "Admin email to issue a JWT for (defaults to first active admin)")
flag.Parse()
cfg, err := config.Load()
cfg, err := config.LoadForBootstrap()
if err != nil {
log.Fatalf("failed to load config: %v", err)
}

View File

@@ -1 +1 @@
0.1.76
0.1.85

View File

@@ -8,7 +8,6 @@ import (
"errors"
"flag"
"log"
"log/slog"
"net/http"
"os"
"os/signal"
@@ -19,11 +18,14 @@ import (
_ "github.com/Wei-Shaw/sub2api/ent/runtime"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/setup"
"github.com/Wei-Shaw/sub2api/internal/web"
"github.com/gin-gonic/gin"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
//go:embed VERSION
@@ -38,7 +40,12 @@ var (
)
func init() {
// Read version from embedded VERSION file
// 如果 Version 已通过 ldflags 注入(例如 -X main.Version=...),则不要覆盖。
if strings.TrimSpace(Version) != "" {
return
}
// 默认从 embedded VERSION 文件读取版本号(编译期打包进二进制)。
Version = strings.TrimSpace(embeddedVersion)
if Version == "" {
Version = "0.0.0-dev"
@@ -47,22 +54,9 @@ func init() {
// initLogger configures the default slog handler based on gin.Mode().
// In non-release mode, Debug level logs are enabled.
func initLogger() {
var level slog.Level
if gin.Mode() == gin.ReleaseMode {
level = slog.LevelInfo
} else {
level = slog.LevelDebug
}
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
})
slog.SetDefault(slog.New(handler))
}
func main() {
// Initialize slog logger based on gin mode
initLogger()
logger.InitBootstrap()
defer logger.Sync()
// Parse command line flags
setupMode := flag.Bool("setup", false, "Run setup wizard in CLI mode")
@@ -122,16 +116,26 @@ func runSetupServer() {
log.Printf("Setup wizard available at http://%s", addr)
log.Println("Complete the setup wizard to configure Sub2API")
if err := r.Run(addr); err != nil {
server := &http.Server{
Addr: addr,
Handler: h2c.NewHandler(r, &http2.Server{}),
ReadHeaderTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
}
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("Failed to start setup server: %v", err)
}
}
func runMainServer() {
cfg, err := config.Load()
cfg, err := config.LoadForBootstrap()
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
if err := logger.Init(logger.OptionsFromConfig(cfg.Log)); err != nil {
log.Fatalf("Failed to initialize logger: %v", err)
}
if cfg.RunMode == config.RunModeSimple {
log.Println("⚠️ WARNING: Running in SIMPLE mode - billing and quota checks are DISABLED")
}

View File

@@ -67,14 +67,19 @@ func provideCleanup(
opsAlertEvaluator *service.OpsAlertEvaluatorService,
opsCleanup *service.OpsCleanupService,
opsScheduledReport *service.OpsScheduledReportService,
opsSystemLogSink *service.OpsSystemLogSink,
soraMediaCleanup *service.SoraMediaCleanupService,
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
subscriptionExpiry *service.SubscriptionExpiryService,
usageCleanup *service.UsageCleanupService,
idempotencyCleanup *service.IdempotencyCleanupService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
billingCache *service.BillingCacheService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
subscriptionService *service.SubscriptionService,
oauth *service.OAuthService,
openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService,
@@ -101,6 +106,18 @@ func provideCleanup(
}
return nil
}},
{"OpsSystemLogSink", func() error {
if opsSystemLogSink != nil {
opsSystemLogSink.Stop()
}
return nil
}},
{"SoraMediaCleanupService", func() error {
if soraMediaCleanup != nil {
soraMediaCleanup.Stop()
}
return nil
}},
{"OpsAlertEvaluatorService", func() error {
if opsAlertEvaluator != nil {
opsAlertEvaluator.Stop()
@@ -131,6 +148,12 @@ func provideCleanup(
}
return nil
}},
{"IdempotencyCleanupService", func() error {
if idempotencyCleanup != nil {
idempotencyCleanup.Stop()
}
return nil
}},
{"TokenRefreshService", func() error {
tokenRefresh.Stop()
return nil
@@ -143,6 +166,12 @@ func provideCleanup(
subscriptionExpiry.Stop()
return nil
}},
{"SubscriptionService", func() error {
if subscriptionService != nil {
subscriptionService.Stop()
}
return nil
}},
{"PricingService", func() error {
pricing.Stop()
return nil
@@ -155,6 +184,12 @@ func provideCleanup(
billingCache.Stop()
return nil
}},
{"UsageRecordWorkerPool", func() error {
if usageRecordWorkerPool != nil {
usageRecordWorkerPool.Stop()
}
return nil
}},
{"OAuthService", func() error {
oauth.Stop()
return nil

View File

@@ -65,8 +65,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
@@ -98,10 +98,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
schedulerCache := repository.NewSchedulerCache(redisClient)
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
soraAccountRepository := repository.NewSoraAccountRepository(db)
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
@@ -112,7 +113,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig)
driveClient := repository.NewGeminiDriveClient()
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, driveClient, configConfig)
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
@@ -159,14 +161,17 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
opsHandler := admin.NewOpsHandler(opsService)
updateCache := repository.NewUpdateCache(redisClient)
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
serviceBuildInfo := provideServiceBuildInfo(buildInfo)
updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo)
systemHandler := handler.ProvideSystemHandler(updateService)
idempotencyRepository := repository.NewIdempotencyRepository(client, db)
systemOperationLockService := service.ProvideSystemOperationLockService(idempotencyRepository, configConfig)
systemHandler := handler.ProvideSystemHandler(updateService, systemOperationLockService)
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
usageCleanupRepository := repository.NewUsageCleanupRepository(client, db)
usageCleanupService := service.ProvideUsageCleanupService(usageCleanupRepository, timingWheelService, dashboardAggregationService, configConfig)
@@ -180,11 +185,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache)
errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
soraGatewayService := service.NewSoraGatewayService(soraSDKClient, soraMediaStorage, rateLimitService, configConfig)
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
totpHandler := handler.NewTotpHandler(totpService)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler)
idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig)
idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
@@ -195,10 +207,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig)
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -228,14 +241,19 @@ func provideCleanup(
opsAlertEvaluator *service.OpsAlertEvaluatorService,
opsCleanup *service.OpsCleanupService,
opsScheduledReport *service.OpsScheduledReportService,
opsSystemLogSink *service.OpsSystemLogSink,
soraMediaCleanup *service.SoraMediaCleanupService,
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
subscriptionExpiry *service.SubscriptionExpiryService,
usageCleanup *service.UsageCleanupService,
idempotencyCleanup *service.IdempotencyCleanupService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
billingCache *service.BillingCacheService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
subscriptionService *service.SubscriptionService,
oauth *service.OAuthService,
openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService,
@@ -261,6 +279,18 @@ func provideCleanup(
}
return nil
}},
{"OpsSystemLogSink", func() error {
if opsSystemLogSink != nil {
opsSystemLogSink.Stop()
}
return nil
}},
{"SoraMediaCleanupService", func() error {
if soraMediaCleanup != nil {
soraMediaCleanup.Stop()
}
return nil
}},
{"OpsAlertEvaluatorService", func() error {
if opsAlertEvaluator != nil {
opsAlertEvaluator.Stop()
@@ -291,6 +321,12 @@ func provideCleanup(
}
return nil
}},
{"IdempotencyCleanupService", func() error {
if idempotencyCleanup != nil {
idempotencyCleanup.Stop()
}
return nil
}},
{"TokenRefreshService", func() error {
tokenRefresh.Stop()
return nil
@@ -303,6 +339,12 @@ func provideCleanup(
subscriptionExpiry.Stop()
return nil
}},
{"SubscriptionService", func() error {
if subscriptionService != nil {
subscriptionService.Stop()
}
return nil
}},
{"PricingService", func() error {
pricing.Stop()
return nil
@@ -315,6 +357,12 @@ func provideCleanup(
billingCache.Stop()
return nil
}},
{"UsageRecordWorkerPool", func() error {
if usageRecordWorkerPool != nil {
usageRecordWorkerPool.Stop()
}
return nil
}},
{"OAuthService", func() error {
oauth.Stop()
return nil

View File

@@ -36,6 +36,8 @@ type APIKey struct {
GroupID *int64 `json:"group_id,omitempty"`
// Status holds the value of the "status" field.
Status string `json:"status,omitempty"`
// Last usage time of this API key
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
// Allowed IPs/CIDRs, e.g. ["192.168.1.100", "10.0.0.0/8"]
IPWhitelist []string `json:"ip_whitelist,omitempty"`
// Blocked IPs/CIDRs
@@ -109,7 +111,7 @@ func (*APIKey) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullInt64)
case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus:
values[i] = new(sql.NullString)
case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt, apikey.FieldExpiresAt:
case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt, apikey.FieldLastUsedAt, apikey.FieldExpiresAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
@@ -182,6 +184,13 @@ func (_m *APIKey) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.Status = value.String
}
case apikey.FieldLastUsedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field last_used_at", values[i])
} else if value.Valid {
_m.LastUsedAt = new(time.Time)
*_m.LastUsedAt = value.Time
}
case apikey.FieldIPWhitelist:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field ip_whitelist", values[i])
@@ -296,6 +305,11 @@ func (_m *APIKey) String() string {
builder.WriteString("status=")
builder.WriteString(_m.Status)
builder.WriteString(", ")
if v := _m.LastUsedAt; v != nil {
builder.WriteString("last_used_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
builder.WriteString("ip_whitelist=")
builder.WriteString(fmt.Sprintf("%v", _m.IPWhitelist))
builder.WriteString(", ")

View File

@@ -31,6 +31,8 @@ const (
FieldGroupID = "group_id"
// FieldStatus holds the string denoting the status field in the database.
FieldStatus = "status"
// FieldLastUsedAt holds the string denoting the last_used_at field in the database.
FieldLastUsedAt = "last_used_at"
// FieldIPWhitelist holds the string denoting the ip_whitelist field in the database.
FieldIPWhitelist = "ip_whitelist"
// FieldIPBlacklist holds the string denoting the ip_blacklist field in the database.
@@ -83,6 +85,7 @@ var Columns = []string{
FieldName,
FieldGroupID,
FieldStatus,
FieldLastUsedAt,
FieldIPWhitelist,
FieldIPBlacklist,
FieldQuota,
@@ -176,6 +179,11 @@ func ByStatus(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStatus, opts...).ToFunc()
}
// ByLastUsedAt orders the results by the last_used_at field.
func ByLastUsedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldLastUsedAt, opts...).ToFunc()
}
// ByQuota orders the results by the quota field.
func ByQuota(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldQuota, opts...).ToFunc()

View File

@@ -95,6 +95,11 @@ func Status(v string) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldStatus, v))
}
// LastUsedAt applies equality check predicate on the "last_used_at" field. It's identical to LastUsedAtEQ.
func LastUsedAt(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldLastUsedAt, v))
}
// Quota applies equality check predicate on the "quota" field. It's identical to QuotaEQ.
func Quota(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldQuota, v))
@@ -485,6 +490,56 @@ func StatusContainsFold(v string) predicate.APIKey {
return predicate.APIKey(sql.FieldContainsFold(FieldStatus, v))
}
// LastUsedAtEQ applies the EQ predicate on the "last_used_at" field.
func LastUsedAtEQ(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldLastUsedAt, v))
}
// LastUsedAtNEQ applies the NEQ predicate on the "last_used_at" field.
func LastUsedAtNEQ(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldNEQ(FieldLastUsedAt, v))
}
// LastUsedAtIn applies the In predicate on the "last_used_at" field.
func LastUsedAtIn(vs ...time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldIn(FieldLastUsedAt, vs...))
}
// LastUsedAtNotIn applies the NotIn predicate on the "last_used_at" field.
func LastUsedAtNotIn(vs ...time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldNotIn(FieldLastUsedAt, vs...))
}
// LastUsedAtGT applies the GT predicate on the "last_used_at" field.
func LastUsedAtGT(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldGT(FieldLastUsedAt, v))
}
// LastUsedAtGTE applies the GTE predicate on the "last_used_at" field.
func LastUsedAtGTE(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldGTE(FieldLastUsedAt, v))
}
// LastUsedAtLT applies the LT predicate on the "last_used_at" field.
func LastUsedAtLT(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldLT(FieldLastUsedAt, v))
}
// LastUsedAtLTE applies the LTE predicate on the "last_used_at" field.
func LastUsedAtLTE(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldLTE(FieldLastUsedAt, v))
}
// LastUsedAtIsNil applies the IsNil predicate on the "last_used_at" field.
func LastUsedAtIsNil() predicate.APIKey {
return predicate.APIKey(sql.FieldIsNull(FieldLastUsedAt))
}
// LastUsedAtNotNil applies the NotNil predicate on the "last_used_at" field.
func LastUsedAtNotNil() predicate.APIKey {
return predicate.APIKey(sql.FieldNotNull(FieldLastUsedAt))
}
// IPWhitelistIsNil applies the IsNil predicate on the "ip_whitelist" field.
func IPWhitelistIsNil() predicate.APIKey {
return predicate.APIKey(sql.FieldIsNull(FieldIPWhitelist))

View File

@@ -113,6 +113,20 @@ func (_c *APIKeyCreate) SetNillableStatus(v *string) *APIKeyCreate {
return _c
}
// SetLastUsedAt sets the "last_used_at" field.
func (_c *APIKeyCreate) SetLastUsedAt(v time.Time) *APIKeyCreate {
_c.mutation.SetLastUsedAt(v)
return _c
}
// SetNillableLastUsedAt sets the "last_used_at" field if the given value is not nil.
func (_c *APIKeyCreate) SetNillableLastUsedAt(v *time.Time) *APIKeyCreate {
if v != nil {
_c.SetLastUsedAt(*v)
}
return _c
}
// SetIPWhitelist sets the "ip_whitelist" field.
func (_c *APIKeyCreate) SetIPWhitelist(v []string) *APIKeyCreate {
_c.mutation.SetIPWhitelist(v)
@@ -353,6 +367,10 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) {
_spec.SetField(apikey.FieldStatus, field.TypeString, value)
_node.Status = value
}
if value, ok := _c.mutation.LastUsedAt(); ok {
_spec.SetField(apikey.FieldLastUsedAt, field.TypeTime, value)
_node.LastUsedAt = &value
}
if value, ok := _c.mutation.IPWhitelist(); ok {
_spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value)
_node.IPWhitelist = value
@@ -571,6 +589,24 @@ func (u *APIKeyUpsert) UpdateStatus() *APIKeyUpsert {
return u
}
// SetLastUsedAt sets the "last_used_at" field.
func (u *APIKeyUpsert) SetLastUsedAt(v time.Time) *APIKeyUpsert {
u.Set(apikey.FieldLastUsedAt, v)
return u
}
// UpdateLastUsedAt sets the "last_used_at" field to the value that was provided on create.
func (u *APIKeyUpsert) UpdateLastUsedAt() *APIKeyUpsert {
u.SetExcluded(apikey.FieldLastUsedAt)
return u
}
// ClearLastUsedAt clears the value of the "last_used_at" field.
func (u *APIKeyUpsert) ClearLastUsedAt() *APIKeyUpsert {
u.SetNull(apikey.FieldLastUsedAt)
return u
}
// SetIPWhitelist sets the "ip_whitelist" field.
func (u *APIKeyUpsert) SetIPWhitelist(v []string) *APIKeyUpsert {
u.Set(apikey.FieldIPWhitelist, v)
@@ -818,6 +854,27 @@ func (u *APIKeyUpsertOne) UpdateStatus() *APIKeyUpsertOne {
})
}
// SetLastUsedAt sets the "last_used_at" field.
func (u *APIKeyUpsertOne) SetLastUsedAt(v time.Time) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.SetLastUsedAt(v)
})
}
// UpdateLastUsedAt sets the "last_used_at" field to the value that was provided on create.
func (u *APIKeyUpsertOne) UpdateLastUsedAt() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateLastUsedAt()
})
}
// ClearLastUsedAt clears the value of the "last_used_at" field.
func (u *APIKeyUpsertOne) ClearLastUsedAt() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.ClearLastUsedAt()
})
}
// SetIPWhitelist sets the "ip_whitelist" field.
func (u *APIKeyUpsertOne) SetIPWhitelist(v []string) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
@@ -1246,6 +1303,27 @@ func (u *APIKeyUpsertBulk) UpdateStatus() *APIKeyUpsertBulk {
})
}
// SetLastUsedAt sets the "last_used_at" field.
func (u *APIKeyUpsertBulk) SetLastUsedAt(v time.Time) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.SetLastUsedAt(v)
})
}
// UpdateLastUsedAt sets the "last_used_at" field to the value that was provided on create.
func (u *APIKeyUpsertBulk) UpdateLastUsedAt() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateLastUsedAt()
})
}
// ClearLastUsedAt clears the value of the "last_used_at" field.
func (u *APIKeyUpsertBulk) ClearLastUsedAt() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.ClearLastUsedAt()
})
}
// SetIPWhitelist sets the "ip_whitelist" field.
func (u *APIKeyUpsertBulk) SetIPWhitelist(v []string) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {

View File

@@ -134,6 +134,26 @@ func (_u *APIKeyUpdate) SetNillableStatus(v *string) *APIKeyUpdate {
return _u
}
// SetLastUsedAt sets the "last_used_at" field.
func (_u *APIKeyUpdate) SetLastUsedAt(v time.Time) *APIKeyUpdate {
_u.mutation.SetLastUsedAt(v)
return _u
}
// SetNillableLastUsedAt sets the "last_used_at" field if the given value is not nil.
func (_u *APIKeyUpdate) SetNillableLastUsedAt(v *time.Time) *APIKeyUpdate {
if v != nil {
_u.SetLastUsedAt(*v)
}
return _u
}
// ClearLastUsedAt clears the value of the "last_used_at" field.
func (_u *APIKeyUpdate) ClearLastUsedAt() *APIKeyUpdate {
_u.mutation.ClearLastUsedAt()
return _u
}
// SetIPWhitelist sets the "ip_whitelist" field.
func (_u *APIKeyUpdate) SetIPWhitelist(v []string) *APIKeyUpdate {
_u.mutation.SetIPWhitelist(v)
@@ -390,6 +410,12 @@ func (_u *APIKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(apikey.FieldStatus, field.TypeString, value)
}
if value, ok := _u.mutation.LastUsedAt(); ok {
_spec.SetField(apikey.FieldLastUsedAt, field.TypeTime, value)
}
if _u.mutation.LastUsedAtCleared() {
_spec.ClearField(apikey.FieldLastUsedAt, field.TypeTime)
}
if value, ok := _u.mutation.IPWhitelist(); ok {
_spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value)
}
@@ -655,6 +681,26 @@ func (_u *APIKeyUpdateOne) SetNillableStatus(v *string) *APIKeyUpdateOne {
return _u
}
// SetLastUsedAt sets the "last_used_at" field.
func (_u *APIKeyUpdateOne) SetLastUsedAt(v time.Time) *APIKeyUpdateOne {
_u.mutation.SetLastUsedAt(v)
return _u
}
// SetNillableLastUsedAt sets the "last_used_at" field if the given value is not nil.
func (_u *APIKeyUpdateOne) SetNillableLastUsedAt(v *time.Time) *APIKeyUpdateOne {
if v != nil {
_u.SetLastUsedAt(*v)
}
return _u
}
// ClearLastUsedAt clears the value of the "last_used_at" field.
func (_u *APIKeyUpdateOne) ClearLastUsedAt() *APIKeyUpdateOne {
_u.mutation.ClearLastUsedAt()
return _u
}
// SetIPWhitelist sets the "ip_whitelist" field.
func (_u *APIKeyUpdateOne) SetIPWhitelist(v []string) *APIKeyUpdateOne {
_u.mutation.SetIPWhitelist(v)
@@ -941,6 +987,12 @@ func (_u *APIKeyUpdateOne) sqlSave(ctx context.Context) (_node *APIKey, err erro
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(apikey.FieldStatus, field.TypeString, value)
}
if value, ok := _u.mutation.LastUsedAt(); ok {
_spec.SetField(apikey.FieldLastUsedAt, field.TypeTime, value)
}
if _u.mutation.LastUsedAtCleared() {
_spec.ClearField(apikey.FieldLastUsedAt, field.TypeTime)
}
if value, ok := _u.mutation.IPWhitelist(); ok {
_spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value)
}

View File

@@ -26,6 +26,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/ent/setting"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
@@ -65,6 +66,8 @@ type Client struct {
Proxy *ProxyClient
// RedeemCode is the client for interacting with the RedeemCode builders.
RedeemCode *RedeemCodeClient
// SecuritySecret is the client for interacting with the SecuritySecret builders.
SecuritySecret *SecuritySecretClient
// Setting is the client for interacting with the Setting builders.
Setting *SettingClient
// UsageCleanupTask is the client for interacting with the UsageCleanupTask builders.
@@ -103,6 +106,7 @@ func (c *Client) init() {
c.PromoCodeUsage = NewPromoCodeUsageClient(c.config)
c.Proxy = NewProxyClient(c.config)
c.RedeemCode = NewRedeemCodeClient(c.config)
c.SecuritySecret = NewSecuritySecretClient(c.config)
c.Setting = NewSettingClient(c.config)
c.UsageCleanupTask = NewUsageCleanupTaskClient(c.config)
c.UsageLog = NewUsageLogClient(c.config)
@@ -214,6 +218,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
PromoCodeUsage: NewPromoCodeUsageClient(cfg),
Proxy: NewProxyClient(cfg),
RedeemCode: NewRedeemCodeClient(cfg),
SecuritySecret: NewSecuritySecretClient(cfg),
Setting: NewSettingClient(cfg),
UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
UsageLog: NewUsageLogClient(cfg),
@@ -252,6 +257,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
PromoCodeUsage: NewPromoCodeUsageClient(cfg),
Proxy: NewProxyClient(cfg),
RedeemCode: NewRedeemCodeClient(cfg),
SecuritySecret: NewSecuritySecretClient(cfg),
Setting: NewSettingClient(cfg),
UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
UsageLog: NewUsageLogClient(cfg),
@@ -291,8 +297,8 @@ func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy,
c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User,
c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.RedeemCode, c.SecuritySecret, c.Setting, c.UsageCleanupTask, c.UsageLog,
c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
} {
n.Use(hooks...)
@@ -305,8 +311,8 @@ func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy,
c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User,
c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.RedeemCode, c.SecuritySecret, c.Setting, c.UsageCleanupTask, c.UsageLog,
c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
} {
n.Intercept(interceptors...)
@@ -338,6 +344,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.Proxy.mutate(ctx, m)
case *RedeemCodeMutation:
return c.RedeemCode.mutate(ctx, m)
case *SecuritySecretMutation:
return c.SecuritySecret.mutate(ctx, m)
case *SettingMutation:
return c.Setting.mutate(ctx, m)
case *UsageCleanupTaskMutation:
@@ -2197,6 +2205,139 @@ func (c *RedeemCodeClient) mutate(ctx context.Context, m *RedeemCodeMutation) (V
}
}
// SecuritySecretClient is a client for the SecuritySecret schema.
type SecuritySecretClient struct {
config
}
// NewSecuritySecretClient returns a client for the SecuritySecret from the given config.
func NewSecuritySecretClient(c config) *SecuritySecretClient {
return &SecuritySecretClient{config: c}
}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `securitysecret.Hooks(f(g(h())))`.
func (c *SecuritySecretClient) Use(hooks ...Hook) {
c.hooks.SecuritySecret = append(c.hooks.SecuritySecret, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `securitysecret.Intercept(f(g(h())))`.
func (c *SecuritySecretClient) Intercept(interceptors ...Interceptor) {
c.inters.SecuritySecret = append(c.inters.SecuritySecret, interceptors...)
}
// Create returns a builder for creating a SecuritySecret entity.
func (c *SecuritySecretClient) Create() *SecuritySecretCreate {
mutation := newSecuritySecretMutation(c.config, OpCreate)
return &SecuritySecretCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of SecuritySecret entities.
func (c *SecuritySecretClient) CreateBulk(builders ...*SecuritySecretCreate) *SecuritySecretCreateBulk {
return &SecuritySecretCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *SecuritySecretClient) MapCreateBulk(slice any, setFunc func(*SecuritySecretCreate, int)) *SecuritySecretCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &SecuritySecretCreateBulk{err: fmt.Errorf("calling to SecuritySecretClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*SecuritySecretCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &SecuritySecretCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for SecuritySecret.
func (c *SecuritySecretClient) Update() *SecuritySecretUpdate {
mutation := newSecuritySecretMutation(c.config, OpUpdate)
return &SecuritySecretUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *SecuritySecretClient) UpdateOne(_m *SecuritySecret) *SecuritySecretUpdateOne {
mutation := newSecuritySecretMutation(c.config, OpUpdateOne, withSecuritySecret(_m))
return &SecuritySecretUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOneID returns an update builder for the given id.
func (c *SecuritySecretClient) UpdateOneID(id int64) *SecuritySecretUpdateOne {
mutation := newSecuritySecretMutation(c.config, OpUpdateOne, withSecuritySecretID(id))
return &SecuritySecretUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// Delete returns a delete builder for SecuritySecret.
func (c *SecuritySecretClient) Delete() *SecuritySecretDelete {
mutation := newSecuritySecretMutation(c.config, OpDelete)
return &SecuritySecretDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// DeleteOne returns a builder for deleting the given entity.
func (c *SecuritySecretClient) DeleteOne(_m *SecuritySecret) *SecuritySecretDeleteOne {
return c.DeleteOneID(_m.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *SecuritySecretClient) DeleteOneID(id int64) *SecuritySecretDeleteOne {
builder := c.Delete().Where(securitysecret.ID(id))
builder.mutation.id = &id
builder.mutation.op = OpDeleteOne
return &SecuritySecretDeleteOne{builder}
}
// Query returns a query builder for SecuritySecret.
func (c *SecuritySecretClient) Query() *SecuritySecretQuery {
return &SecuritySecretQuery{
config: c.config,
ctx: &QueryContext{Type: TypeSecuritySecret},
inters: c.Interceptors(),
}
}
// Get returns a SecuritySecret entity by its id.
func (c *SecuritySecretClient) Get(ctx context.Context, id int64) (*SecuritySecret, error) {
return c.Query().Where(securitysecret.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *SecuritySecretClient) GetX(ctx context.Context, id int64) *SecuritySecret {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
// Hooks returns the client hooks.
func (c *SecuritySecretClient) Hooks() []Hook {
return c.hooks.SecuritySecret
}
// Interceptors returns the client interceptors.
func (c *SecuritySecretClient) Interceptors() []Interceptor {
return c.inters.SecuritySecret
}
func (c *SecuritySecretClient) mutate(ctx context.Context, m *SecuritySecretMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&SecuritySecretCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&SecuritySecretUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&SecuritySecretUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&SecuritySecretDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown SecuritySecret mutation op: %q", m.Op())
}
}
// SettingClient is a client for the Setting schema.
type SettingClient struct {
config
@@ -3607,13 +3748,13 @@ type (
hooks struct {
APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
SecuritySecret, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook
}
inters struct {
APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
SecuritySecret, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor
}
)

View File

@@ -23,6 +23,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/ent/setting"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
@@ -102,6 +103,7 @@ func checkColumn(t, c string) error {
promocodeusage.Table: promocodeusage.ValidColumn,
proxy.Table: proxy.ValidColumn,
redeemcode.Table: redeemcode.ValidColumn,
securitysecret.Table: securitysecret.ValidColumn,
setting.Table: setting.ValidColumn,
usagecleanuptask.Table: usagecleanuptask.ValidColumn,
usagelog.Table: usagelog.ValidColumn,

View File

@@ -52,6 +52,14 @@ type Group struct {
ImagePrice2k *float64 `json:"image_price_2k,omitempty"`
// ImagePrice4k holds the value of the "image_price_4k" field.
ImagePrice4k *float64 `json:"image_price_4k,omitempty"`
// SoraImagePrice360 holds the value of the "sora_image_price_360" field.
SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"`
// SoraImagePrice540 holds the value of the "sora_image_price_540" field.
SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"`
// SoraVideoPricePerRequest holds the value of the "sora_video_price_per_request" field.
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
// SoraVideoPricePerRequestHd holds the value of the "sora_video_price_per_request_hd" field.
SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"`
// 是否仅允许 Claude Code 客户端
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
// 非 Claude Code 请求降级使用的分组 ID
@@ -178,7 +186,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
values[i] = new([]byte)
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject:
values[i] = new(sql.NullBool)
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
values[i] = new(sql.NullFloat64)
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder:
values[i] = new(sql.NullInt64)
@@ -317,6 +325,34 @@ func (_m *Group) assignValues(columns []string, values []any) error {
_m.ImagePrice4k = new(float64)
*_m.ImagePrice4k = value.Float64
}
case group.FieldSoraImagePrice360:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field sora_image_price_360", values[i])
} else if value.Valid {
_m.SoraImagePrice360 = new(float64)
*_m.SoraImagePrice360 = value.Float64
}
case group.FieldSoraImagePrice540:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field sora_image_price_540", values[i])
} else if value.Valid {
_m.SoraImagePrice540 = new(float64)
*_m.SoraImagePrice540 = value.Float64
}
case group.FieldSoraVideoPricePerRequest:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field sora_video_price_per_request", values[i])
} else if value.Valid {
_m.SoraVideoPricePerRequest = new(float64)
*_m.SoraVideoPricePerRequest = value.Float64
}
case group.FieldSoraVideoPricePerRequestHd:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field sora_video_price_per_request_hd", values[i])
} else if value.Valid {
_m.SoraVideoPricePerRequestHd = new(float64)
*_m.SoraVideoPricePerRequestHd = value.Float64
}
case group.FieldClaudeCodeOnly:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field claude_code_only", values[i])
@@ -514,6 +550,26 @@ func (_m *Group) String() string {
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
if v := _m.SoraImagePrice360; v != nil {
builder.WriteString("sora_image_price_360=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
if v := _m.SoraImagePrice540; v != nil {
builder.WriteString("sora_image_price_540=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
if v := _m.SoraVideoPricePerRequest; v != nil {
builder.WriteString("sora_video_price_per_request=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
if v := _m.SoraVideoPricePerRequestHd; v != nil {
builder.WriteString("sora_video_price_per_request_hd=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("claude_code_only=")
builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly))
builder.WriteString(", ")

View File

@@ -49,6 +49,14 @@ const (
FieldImagePrice2k = "image_price_2k"
// FieldImagePrice4k holds the string denoting the image_price_4k field in the database.
FieldImagePrice4k = "image_price_4k"
// FieldSoraImagePrice360 holds the string denoting the sora_image_price_360 field in the database.
FieldSoraImagePrice360 = "sora_image_price_360"
// FieldSoraImagePrice540 holds the string denoting the sora_image_price_540 field in the database.
FieldSoraImagePrice540 = "sora_image_price_540"
// FieldSoraVideoPricePerRequest holds the string denoting the sora_video_price_per_request field in the database.
FieldSoraVideoPricePerRequest = "sora_video_price_per_request"
// FieldSoraVideoPricePerRequestHd holds the string denoting the sora_video_price_per_request_hd field in the database.
FieldSoraVideoPricePerRequestHd = "sora_video_price_per_request_hd"
// FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database.
FieldClaudeCodeOnly = "claude_code_only"
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
@@ -157,6 +165,10 @@ var Columns = []string{
FieldImagePrice1k,
FieldImagePrice2k,
FieldImagePrice4k,
FieldSoraImagePrice360,
FieldSoraImagePrice540,
FieldSoraVideoPricePerRequest,
FieldSoraVideoPricePerRequestHd,
FieldClaudeCodeOnly,
FieldFallbackGroupID,
FieldFallbackGroupIDOnInvalidRequest,
@@ -325,6 +337,26 @@ func ByImagePrice4k(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc()
}
// BySoraImagePrice360 orders the results by the sora_image_price_360 field.
func BySoraImagePrice360(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSoraImagePrice360, opts...).ToFunc()
}
// BySoraImagePrice540 orders the results by the sora_image_price_540 field.
func BySoraImagePrice540(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSoraImagePrice540, opts...).ToFunc()
}
// BySoraVideoPricePerRequest orders the results by the sora_video_price_per_request field.
func BySoraVideoPricePerRequest(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSoraVideoPricePerRequest, opts...).ToFunc()
}
// BySoraVideoPricePerRequestHd orders the results by the sora_video_price_per_request_hd field.
func BySoraVideoPricePerRequestHd(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSoraVideoPricePerRequestHd, opts...).ToFunc()
}
// ByClaudeCodeOnly orders the results by the claude_code_only field.
func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc()

View File

@@ -140,6 +140,26 @@ func ImagePrice4k(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v))
}
// SoraImagePrice360 applies equality check predicate on the "sora_image_price_360" field. It's identical to SoraImagePrice360EQ.
func SoraImagePrice360(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v))
}
// SoraImagePrice540 applies equality check predicate on the "sora_image_price_540" field. It's identical to SoraImagePrice540EQ.
func SoraImagePrice540(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v))
}
// SoraVideoPricePerRequest applies equality check predicate on the "sora_video_price_per_request" field. It's identical to SoraVideoPricePerRequestEQ.
func SoraVideoPricePerRequest(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v))
}
// SoraVideoPricePerRequestHd applies equality check predicate on the "sora_video_price_per_request_hd" field. It's identical to SoraVideoPricePerRequestHdEQ.
func SoraVideoPricePerRequestHd(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v))
}
// ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ.
func ClaudeCodeOnly(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
@@ -1025,6 +1045,206 @@ func ImagePrice4kNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldImagePrice4k))
}
// SoraImagePrice360EQ applies the EQ predicate on the "sora_image_price_360" field.
func SoraImagePrice360EQ(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v))
}
// SoraImagePrice360NEQ applies the NEQ predicate on the "sora_image_price_360" field.
func SoraImagePrice360NEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice360, v))
}
// SoraImagePrice360In applies the In predicate on the "sora_image_price_360" field.
func SoraImagePrice360In(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldIn(FieldSoraImagePrice360, vs...))
}
// SoraImagePrice360NotIn applies the NotIn predicate on the "sora_image_price_360" field.
func SoraImagePrice360NotIn(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice360, vs...))
}
// SoraImagePrice360GT applies the GT predicate on the "sora_image_price_360" field.
func SoraImagePrice360GT(v float64) predicate.Group {
return predicate.Group(sql.FieldGT(FieldSoraImagePrice360, v))
}
// SoraImagePrice360GTE applies the GTE predicate on the "sora_image_price_360" field.
func SoraImagePrice360GTE(v float64) predicate.Group {
return predicate.Group(sql.FieldGTE(FieldSoraImagePrice360, v))
}
// SoraImagePrice360LT applies the LT predicate on the "sora_image_price_360" field.
func SoraImagePrice360LT(v float64) predicate.Group {
return predicate.Group(sql.FieldLT(FieldSoraImagePrice360, v))
}
// SoraImagePrice360LTE applies the LTE predicate on the "sora_image_price_360" field.
func SoraImagePrice360LTE(v float64) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldSoraImagePrice360, v))
}
// SoraImagePrice360IsNil applies the IsNil predicate on the "sora_image_price_360" field.
func SoraImagePrice360IsNil() predicate.Group {
return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice360))
}
// SoraImagePrice360NotNil applies the NotNil predicate on the "sora_image_price_360" field.
func SoraImagePrice360NotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice360))
}
// SoraImagePrice540EQ applies the EQ predicate on the "sora_image_price_540" field.
func SoraImagePrice540EQ(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v))
}
// SoraImagePrice540NEQ applies the NEQ predicate on the "sora_image_price_540" field.
func SoraImagePrice540NEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice540, v))
}
// SoraImagePrice540In applies the In predicate on the "sora_image_price_540" field.
func SoraImagePrice540In(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldIn(FieldSoraImagePrice540, vs...))
}
// SoraImagePrice540NotIn applies the NotIn predicate on the "sora_image_price_540" field.
func SoraImagePrice540NotIn(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice540, vs...))
}
// SoraImagePrice540GT applies the GT predicate on the "sora_image_price_540" field.
func SoraImagePrice540GT(v float64) predicate.Group {
return predicate.Group(sql.FieldGT(FieldSoraImagePrice540, v))
}
// SoraImagePrice540GTE applies the GTE predicate on the "sora_image_price_540" field.
func SoraImagePrice540GTE(v float64) predicate.Group {
return predicate.Group(sql.FieldGTE(FieldSoraImagePrice540, v))
}
// SoraImagePrice540LT applies the LT predicate on the "sora_image_price_540" field.
func SoraImagePrice540LT(v float64) predicate.Group {
return predicate.Group(sql.FieldLT(FieldSoraImagePrice540, v))
}
// SoraImagePrice540LTE applies the LTE predicate on the "sora_image_price_540" field.
func SoraImagePrice540LTE(v float64) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldSoraImagePrice540, v))
}
// SoraImagePrice540IsNil applies the IsNil predicate on the "sora_image_price_540" field.
func SoraImagePrice540IsNil() predicate.Group {
return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice540))
}
// SoraImagePrice540NotNil applies the NotNil predicate on the "sora_image_price_540" field.
func SoraImagePrice540NotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice540))
}
// SoraVideoPricePerRequestEQ applies the EQ predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v))
}
// SoraVideoPricePerRequestNEQ applies the NEQ predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestNEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequest, v))
}
// SoraVideoPricePerRequestIn applies the In predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestIn(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequest, vs...))
}
// SoraVideoPricePerRequestNotIn applies the NotIn predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestNotIn(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequest, vs...))
}
// SoraVideoPricePerRequestGT applies the GT predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestGT(v float64) predicate.Group {
return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequest, v))
}
// SoraVideoPricePerRequestGTE applies the GTE predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestGTE(v float64) predicate.Group {
return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequest, v))
}
// SoraVideoPricePerRequestLT applies the LT predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestLT(v float64) predicate.Group {
return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequest, v))
}
// SoraVideoPricePerRequestLTE applies the LTE predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestLTE(v float64) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequest, v))
}
// SoraVideoPricePerRequestIsNil applies the IsNil predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestIsNil() predicate.Group {
return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequest))
}
// SoraVideoPricePerRequestNotNil applies the NotNil predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequest))
}
// SoraVideoPricePerRequestHdEQ applies the EQ predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v))
}
// SoraVideoPricePerRequestHdNEQ applies the NEQ predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdNEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequestHd, v))
}
// SoraVideoPricePerRequestHdIn applies the In predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdIn(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequestHd, vs...))
}
// SoraVideoPricePerRequestHdNotIn applies the NotIn predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdNotIn(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequestHd, vs...))
}
// SoraVideoPricePerRequestHdGT applies the GT predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdGT(v float64) predicate.Group {
return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequestHd, v))
}
// SoraVideoPricePerRequestHdGTE applies the GTE predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdGTE(v float64) predicate.Group {
return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequestHd, v))
}
// SoraVideoPricePerRequestHdLT applies the LT predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdLT(v float64) predicate.Group {
return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequestHd, v))
}
// SoraVideoPricePerRequestHdLTE applies the LTE predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdLTE(v float64) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequestHd, v))
}
// SoraVideoPricePerRequestHdIsNil applies the IsNil predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdIsNil() predicate.Group {
return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequestHd))
}
// SoraVideoPricePerRequestHdNotNil applies the NotNil predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequestHd))
}
// ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field.
func ClaudeCodeOnlyEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))

View File

@@ -258,6 +258,62 @@ func (_c *GroupCreate) SetNillableImagePrice4k(v *float64) *GroupCreate {
return _c
}
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
func (_c *GroupCreate) SetSoraImagePrice360(v float64) *GroupCreate {
_c.mutation.SetSoraImagePrice360(v)
return _c
}
// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil.
func (_c *GroupCreate) SetNillableSoraImagePrice360(v *float64) *GroupCreate {
if v != nil {
_c.SetSoraImagePrice360(*v)
}
return _c
}
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
func (_c *GroupCreate) SetSoraImagePrice540(v float64) *GroupCreate {
_c.mutation.SetSoraImagePrice540(v)
return _c
}
// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil.
func (_c *GroupCreate) SetNillableSoraImagePrice540(v *float64) *GroupCreate {
if v != nil {
_c.SetSoraImagePrice540(*v)
}
return _c
}
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
func (_c *GroupCreate) SetSoraVideoPricePerRequest(v float64) *GroupCreate {
_c.mutation.SetSoraVideoPricePerRequest(v)
return _c
}
// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil.
func (_c *GroupCreate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupCreate {
if v != nil {
_c.SetSoraVideoPricePerRequest(*v)
}
return _c
}
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
func (_c *GroupCreate) SetSoraVideoPricePerRequestHd(v float64) *GroupCreate {
_c.mutation.SetSoraVideoPricePerRequestHd(v)
return _c
}
// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil.
func (_c *GroupCreate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupCreate {
if v != nil {
_c.SetSoraVideoPricePerRequestHd(*v)
}
return _c
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate {
_c.mutation.SetClaudeCodeOnly(v)
@@ -701,6 +757,22 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value)
_node.ImagePrice4k = &value
}
if value, ok := _c.mutation.SoraImagePrice360(); ok {
_spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
_node.SoraImagePrice360 = &value
}
if value, ok := _c.mutation.SoraImagePrice540(); ok {
_spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
_node.SoraImagePrice540 = &value
}
if value, ok := _c.mutation.SoraVideoPricePerRequest(); ok {
_spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
_node.SoraVideoPricePerRequest = &value
}
if value, ok := _c.mutation.SoraVideoPricePerRequestHd(); ok {
_spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
_node.SoraVideoPricePerRequestHd = &value
}
if value, ok := _c.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
_node.ClaudeCodeOnly = value
@@ -1177,6 +1249,102 @@ func (u *GroupUpsert) ClearImagePrice4k() *GroupUpsert {
return u
}
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
func (u *GroupUpsert) SetSoraImagePrice360(v float64) *GroupUpsert {
u.Set(group.FieldSoraImagePrice360, v)
return u
}
// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create.
func (u *GroupUpsert) UpdateSoraImagePrice360() *GroupUpsert {
u.SetExcluded(group.FieldSoraImagePrice360)
return u
}
// AddSoraImagePrice360 adds v to the "sora_image_price_360" field.
func (u *GroupUpsert) AddSoraImagePrice360(v float64) *GroupUpsert {
u.Add(group.FieldSoraImagePrice360, v)
return u
}
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
func (u *GroupUpsert) ClearSoraImagePrice360() *GroupUpsert {
u.SetNull(group.FieldSoraImagePrice360)
return u
}
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
func (u *GroupUpsert) SetSoraImagePrice540(v float64) *GroupUpsert {
u.Set(group.FieldSoraImagePrice540, v)
return u
}
// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create.
func (u *GroupUpsert) UpdateSoraImagePrice540() *GroupUpsert {
u.SetExcluded(group.FieldSoraImagePrice540)
return u
}
// AddSoraImagePrice540 adds v to the "sora_image_price_540" field.
func (u *GroupUpsert) AddSoraImagePrice540(v float64) *GroupUpsert {
u.Add(group.FieldSoraImagePrice540, v)
return u
}
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
func (u *GroupUpsert) ClearSoraImagePrice540() *GroupUpsert {
u.SetNull(group.FieldSoraImagePrice540)
return u
}
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
func (u *GroupUpsert) SetSoraVideoPricePerRequest(v float64) *GroupUpsert {
u.Set(group.FieldSoraVideoPricePerRequest, v)
return u
}
// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create.
func (u *GroupUpsert) UpdateSoraVideoPricePerRequest() *GroupUpsert {
u.SetExcluded(group.FieldSoraVideoPricePerRequest)
return u
}
// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field.
func (u *GroupUpsert) AddSoraVideoPricePerRequest(v float64) *GroupUpsert {
u.Add(group.FieldSoraVideoPricePerRequest, v)
return u
}
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
func (u *GroupUpsert) ClearSoraVideoPricePerRequest() *GroupUpsert {
u.SetNull(group.FieldSoraVideoPricePerRequest)
return u
}
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
func (u *GroupUpsert) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsert {
u.Set(group.FieldSoraVideoPricePerRequestHd, v)
return u
}
// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create.
func (u *GroupUpsert) UpdateSoraVideoPricePerRequestHd() *GroupUpsert {
u.SetExcluded(group.FieldSoraVideoPricePerRequestHd)
return u
}
// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field.
func (u *GroupUpsert) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsert {
u.Add(group.FieldSoraVideoPricePerRequestHd, v)
return u
}
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
func (u *GroupUpsert) ClearSoraVideoPricePerRequestHd() *GroupUpsert {
u.SetNull(group.FieldSoraVideoPricePerRequestHd)
return u
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert {
u.Set(group.FieldClaudeCodeOnly, v)
@@ -1690,6 +1858,118 @@ func (u *GroupUpsertOne) ClearImagePrice4k() *GroupUpsertOne {
})
}
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
func (u *GroupUpsertOne) SetSoraImagePrice360(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetSoraImagePrice360(v)
})
}
// AddSoraImagePrice360 adds v to the "sora_image_price_360" field.
func (u *GroupUpsertOne) AddSoraImagePrice360(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.AddSoraImagePrice360(v)
})
}
// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateSoraImagePrice360() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateSoraImagePrice360()
})
}
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
func (u *GroupUpsertOne) ClearSoraImagePrice360() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.ClearSoraImagePrice360()
})
}
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
func (u *GroupUpsertOne) SetSoraImagePrice540(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetSoraImagePrice540(v)
})
}
// AddSoraImagePrice540 adds v to the "sora_image_price_540" field.
func (u *GroupUpsertOne) AddSoraImagePrice540(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.AddSoraImagePrice540(v)
})
}
// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateSoraImagePrice540() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateSoraImagePrice540()
})
}
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
func (u *GroupUpsertOne) ClearSoraImagePrice540() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.ClearSoraImagePrice540()
})
}
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
func (u *GroupUpsertOne) SetSoraVideoPricePerRequest(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetSoraVideoPricePerRequest(v)
})
}
// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field.
func (u *GroupUpsertOne) AddSoraVideoPricePerRequest(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.AddSoraVideoPricePerRequest(v)
})
}
// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequest() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateSoraVideoPricePerRequest()
})
}
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
func (u *GroupUpsertOne) ClearSoraVideoPricePerRequest() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.ClearSoraVideoPricePerRequest()
})
}
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
func (u *GroupUpsertOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetSoraVideoPricePerRequestHd(v)
})
}
// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field.
func (u *GroupUpsertOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.AddSoraVideoPricePerRequestHd(v)
})
}
// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequestHd() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateSoraVideoPricePerRequestHd()
})
}
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
func (u *GroupUpsertOne) ClearSoraVideoPricePerRequestHd() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.ClearSoraVideoPricePerRequestHd()
})
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
@@ -2391,6 +2671,118 @@ func (u *GroupUpsertBulk) ClearImagePrice4k() *GroupUpsertBulk {
})
}
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
func (u *GroupUpsertBulk) SetSoraImagePrice360(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetSoraImagePrice360(v)
})
}
// AddSoraImagePrice360 adds v to the "sora_image_price_360" field.
func (u *GroupUpsertBulk) AddSoraImagePrice360(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.AddSoraImagePrice360(v)
})
}
// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateSoraImagePrice360() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateSoraImagePrice360()
})
}
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
func (u *GroupUpsertBulk) ClearSoraImagePrice360() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.ClearSoraImagePrice360()
})
}
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
func (u *GroupUpsertBulk) SetSoraImagePrice540(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetSoraImagePrice540(v)
})
}
// AddSoraImagePrice540 adds v to the "sora_image_price_540" field.
func (u *GroupUpsertBulk) AddSoraImagePrice540(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.AddSoraImagePrice540(v)
})
}
// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateSoraImagePrice540() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateSoraImagePrice540()
})
}
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
func (u *GroupUpsertBulk) ClearSoraImagePrice540() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.ClearSoraImagePrice540()
})
}
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
func (u *GroupUpsertBulk) SetSoraVideoPricePerRequest(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetSoraVideoPricePerRequest(v)
})
}
// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field.
func (u *GroupUpsertBulk) AddSoraVideoPricePerRequest(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.AddSoraVideoPricePerRequest(v)
})
}
// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequest() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateSoraVideoPricePerRequest()
})
}
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequest() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.ClearSoraVideoPricePerRequest()
})
}
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
func (u *GroupUpsertBulk) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetSoraVideoPricePerRequestHd(v)
})
}
// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field.
func (u *GroupUpsertBulk) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.AddSoraVideoPricePerRequestHd(v)
})
}
// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequestHd() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateSoraVideoPricePerRequestHd()
})
}
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequestHd() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.ClearSoraVideoPricePerRequestHd()
})
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {

View File

@@ -355,6 +355,114 @@ func (_u *GroupUpdate) ClearImagePrice4k() *GroupUpdate {
return _u
}
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
func (_u *GroupUpdate) SetSoraImagePrice360(v float64) *GroupUpdate {
_u.mutation.ResetSoraImagePrice360()
_u.mutation.SetSoraImagePrice360(v)
return _u
}
// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableSoraImagePrice360(v *float64) *GroupUpdate {
if v != nil {
_u.SetSoraImagePrice360(*v)
}
return _u
}
// AddSoraImagePrice360 adds value to the "sora_image_price_360" field.
func (_u *GroupUpdate) AddSoraImagePrice360(v float64) *GroupUpdate {
_u.mutation.AddSoraImagePrice360(v)
return _u
}
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
func (_u *GroupUpdate) ClearSoraImagePrice360() *GroupUpdate {
_u.mutation.ClearSoraImagePrice360()
return _u
}
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
func (_u *GroupUpdate) SetSoraImagePrice540(v float64) *GroupUpdate {
_u.mutation.ResetSoraImagePrice540()
_u.mutation.SetSoraImagePrice540(v)
return _u
}
// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableSoraImagePrice540(v *float64) *GroupUpdate {
if v != nil {
_u.SetSoraImagePrice540(*v)
}
return _u
}
// AddSoraImagePrice540 adds value to the "sora_image_price_540" field.
func (_u *GroupUpdate) AddSoraImagePrice540(v float64) *GroupUpdate {
_u.mutation.AddSoraImagePrice540(v)
return _u
}
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
func (_u *GroupUpdate) ClearSoraImagePrice540() *GroupUpdate {
_u.mutation.ClearSoraImagePrice540()
return _u
}
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
func (_u *GroupUpdate) SetSoraVideoPricePerRequest(v float64) *GroupUpdate {
_u.mutation.ResetSoraVideoPricePerRequest()
_u.mutation.SetSoraVideoPricePerRequest(v)
return _u
}
// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdate {
if v != nil {
_u.SetSoraVideoPricePerRequest(*v)
}
return _u
}
// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field.
func (_u *GroupUpdate) AddSoraVideoPricePerRequest(v float64) *GroupUpdate {
_u.mutation.AddSoraVideoPricePerRequest(v)
return _u
}
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
func (_u *GroupUpdate) ClearSoraVideoPricePerRequest() *GroupUpdate {
_u.mutation.ClearSoraVideoPricePerRequest()
return _u
}
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
func (_u *GroupUpdate) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdate {
_u.mutation.ResetSoraVideoPricePerRequestHd()
_u.mutation.SetSoraVideoPricePerRequestHd(v)
return _u
}
// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdate {
if v != nil {
_u.SetSoraVideoPricePerRequestHd(*v)
}
return _u
}
// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field.
func (_u *GroupUpdate) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdate {
_u.mutation.AddSoraVideoPricePerRequestHd(v)
return _u
}
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
func (_u *GroupUpdate) ClearSoraVideoPricePerRequestHd() *GroupUpdate {
_u.mutation.ClearSoraVideoPricePerRequestHd()
return _u
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate {
_u.mutation.SetClaudeCodeOnly(v)
@@ -892,6 +1000,42 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.ImagePrice4kCleared() {
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
}
if value, ok := _u.mutation.SoraImagePrice360(); ok {
_spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedSoraImagePrice360(); ok {
_spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
}
if _u.mutation.SoraImagePrice360Cleared() {
_spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64)
}
if value, ok := _u.mutation.SoraImagePrice540(); ok {
_spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedSoraImagePrice540(); ok {
_spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
}
if _u.mutation.SoraImagePrice540Cleared() {
_spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64)
}
if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok {
_spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok {
_spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
}
if _u.mutation.SoraVideoPricePerRequestCleared() {
_spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64)
}
if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok {
_spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok {
_spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
}
if _u.mutation.SoraVideoPricePerRequestHdCleared() {
_spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64)
}
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
}
@@ -1573,6 +1717,114 @@ func (_u *GroupUpdateOne) ClearImagePrice4k() *GroupUpdateOne {
return _u
}
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
func (_u *GroupUpdateOne) SetSoraImagePrice360(v float64) *GroupUpdateOne {
_u.mutation.ResetSoraImagePrice360()
_u.mutation.SetSoraImagePrice360(v)
return _u
}
// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableSoraImagePrice360(v *float64) *GroupUpdateOne {
if v != nil {
_u.SetSoraImagePrice360(*v)
}
return _u
}
// AddSoraImagePrice360 adds value to the "sora_image_price_360" field.
func (_u *GroupUpdateOne) AddSoraImagePrice360(v float64) *GroupUpdateOne {
_u.mutation.AddSoraImagePrice360(v)
return _u
}
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
func (_u *GroupUpdateOne) ClearSoraImagePrice360() *GroupUpdateOne {
_u.mutation.ClearSoraImagePrice360()
return _u
}
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
func (_u *GroupUpdateOne) SetSoraImagePrice540(v float64) *GroupUpdateOne {
_u.mutation.ResetSoraImagePrice540()
_u.mutation.SetSoraImagePrice540(v)
return _u
}
// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableSoraImagePrice540(v *float64) *GroupUpdateOne {
if v != nil {
_u.SetSoraImagePrice540(*v)
}
return _u
}
// AddSoraImagePrice540 adds value to the "sora_image_price_540" field.
func (_u *GroupUpdateOne) AddSoraImagePrice540(v float64) *GroupUpdateOne {
_u.mutation.AddSoraImagePrice540(v)
return _u
}
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
func (_u *GroupUpdateOne) ClearSoraImagePrice540() *GroupUpdateOne {
_u.mutation.ClearSoraImagePrice540()
return _u
}
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
func (_u *GroupUpdateOne) SetSoraVideoPricePerRequest(v float64) *GroupUpdateOne {
_u.mutation.ResetSoraVideoPricePerRequest()
_u.mutation.SetSoraVideoPricePerRequest(v)
return _u
}
// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdateOne {
if v != nil {
_u.SetSoraVideoPricePerRequest(*v)
}
return _u
}
// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field.
func (_u *GroupUpdateOne) AddSoraVideoPricePerRequest(v float64) *GroupUpdateOne {
_u.mutation.AddSoraVideoPricePerRequest(v)
return _u
}
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequest() *GroupUpdateOne {
_u.mutation.ClearSoraVideoPricePerRequest()
return _u
}
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
func (_u *GroupUpdateOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne {
_u.mutation.ResetSoraVideoPricePerRequestHd()
_u.mutation.SetSoraVideoPricePerRequestHd(v)
return _u
}
// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdateOne {
if v != nil {
_u.SetSoraVideoPricePerRequestHd(*v)
}
return _u
}
// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field.
func (_u *GroupUpdateOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne {
_u.mutation.AddSoraVideoPricePerRequestHd(v)
return _u
}
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequestHd() *GroupUpdateOne {
_u.mutation.ClearSoraVideoPricePerRequestHd()
return _u
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne {
_u.mutation.SetClaudeCodeOnly(v)
@@ -2140,6 +2392,42 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if _u.mutation.ImagePrice4kCleared() {
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
}
if value, ok := _u.mutation.SoraImagePrice360(); ok {
_spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedSoraImagePrice360(); ok {
_spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
}
if _u.mutation.SoraImagePrice360Cleared() {
_spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64)
}
if value, ok := _u.mutation.SoraImagePrice540(); ok {
_spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedSoraImagePrice540(); ok {
_spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
}
if _u.mutation.SoraImagePrice540Cleared() {
_spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64)
}
if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok {
_spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok {
_spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
}
if _u.mutation.SoraVideoPricePerRequestCleared() {
_spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64)
}
if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok {
_spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok {
_spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
}
if _u.mutation.SoraVideoPricePerRequestHdCleared() {
_spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64)
}
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
}

View File

@@ -141,6 +141,18 @@ func (f RedeemCodeFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value,
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.RedeemCodeMutation", m)
}
// The SecuritySecretFunc type is an adapter to allow the use of ordinary
// function as SecuritySecret mutator.
type SecuritySecretFunc func(context.Context, *ent.SecuritySecretMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f SecuritySecretFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if mv, ok := m.(*ent.SecuritySecretMutation); ok {
return f(ctx, mv)
}
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SecuritySecretMutation", m)
}
// The SettingFunc type is an adapter to allow the use of ordinary
// function as Setting mutator.
type SettingFunc func(context.Context, *ent.SettingMutation) (ent.Value, error)

View File

@@ -20,6 +20,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/ent/setting"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
@@ -383,6 +384,33 @@ func (f TraverseRedeemCode) Traverse(ctx context.Context, q ent.Query) error {
return fmt.Errorf("unexpected query type %T. expect *ent.RedeemCodeQuery", q)
}
// The SecuritySecretFunc type is an adapter to allow the use of ordinary function as a Querier.
type SecuritySecretFunc func(context.Context, *ent.SecuritySecretQuery) (ent.Value, error)
// Query calls f(ctx, q).
func (f SecuritySecretFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
if q, ok := q.(*ent.SecuritySecretQuery); ok {
return f(ctx, q)
}
return nil, fmt.Errorf("unexpected query type %T. expect *ent.SecuritySecretQuery", q)
}
// The TraverseSecuritySecret type is an adapter to allow the use of ordinary function as Traverser.
type TraverseSecuritySecret func(context.Context, *ent.SecuritySecretQuery) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f TraverseSecuritySecret) Intercept(next ent.Querier) ent.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f TraverseSecuritySecret) Traverse(ctx context.Context, q ent.Query) error {
if q, ok := q.(*ent.SecuritySecretQuery); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect *ent.SecuritySecretQuery", q)
}
// The SettingFunc type is an adapter to allow the use of ordinary function as a Querier.
type SettingFunc func(context.Context, *ent.SettingQuery) (ent.Value, error)
@@ -624,6 +652,8 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.ProxyQuery, predicate.Proxy, proxy.OrderOption]{typ: ent.TypeProxy, tq: q}, nil
case *ent.RedeemCodeQuery:
return &query[*ent.RedeemCodeQuery, predicate.RedeemCode, redeemcode.OrderOption]{typ: ent.TypeRedeemCode, tq: q}, nil
case *ent.SecuritySecretQuery:
return &query[*ent.SecuritySecretQuery, predicate.SecuritySecret, securitysecret.OrderOption]{typ: ent.TypeSecuritySecret, tq: q}, nil
case *ent.SettingQuery:
return &query[*ent.SettingQuery, predicate.Setting, setting.OrderOption]{typ: ent.TypeSetting, tq: q}, nil
case *ent.UsageCleanupTaskQuery:

View File

@@ -18,6 +18,7 @@ var (
{Name: "key", Type: field.TypeString, Unique: true, Size: 128},
{Name: "name", Type: field.TypeString, Size: 100},
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
{Name: "last_used_at", Type: field.TypeTime, Nullable: true},
{Name: "ip_whitelist", Type: field.TypeJSON, Nullable: true},
{Name: "ip_blacklist", Type: field.TypeJSON, Nullable: true},
{Name: "quota", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
@@ -34,13 +35,13 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "api_keys_groups_api_keys",
Columns: []*schema.Column{APIKeysColumns[12]},
Columns: []*schema.Column{APIKeysColumns[13]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "api_keys_users_api_keys",
Columns: []*schema.Column{APIKeysColumns[13]},
Columns: []*schema.Column{APIKeysColumns[14]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
@@ -49,12 +50,12 @@ var (
{
Name: "apikey_user_id",
Unique: false,
Columns: []*schema.Column{APIKeysColumns[13]},
Columns: []*schema.Column{APIKeysColumns[14]},
},
{
Name: "apikey_group_id",
Unique: false,
Columns: []*schema.Column{APIKeysColumns[12]},
Columns: []*schema.Column{APIKeysColumns[13]},
},
{
Name: "apikey_status",
@@ -66,15 +67,20 @@ var (
Unique: false,
Columns: []*schema.Column{APIKeysColumns[3]},
},
{
Name: "apikey_last_used_at",
Unique: false,
Columns: []*schema.Column{APIKeysColumns[7]},
},
{
Name: "apikey_quota_quota_used",
Unique: false,
Columns: []*schema.Column{APIKeysColumns[9], APIKeysColumns[10]},
Columns: []*schema.Column{APIKeysColumns[10], APIKeysColumns[11]},
},
{
Name: "apikey_expires_at",
Unique: false,
Columns: []*schema.Column{APIKeysColumns[11]},
Columns: []*schema.Column{APIKeysColumns[12]},
},
},
}
@@ -366,6 +372,10 @@ var (
{Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "sora_image_price_360", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "sora_image_price_540", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "sora_video_price_per_request", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "sora_video_price_per_request_hd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
{Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true},
@@ -409,7 +419,7 @@ var (
{
Name: "group_sort_order",
Unique: false,
Columns: []*schema.Column{GroupsColumns[25]},
Columns: []*schema.Column{GroupsColumns[29]},
},
},
}
@@ -572,6 +582,20 @@ var (
},
},
}
// SecuritySecretsColumns holds the columns for the "security_secrets" table.
SecuritySecretsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "key", Type: field.TypeString, Unique: true, Size: 100},
{Name: "value", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
}
// SecuritySecretsTable holds the schema information for the "security_secrets" table.
SecuritySecretsTable = &schema.Table{
Name: "security_secrets",
Columns: SecuritySecretsColumns,
PrimaryKey: []*schema.Column{SecuritySecretsColumns[0]},
}
// SettingsColumns holds the columns for the "settings" table.
SettingsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -650,6 +674,8 @@ var (
{Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45},
{Name: "image_count", Type: field.TypeInt, Default: 0},
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
{Name: "media_type", Type: field.TypeString, Nullable: true, Size: 16},
{Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "api_key_id", Type: field.TypeInt64},
{Name: "account_id", Type: field.TypeInt64},
@@ -665,31 +691,31 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "usage_logs_api_keys_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[26]},
Columns: []*schema.Column{UsageLogsColumns[28]},
RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_accounts_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[29]},
RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_groups_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[28]},
Columns: []*schema.Column{UsageLogsColumns[30]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "usage_logs_users_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[29]},
Columns: []*schema.Column{UsageLogsColumns[31]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_user_subscriptions_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[30]},
Columns: []*schema.Column{UsageLogsColumns[32]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull,
},
@@ -698,32 +724,32 @@ var (
{
Name: "usagelog_user_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[29]},
Columns: []*schema.Column{UsageLogsColumns[31]},
},
{
Name: "usagelog_api_key_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[26]},
Columns: []*schema.Column{UsageLogsColumns[28]},
},
{
Name: "usagelog_account_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[29]},
},
{
Name: "usagelog_group_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[28]},
Columns: []*schema.Column{UsageLogsColumns[30]},
},
{
Name: "usagelog_subscription_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[30]},
Columns: []*schema.Column{UsageLogsColumns[32]},
},
{
Name: "usagelog_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[25]},
Columns: []*schema.Column{UsageLogsColumns[27]},
},
{
Name: "usagelog_model",
@@ -738,12 +764,12 @@ var (
{
Name: "usagelog_user_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[25]},
Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[27]},
},
{
Name: "usagelog_api_key_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[26], UsageLogsColumns[25]},
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]},
},
},
}
@@ -999,6 +1025,7 @@ var (
PromoCodeUsagesTable,
ProxiesTable,
RedeemCodesTable,
SecuritySecretsTable,
SettingsTable,
UsageCleanupTasksTable,
UsageLogsTable,
@@ -1055,6 +1082,9 @@ func init() {
RedeemCodesTable.Annotation = &entsql.Annotation{
Table: "redeem_codes",
}
SecuritySecretsTable.Annotation = &entsql.Annotation{
Table: "security_secrets",
}
SettingsTable.Annotation = &entsql.Annotation{
Table: "settings",
}

File diff suppressed because it is too large Load Diff

View File

@@ -39,6 +39,9 @@ type Proxy func(*sql.Selector)
// RedeemCode is the predicate function for redeemcode builders.
type RedeemCode func(*sql.Selector)
// SecuritySecret is the predicate function for securitysecret builders.
type SecuritySecret func(*sql.Selector)
// Setting is the predicate function for setting builders.
type Setting func(*sql.Selector)

View File

@@ -17,6 +17,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/schema"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/ent/setting"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
@@ -93,11 +94,11 @@ func init() {
// apikey.StatusValidator is a validator for the "status" field. It is called by the builders before save.
apikey.StatusValidator = apikeyDescStatus.Validators[0].(func(string) error)
// apikeyDescQuota is the schema descriptor for quota field.
apikeyDescQuota := apikeyFields[7].Descriptor()
apikeyDescQuota := apikeyFields[8].Descriptor()
// apikey.DefaultQuota holds the default value on creation for the quota field.
apikey.DefaultQuota = apikeyDescQuota.Default.(float64)
// apikeyDescQuotaUsed is the schema descriptor for quota_used field.
apikeyDescQuotaUsed := apikeyFields[8].Descriptor()
apikeyDescQuotaUsed := apikeyFields[9].Descriptor()
// apikey.DefaultQuotaUsed holds the default value on creation for the quota_used field.
apikey.DefaultQuotaUsed = apikeyDescQuotaUsed.Default.(float64)
accountMixin := schema.Account{}.Mixin()
@@ -398,23 +399,23 @@ func init() {
// group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field.
group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int)
// groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field.
groupDescClaudeCodeOnly := groupFields[14].Descriptor()
groupDescClaudeCodeOnly := groupFields[18].Descriptor()
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
groupDescModelRoutingEnabled := groupFields[18].Descriptor()
groupDescModelRoutingEnabled := groupFields[22].Descriptor()
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
// groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field.
groupDescMcpXMLInject := groupFields[19].Descriptor()
groupDescMcpXMLInject := groupFields[23].Descriptor()
// group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field.
group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool)
// groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field.
groupDescSupportedModelScopes := groupFields[20].Descriptor()
groupDescSupportedModelScopes := groupFields[24].Descriptor()
// group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field.
group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string)
// groupDescSortOrder is the schema descriptor for sort_order field.
groupDescSortOrder := groupFields[21].Descriptor()
groupDescSortOrder := groupFields[25].Descriptor()
// group.DefaultSortOrder holds the default value on creation for the sort_order field.
group.DefaultSortOrder = groupDescSortOrder.Default.(int)
promocodeFields := schema.PromoCode{}.Fields()
@@ -602,6 +603,43 @@ func init() {
redeemcodeDescValidityDays := redeemcodeFields[9].Descriptor()
// redeemcode.DefaultValidityDays holds the default value on creation for the validity_days field.
redeemcode.DefaultValidityDays = redeemcodeDescValidityDays.Default.(int)
securitysecretMixin := schema.SecuritySecret{}.Mixin()
securitysecretMixinFields0 := securitysecretMixin[0].Fields()
_ = securitysecretMixinFields0
securitysecretFields := schema.SecuritySecret{}.Fields()
_ = securitysecretFields
// securitysecretDescCreatedAt is the schema descriptor for created_at field.
securitysecretDescCreatedAt := securitysecretMixinFields0[0].Descriptor()
// securitysecret.DefaultCreatedAt holds the default value on creation for the created_at field.
securitysecret.DefaultCreatedAt = securitysecretDescCreatedAt.Default.(func() time.Time)
// securitysecretDescUpdatedAt is the schema descriptor for updated_at field.
securitysecretDescUpdatedAt := securitysecretMixinFields0[1].Descriptor()
// securitysecret.DefaultUpdatedAt holds the default value on creation for the updated_at field.
securitysecret.DefaultUpdatedAt = securitysecretDescUpdatedAt.Default.(func() time.Time)
// securitysecret.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
securitysecret.UpdateDefaultUpdatedAt = securitysecretDescUpdatedAt.UpdateDefault.(func() time.Time)
// securitysecretDescKey is the schema descriptor for key field.
securitysecretDescKey := securitysecretFields[0].Descriptor()
// securitysecret.KeyValidator is a validator for the "key" field. It is called by the builders before save.
securitysecret.KeyValidator = func() func(string) error {
validators := securitysecretDescKey.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
}
return func(key string) error {
for _, fn := range fns {
if err := fn(key); err != nil {
return err
}
}
return nil
}
}()
// securitysecretDescValue is the schema descriptor for value field.
securitysecretDescValue := securitysecretFields[1].Descriptor()
// securitysecret.ValueValidator is a validator for the "value" field. It is called by the builders before save.
securitysecret.ValueValidator = securitysecretDescValue.Validators[0].(func(string) error)
settingFields := schema.Setting{}.Fields()
_ = settingFields
// settingDescKey is the schema descriptor for key field.
@@ -779,8 +817,16 @@ func init() {
usagelogDescImageSize := usagelogFields[28].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
// usagelogDescMediaType is the schema descriptor for media_type field.
usagelogDescMediaType := usagelogFields[29].Descriptor()
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
usagelogDescCacheTTLOverridden := usagelogFields[30].Descriptor()
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
// usagelogDescCreatedAt is the schema descriptor for created_at field.
usagelogDescCreatedAt := usagelogFields[29].Descriptor()
usagelogDescCreatedAt := usagelogFields[31].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin()

View File

@@ -47,6 +47,10 @@ func (APIKey) Fields() []ent.Field {
field.String("status").
MaxLen(20).
Default(domain.StatusActive),
field.Time("last_used_at").
Optional().
Nillable().
Comment("Last usage time of this API key"),
field.JSON("ip_whitelist", []string{}).
Optional().
Comment("Allowed IPs/CIDRs, e.g. [\"192.168.1.100\", \"10.0.0.0/8\"]"),
@@ -95,6 +99,7 @@ func (APIKey) Indexes() []ent.Index {
index.Fields("group_id"),
index.Fields("status"),
index.Fields("deleted_at"),
index.Fields("last_used_at"),
// Index for quota queries
index.Fields("quota", "quota_used"),
index.Fields("expires_at"),

View File

@@ -87,6 +87,24 @@ func (Group) Fields() []ent.Field {
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
// Sora 按次计费配置(阶段 1
field.Float("sora_image_price_360").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
field.Float("sora_image_price_540").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
field.Float("sora_video_price_per_request").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
field.Float("sora_video_price_per_request_hd").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
// Claude Code 客户端限制 (added by migration 029)
field.Bool("claude_code_only").
Default(false).

View File

@@ -0,0 +1,50 @@
package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"entgo.io/ent"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// IdempotencyRecord 幂等请求记录表。
type IdempotencyRecord struct {
ent.Schema
}
func (IdempotencyRecord) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "idempotency_records"},
}
}
func (IdempotencyRecord) Mixin() []ent.Mixin {
return []ent.Mixin{
mixins.TimeMixin{},
}
}
func (IdempotencyRecord) Fields() []ent.Field {
return []ent.Field{
field.String("scope").MaxLen(128),
field.String("idempotency_key_hash").MaxLen(64),
field.String("request_fingerprint").MaxLen(64),
field.String("status").MaxLen(32),
field.Int("response_status").Optional().Nillable(),
field.String("response_body").Optional().Nillable(),
field.String("error_reason").MaxLen(128).Optional().Nillable(),
field.Time("locked_until").Optional().Nillable(),
field.Time("expires_at"),
}
}
func (IdempotencyRecord) Indexes() []ent.Index {
return []ent.Index{
index.Fields("scope", "idempotency_key_hash").Unique(),
index.Fields("expires_at"),
index.Fields("status", "locked_until"),
}
}

View File

@@ -0,0 +1,42 @@
package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/field"
)
// SecuritySecret 存储系统级安全密钥(如 JWT 签名密钥、TOTP 加密密钥)。
type SecuritySecret struct {
ent.Schema
}
func (SecuritySecret) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "security_secrets"},
}
}
func (SecuritySecret) Mixin() []ent.Mixin {
return []ent.Mixin{
mixins.TimeMixin{},
}
}
func (SecuritySecret) Fields() []ent.Field {
return []ent.Field{
field.String("key").
MaxLen(100).
NotEmpty().
Unique(),
field.String("value").
NotEmpty().
SchemaType(map[string]string{
dialect.Postgres: "text",
}),
}
}

View File

@@ -118,6 +118,15 @@ func (UsageLog) Fields() []ent.Field {
MaxLen(10).
Optional().
Nillable(),
// 媒体类型字段sora 使用)
field.String("media_type").
MaxLen(16).
Optional().
Nillable(),
// Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费)
field.Bool("cache_ttl_overridden").
Default(false),
// 时间戳(只有 created_at日志不可修改
field.Time("created_at").

View File

@@ -0,0 +1,139 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"fmt"
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
)
// SecuritySecret is the model entity for the SecuritySecret schema.
type SecuritySecret struct {
config `json:"-"`
// ID of the ent.
ID int64 `json:"id,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// UpdatedAt holds the value of the "updated_at" field.
UpdatedAt time.Time `json:"updated_at,omitempty"`
// Key holds the value of the "key" field.
Key string `json:"key,omitempty"`
// Value holds the value of the "value" field.
Value string `json:"value,omitempty"`
selectValues sql.SelectValues
}
// scanValues returns the types for scanning values from sql.Rows.
func (*SecuritySecret) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case securitysecret.FieldID:
values[i] = new(sql.NullInt64)
case securitysecret.FieldKey, securitysecret.FieldValue:
values[i] = new(sql.NullString)
case securitysecret.FieldCreatedAt, securitysecret.FieldUpdatedAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
}
}
return values, nil
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the SecuritySecret fields.
func (_m *SecuritySecret) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
for i := range columns {
switch columns[i] {
case securitysecret.FieldID:
value, ok := values[i].(*sql.NullInt64)
if !ok {
return fmt.Errorf("unexpected type %T for field id", value)
}
_m.ID = int64(value.Int64)
case securitysecret.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
} else if value.Valid {
_m.CreatedAt = value.Time
}
case securitysecret.FieldUpdatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
} else if value.Valid {
_m.UpdatedAt = value.Time
}
case securitysecret.FieldKey:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field key", values[i])
} else if value.Valid {
_m.Key = value.String
}
case securitysecret.FieldValue:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field value", values[i])
} else if value.Valid {
_m.Value = value.String
}
default:
_m.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// GetValue returns the ent.Value that was dynamically selected and assigned to the SecuritySecret.
// This includes values selected through modifiers, order, etc.
func (_m *SecuritySecret) GetValue(name string) (ent.Value, error) {
return _m.selectValues.Get(name)
}
// Update returns a builder for updating this SecuritySecret.
// Note that you need to call SecuritySecret.Unwrap() before calling this method if this SecuritySecret
// was returned from a transaction, and the transaction was committed or rolled back.
func (_m *SecuritySecret) Update() *SecuritySecretUpdateOne {
return NewSecuritySecretClient(_m.config).UpdateOne(_m)
}
// Unwrap unwraps the SecuritySecret entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func (_m *SecuritySecret) Unwrap() *SecuritySecret {
_tx, ok := _m.config.driver.(*txDriver)
if !ok {
panic("ent: SecuritySecret is not a transactional entity")
}
_m.config.driver = _tx.drv
return _m
}
// String implements the fmt.Stringer.
func (_m *SecuritySecret) String() string {
var builder strings.Builder
builder.WriteString("SecuritySecret(")
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("updated_at=")
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("key=")
builder.WriteString(_m.Key)
builder.WriteString(", ")
builder.WriteString("value=")
builder.WriteString(_m.Value)
builder.WriteByte(')')
return builder.String()
}
// SecuritySecrets is a parsable slice of SecuritySecret.
type SecuritySecrets []*SecuritySecret

View File

@@ -0,0 +1,86 @@
// Code generated by ent, DO NOT EDIT.
package securitysecret
import (
"time"
"entgo.io/ent/dialect/sql"
)
const (
// Label holds the string label denoting the securitysecret type in the database.
Label = "security_secret"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
FieldUpdatedAt = "updated_at"
// FieldKey holds the string denoting the key field in the database.
FieldKey = "key"
// FieldValue holds the string denoting the value field in the database.
FieldValue = "value"
// Table holds the table name of the securitysecret in the database.
Table = "security_secrets"
)
// Columns holds all SQL columns for securitysecret fields.
var Columns = []string{
FieldID,
FieldCreatedAt,
FieldUpdatedAt,
FieldKey,
FieldValue,
}
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
if column == Columns[i] {
return true
}
}
return false
}
var (
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
DefaultUpdatedAt func() time.Time
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
UpdateDefaultUpdatedAt func() time.Time
// KeyValidator is a validator for the "key" field. It is called by the builders before save.
KeyValidator func(string) error
// ValueValidator is a validator for the "value" field. It is called by the builders before save.
ValueValidator func(string) error
)
// OrderOption defines the ordering options for the SecuritySecret queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByUpdatedAt orders the results by the updated_at field.
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// ByKey orders the results by the key field.
func ByKey(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldKey, opts...).ToFunc()
}
// ByValue orders the results by the value field.
func ByValue(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldValue, opts...).ToFunc()
}

View File

@@ -0,0 +1,300 @@
// Code generated by ent, DO NOT EDIT.
package securitysecret
import (
"time"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// ID filters vertices based on their ID field.
func ID(id int64) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldEQ(FieldID, id))
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id int64) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldEQ(FieldID, id))
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id int64) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldNEQ(FieldID, id))
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...int64) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldIn(FieldID, ids...))
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...int64) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldNotIn(FieldID, ids...))
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id int64) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldGT(FieldID, id))
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id int64) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldGTE(FieldID, id))
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id int64) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldLT(FieldID, id))
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id int64) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldLTE(FieldID, id))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldEQ(FieldCreatedAt, v))
}
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
func UpdatedAt(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldEQ(FieldUpdatedAt, v))
}
// Key applies equality check predicate on the "key" field. It's identical to KeyEQ.
func Key(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldEQ(FieldKey, v))
}
// Value applies equality check predicate on the "value" field. It's identical to ValueEQ.
func Value(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldEQ(FieldValue, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldEQ(FieldCreatedAt, v))
}
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
func CreatedAtNEQ(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldNEQ(FieldCreatedAt, v))
}
// CreatedAtIn applies the In predicate on the "created_at" field.
func CreatedAtIn(vs ...time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldIn(FieldCreatedAt, vs...))
}
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
func CreatedAtNotIn(vs ...time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldNotIn(FieldCreatedAt, vs...))
}
// CreatedAtGT applies the GT predicate on the "created_at" field.
func CreatedAtGT(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldGT(FieldCreatedAt, v))
}
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
func CreatedAtGTE(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldGTE(FieldCreatedAt, v))
}
// CreatedAtLT applies the LT predicate on the "created_at" field.
func CreatedAtLT(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldLT(FieldCreatedAt, v))
}
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
func CreatedAtLTE(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldLTE(FieldCreatedAt, v))
}
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
func UpdatedAtEQ(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldEQ(FieldUpdatedAt, v))
}
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
func UpdatedAtNEQ(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldNEQ(FieldUpdatedAt, v))
}
// UpdatedAtIn applies the In predicate on the "updated_at" field.
func UpdatedAtIn(vs ...time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldIn(FieldUpdatedAt, vs...))
}
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
func UpdatedAtNotIn(vs ...time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldNotIn(FieldUpdatedAt, vs...))
}
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
func UpdatedAtGT(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldGT(FieldUpdatedAt, v))
}
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
func UpdatedAtGTE(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldGTE(FieldUpdatedAt, v))
}
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
func UpdatedAtLT(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldLT(FieldUpdatedAt, v))
}
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
func UpdatedAtLTE(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldLTE(FieldUpdatedAt, v))
}
// KeyEQ applies the EQ predicate on the "key" field.
func KeyEQ(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldEQ(FieldKey, v))
}
// KeyNEQ applies the NEQ predicate on the "key" field.
func KeyNEQ(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldNEQ(FieldKey, v))
}
// KeyIn applies the In predicate on the "key" field.
func KeyIn(vs ...string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldIn(FieldKey, vs...))
}
// KeyNotIn applies the NotIn predicate on the "key" field.
func KeyNotIn(vs ...string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldNotIn(FieldKey, vs...))
}
// KeyGT applies the GT predicate on the "key" field.
func KeyGT(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldGT(FieldKey, v))
}
// KeyGTE applies the GTE predicate on the "key" field.
func KeyGTE(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldGTE(FieldKey, v))
}
// KeyLT applies the LT predicate on the "key" field.
func KeyLT(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldLT(FieldKey, v))
}
// KeyLTE applies the LTE predicate on the "key" field.
func KeyLTE(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldLTE(FieldKey, v))
}
// KeyContains applies the Contains predicate on the "key" field.
func KeyContains(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldContains(FieldKey, v))
}
// KeyHasPrefix applies the HasPrefix predicate on the "key" field.
func KeyHasPrefix(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldHasPrefix(FieldKey, v))
}
// KeyHasSuffix applies the HasSuffix predicate on the "key" field.
func KeyHasSuffix(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldHasSuffix(FieldKey, v))
}
// KeyEqualFold applies the EqualFold predicate on the "key" field.
func KeyEqualFold(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldEqualFold(FieldKey, v))
}
// KeyContainsFold applies the ContainsFold predicate on the "key" field.
func KeyContainsFold(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldContainsFold(FieldKey, v))
}
// ValueEQ applies the EQ predicate on the "value" field.
func ValueEQ(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldEQ(FieldValue, v))
}
// ValueNEQ applies the NEQ predicate on the "value" field.
func ValueNEQ(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldNEQ(FieldValue, v))
}
// ValueIn applies the In predicate on the "value" field.
func ValueIn(vs ...string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldIn(FieldValue, vs...))
}
// ValueNotIn applies the NotIn predicate on the "value" field.
func ValueNotIn(vs ...string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldNotIn(FieldValue, vs...))
}
// ValueGT applies the GT predicate on the "value" field.
func ValueGT(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldGT(FieldValue, v))
}
// ValueGTE applies the GTE predicate on the "value" field.
func ValueGTE(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldGTE(FieldValue, v))
}
// ValueLT applies the LT predicate on the "value" field.
func ValueLT(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldLT(FieldValue, v))
}
// ValueLTE applies the LTE predicate on the "value" field.
func ValueLTE(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldLTE(FieldValue, v))
}
// ValueContains applies the Contains predicate on the "value" field.
func ValueContains(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldContains(FieldValue, v))
}
// ValueHasPrefix applies the HasPrefix predicate on the "value" field.
func ValueHasPrefix(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldHasPrefix(FieldValue, v))
}
// ValueHasSuffix applies the HasSuffix predicate on the "value" field.
func ValueHasSuffix(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldHasSuffix(FieldValue, v))
}
// ValueEqualFold applies the EqualFold predicate on the "value" field.
func ValueEqualFold(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldEqualFold(FieldValue, v))
}
// ValueContainsFold applies the ContainsFold predicate on the "value" field.
func ValueContainsFold(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldContainsFold(FieldValue, v))
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.SecuritySecret) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.AndPredicates(predicates...))
}
// Or groups predicates with the OR operator between them.
func Or(predicates ...predicate.SecuritySecret) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.OrPredicates(predicates...))
}
// Not applies the not operator on the given predicate.
func Not(p predicate.SecuritySecret) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.NotPredicates(p))
}

View File

@@ -0,0 +1,626 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
)
// SecuritySecretCreate is the builder for creating a SecuritySecret entity.
type SecuritySecretCreate struct {
config
mutation *SecuritySecretMutation
hooks []Hook
conflict []sql.ConflictOption
}
// SetCreatedAt sets the "created_at" field.
func (_c *SecuritySecretCreate) SetCreatedAt(v time.Time) *SecuritySecretCreate {
_c.mutation.SetCreatedAt(v)
return _c
}
// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
func (_c *SecuritySecretCreate) SetNillableCreatedAt(v *time.Time) *SecuritySecretCreate {
if v != nil {
_c.SetCreatedAt(*v)
}
return _c
}
// SetUpdatedAt sets the "updated_at" field.
func (_c *SecuritySecretCreate) SetUpdatedAt(v time.Time) *SecuritySecretCreate {
_c.mutation.SetUpdatedAt(v)
return _c
}
// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
func (_c *SecuritySecretCreate) SetNillableUpdatedAt(v *time.Time) *SecuritySecretCreate {
if v != nil {
_c.SetUpdatedAt(*v)
}
return _c
}
// SetKey sets the "key" field.
func (_c *SecuritySecretCreate) SetKey(v string) *SecuritySecretCreate {
_c.mutation.SetKey(v)
return _c
}
// SetValue sets the "value" field.
func (_c *SecuritySecretCreate) SetValue(v string) *SecuritySecretCreate {
_c.mutation.SetValue(v)
return _c
}
// Mutation returns the SecuritySecretMutation object of the builder.
func (_c *SecuritySecretCreate) Mutation() *SecuritySecretMutation {
return _c.mutation
}
// Save creates the SecuritySecret in the database.
func (_c *SecuritySecretCreate) Save(ctx context.Context) (*SecuritySecret, error) {
_c.defaults()
return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
}
// SaveX calls Save and panics if Save returns an error.
func (_c *SecuritySecretCreate) SaveX(ctx context.Context) *SecuritySecret {
v, err := _c.Save(ctx)
if err != nil {
panic(err)
}
return v
}
// Exec executes the query.
func (_c *SecuritySecretCreate) Exec(ctx context.Context) error {
_, err := _c.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_c *SecuritySecretCreate) ExecX(ctx context.Context) {
if err := _c.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_c *SecuritySecretCreate) defaults() {
if _, ok := _c.mutation.CreatedAt(); !ok {
v := securitysecret.DefaultCreatedAt()
_c.mutation.SetCreatedAt(v)
}
if _, ok := _c.mutation.UpdatedAt(); !ok {
v := securitysecret.DefaultUpdatedAt()
_c.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_c *SecuritySecretCreate) check() error {
if _, ok := _c.mutation.CreatedAt(); !ok {
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "SecuritySecret.created_at"`)}
}
if _, ok := _c.mutation.UpdatedAt(); !ok {
return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "SecuritySecret.updated_at"`)}
}
if _, ok := _c.mutation.Key(); !ok {
return &ValidationError{Name: "key", err: errors.New(`ent: missing required field "SecuritySecret.key"`)}
}
if v, ok := _c.mutation.Key(); ok {
if err := securitysecret.KeyValidator(v); err != nil {
return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.key": %w`, err)}
}
}
if _, ok := _c.mutation.Value(); !ok {
return &ValidationError{Name: "value", err: errors.New(`ent: missing required field "SecuritySecret.value"`)}
}
if v, ok := _c.mutation.Value(); ok {
if err := securitysecret.ValueValidator(v); err != nil {
return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.value": %w`, err)}
}
}
return nil
}
func (_c *SecuritySecretCreate) sqlSave(ctx context.Context) (*SecuritySecret, error) {
if err := _c.check(); err != nil {
return nil, err
}
_node, _spec := _c.createSpec()
if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
id := _spec.ID.Value.(int64)
_node.ID = int64(id)
_c.mutation.id = &_node.ID
_c.mutation.done = true
return _node, nil
}
func (_c *SecuritySecretCreate) createSpec() (*SecuritySecret, *sqlgraph.CreateSpec) {
var (
_node = &SecuritySecret{config: _c.config}
_spec = sqlgraph.NewCreateSpec(securitysecret.Table, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64))
)
_spec.OnConflict = _c.conflict
if value, ok := _c.mutation.CreatedAt(); ok {
_spec.SetField(securitysecret.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = value
}
if value, ok := _c.mutation.UpdatedAt(); ok {
_spec.SetField(securitysecret.FieldUpdatedAt, field.TypeTime, value)
_node.UpdatedAt = value
}
if value, ok := _c.mutation.Key(); ok {
_spec.SetField(securitysecret.FieldKey, field.TypeString, value)
_node.Key = value
}
if value, ok := _c.mutation.Value(); ok {
_spec.SetField(securitysecret.FieldValue, field.TypeString, value)
_node.Value = value
}
return _node, _spec
}
// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
// of the `INSERT` statement. For example:
//
// client.SecuritySecret.Create().
// SetCreatedAt(v).
// OnConflict(
// // Update the row with the new values
// // the was proposed for insertion.
// sql.ResolveWithNewValues(),
// ).
// // Override some of the fields with custom
// // update values.
// Update(func(u *ent.SecuritySecretUpsert) {
// SetCreatedAt(v+v).
// }).
// Exec(ctx)
func (_c *SecuritySecretCreate) OnConflict(opts ...sql.ConflictOption) *SecuritySecretUpsertOne {
_c.conflict = opts
return &SecuritySecretUpsertOne{
create: _c,
}
}
// OnConflictColumns calls `OnConflict` and configures the columns
// as conflict target. Using this option is equivalent to using:
//
// client.SecuritySecret.Create().
// OnConflict(sql.ConflictColumns(columns...)).
// Exec(ctx)
func (_c *SecuritySecretCreate) OnConflictColumns(columns ...string) *SecuritySecretUpsertOne {
_c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
return &SecuritySecretUpsertOne{
create: _c,
}
}
type (
// SecuritySecretUpsertOne is the builder for "upsert"-ing
// one SecuritySecret node.
SecuritySecretUpsertOne struct {
create *SecuritySecretCreate
}
// SecuritySecretUpsert is the "OnConflict" setter.
SecuritySecretUpsert struct {
*sql.UpdateSet
}
)
// SetUpdatedAt sets the "updated_at" field.
func (u *SecuritySecretUpsert) SetUpdatedAt(v time.Time) *SecuritySecretUpsert {
u.Set(securitysecret.FieldUpdatedAt, v)
return u
}
// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
func (u *SecuritySecretUpsert) UpdateUpdatedAt() *SecuritySecretUpsert {
u.SetExcluded(securitysecret.FieldUpdatedAt)
return u
}
// SetKey sets the "key" field.
func (u *SecuritySecretUpsert) SetKey(v string) *SecuritySecretUpsert {
u.Set(securitysecret.FieldKey, v)
return u
}
// UpdateKey sets the "key" field to the value that was provided on create.
func (u *SecuritySecretUpsert) UpdateKey() *SecuritySecretUpsert {
u.SetExcluded(securitysecret.FieldKey)
return u
}
// SetValue sets the "value" field.
func (u *SecuritySecretUpsert) SetValue(v string) *SecuritySecretUpsert {
u.Set(securitysecret.FieldValue, v)
return u
}
// UpdateValue sets the "value" field to the value that was provided on create.
func (u *SecuritySecretUpsert) UpdateValue() *SecuritySecretUpsert {
u.SetExcluded(securitysecret.FieldValue)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
// client.SecuritySecret.Create().
// OnConflict(
// sql.ResolveWithNewValues(),
// ).
// Exec(ctx)
func (u *SecuritySecretUpsertOne) UpdateNewValues() *SecuritySecretUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
if _, exists := u.create.mutation.CreatedAt(); exists {
s.SetIgnore(securitysecret.FieldCreatedAt)
}
}))
return u
}
// Ignore sets each column to itself in case of conflict.
// Using this option is equivalent to using:
//
// client.SecuritySecret.Create().
// OnConflict(sql.ResolveWithIgnore()).
// Exec(ctx)
func (u *SecuritySecretUpsertOne) Ignore() *SecuritySecretUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
return u
}
// DoNothing configures the conflict_action to `DO NOTHING`.
// Supported only by SQLite and PostgreSQL.
func (u *SecuritySecretUpsertOne) DoNothing() *SecuritySecretUpsertOne {
u.create.conflict = append(u.create.conflict, sql.DoNothing())
return u
}
// Update allows overriding fields `UPDATE` values. See the SecuritySecretCreate.OnConflict
// documentation for more info.
func (u *SecuritySecretUpsertOne) Update(set func(*SecuritySecretUpsert)) *SecuritySecretUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
set(&SecuritySecretUpsert{UpdateSet: update})
}))
return u
}
// SetUpdatedAt sets the "updated_at" field.
func (u *SecuritySecretUpsertOne) SetUpdatedAt(v time.Time) *SecuritySecretUpsertOne {
return u.Update(func(s *SecuritySecretUpsert) {
s.SetUpdatedAt(v)
})
}
// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
func (u *SecuritySecretUpsertOne) UpdateUpdatedAt() *SecuritySecretUpsertOne {
return u.Update(func(s *SecuritySecretUpsert) {
s.UpdateUpdatedAt()
})
}
// SetKey sets the "key" field.
func (u *SecuritySecretUpsertOne) SetKey(v string) *SecuritySecretUpsertOne {
return u.Update(func(s *SecuritySecretUpsert) {
s.SetKey(v)
})
}
// UpdateKey sets the "key" field to the value that was provided on create.
func (u *SecuritySecretUpsertOne) UpdateKey() *SecuritySecretUpsertOne {
return u.Update(func(s *SecuritySecretUpsert) {
s.UpdateKey()
})
}
// SetValue sets the "value" field.
func (u *SecuritySecretUpsertOne) SetValue(v string) *SecuritySecretUpsertOne {
return u.Update(func(s *SecuritySecretUpsert) {
s.SetValue(v)
})
}
// UpdateValue sets the "value" field to the value that was provided on create.
func (u *SecuritySecretUpsertOne) UpdateValue() *SecuritySecretUpsertOne {
return u.Update(func(s *SecuritySecretUpsert) {
s.UpdateValue()
})
}
// Exec executes the query.
func (u *SecuritySecretUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
return errors.New("ent: missing options for SecuritySecretCreate.OnConflict")
}
return u.create.Exec(ctx)
}
// ExecX is like Exec, but panics if an error occurs.
func (u *SecuritySecretUpsertOne) ExecX(ctx context.Context) {
if err := u.create.Exec(ctx); err != nil {
panic(err)
}
}
// Exec executes the UPSERT query and returns the inserted/updated ID.
func (u *SecuritySecretUpsertOne) ID(ctx context.Context) (id int64, err error) {
node, err := u.create.Save(ctx)
if err != nil {
return id, err
}
return node.ID, nil
}
// IDX is like ID, but panics if an error occurs.
func (u *SecuritySecretUpsertOne) IDX(ctx context.Context) int64 {
id, err := u.ID(ctx)
if err != nil {
panic(err)
}
return id
}
// SecuritySecretCreateBulk is the builder for creating many SecuritySecret entities in bulk.
type SecuritySecretCreateBulk struct {
config
err error
builders []*SecuritySecretCreate
conflict []sql.ConflictOption
}
// Save creates the SecuritySecret entities in the database.
func (_c *SecuritySecretCreateBulk) Save(ctx context.Context) ([]*SecuritySecret, error) {
if _c.err != nil {
return nil, _c.err
}
specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
nodes := make([]*SecuritySecret, len(_c.builders))
mutators := make([]Mutator, len(_c.builders))
for i := range _c.builders {
func(i int, root context.Context) {
builder := _c.builders[i]
builder.defaults()
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*SecuritySecretMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
if err := builder.check(); err != nil {
return nil, err
}
builder.mutation = mutation
var err error
nodes[i], specs[i] = builder.createSpec()
if i < len(mutators)-1 {
_, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
} else {
spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
spec.OnConflict = _c.conflict
// Invoke the actual operation on the latest mutation in the chain.
if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
}
}
if err != nil {
return nil, err
}
mutation.id = &nodes[i].ID
if specs[i].ID.Value != nil {
id := specs[i].ID.Value.(int64)
nodes[i].ID = int64(id)
}
mutation.done = true
return nodes[i], nil
})
for i := len(builder.hooks) - 1; i >= 0; i-- {
mut = builder.hooks[i](mut)
}
mutators[i] = mut
}(i, ctx)
}
if len(mutators) > 0 {
if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
return nil, err
}
}
return nodes, nil
}
// SaveX is like Save, but panics if an error occurs.
func (_c *SecuritySecretCreateBulk) SaveX(ctx context.Context) []*SecuritySecret {
v, err := _c.Save(ctx)
if err != nil {
panic(err)
}
return v
}
// Exec executes the query.
func (_c *SecuritySecretCreateBulk) Exec(ctx context.Context) error {
_, err := _c.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_c *SecuritySecretCreateBulk) ExecX(ctx context.Context) {
if err := _c.Exec(ctx); err != nil {
panic(err)
}
}
// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
// of the `INSERT` statement. For example:
//
// client.SecuritySecret.CreateBulk(builders...).
// OnConflict(
// // Update the row with the new values
// // the was proposed for insertion.
// sql.ResolveWithNewValues(),
// ).
// // Override some of the fields with custom
// // update values.
// Update(func(u *ent.SecuritySecretUpsert) {
// SetCreatedAt(v+v).
// }).
// Exec(ctx)
func (_c *SecuritySecretCreateBulk) OnConflict(opts ...sql.ConflictOption) *SecuritySecretUpsertBulk {
_c.conflict = opts
return &SecuritySecretUpsertBulk{
create: _c,
}
}
// OnConflictColumns calls `OnConflict` and configures the columns
// as conflict target. Using this option is equivalent to using:
//
// client.SecuritySecret.Create().
// OnConflict(sql.ConflictColumns(columns...)).
// Exec(ctx)
func (_c *SecuritySecretCreateBulk) OnConflictColumns(columns ...string) *SecuritySecretUpsertBulk {
_c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
return &SecuritySecretUpsertBulk{
create: _c,
}
}
// SecuritySecretUpsertBulk is the builder for "upsert"-ing
// a bulk of SecuritySecret nodes.
type SecuritySecretUpsertBulk struct {
create *SecuritySecretCreateBulk
}
// UpdateNewValues updates the mutable fields using the new values that
// were set on create. Using this option is equivalent to using:
//
// client.SecuritySecret.Create().
// OnConflict(
// sql.ResolveWithNewValues(),
// ).
// Exec(ctx)
func (u *SecuritySecretUpsertBulk) UpdateNewValues() *SecuritySecretUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
for _, b := range u.create.builders {
if _, exists := b.mutation.CreatedAt(); exists {
s.SetIgnore(securitysecret.FieldCreatedAt)
}
}
}))
return u
}
// Ignore sets each column to itself in case of conflict.
// Using this option is equivalent to using:
//
// client.SecuritySecret.Create().
// OnConflict(sql.ResolveWithIgnore()).
// Exec(ctx)
func (u *SecuritySecretUpsertBulk) Ignore() *SecuritySecretUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
return u
}
// DoNothing configures the conflict_action to `DO NOTHING`.
// Supported only by SQLite and PostgreSQL.
func (u *SecuritySecretUpsertBulk) DoNothing() *SecuritySecretUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.DoNothing())
return u
}
// Update allows overriding fields `UPDATE` values. See the SecuritySecretCreateBulk.OnConflict
// documentation for more info.
func (u *SecuritySecretUpsertBulk) Update(set func(*SecuritySecretUpsert)) *SecuritySecretUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
set(&SecuritySecretUpsert{UpdateSet: update})
}))
return u
}
// SetUpdatedAt sets the "updated_at" field.
func (u *SecuritySecretUpsertBulk) SetUpdatedAt(v time.Time) *SecuritySecretUpsertBulk {
return u.Update(func(s *SecuritySecretUpsert) {
s.SetUpdatedAt(v)
})
}
// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
func (u *SecuritySecretUpsertBulk) UpdateUpdatedAt() *SecuritySecretUpsertBulk {
return u.Update(func(s *SecuritySecretUpsert) {
s.UpdateUpdatedAt()
})
}
// SetKey sets the "key" field.
func (u *SecuritySecretUpsertBulk) SetKey(v string) *SecuritySecretUpsertBulk {
return u.Update(func(s *SecuritySecretUpsert) {
s.SetKey(v)
})
}
// UpdateKey sets the "key" field to the value that was provided on create.
func (u *SecuritySecretUpsertBulk) UpdateKey() *SecuritySecretUpsertBulk {
return u.Update(func(s *SecuritySecretUpsert) {
s.UpdateKey()
})
}
// SetValue sets the "value" field.
func (u *SecuritySecretUpsertBulk) SetValue(v string) *SecuritySecretUpsertBulk {
return u.Update(func(s *SecuritySecretUpsert) {
s.SetValue(v)
})
}
// UpdateValue sets the "value" field to the value that was provided on create.
func (u *SecuritySecretUpsertBulk) UpdateValue() *SecuritySecretUpsertBulk {
return u.Update(func(s *SecuritySecretUpsert) {
s.UpdateValue()
})
}
// Exec executes the query.
func (u *SecuritySecretUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
return u.create.err
}
for i, b := range u.create.builders {
if len(b.conflict) != 0 {
return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the SecuritySecretCreateBulk instead", i)
}
}
if len(u.create.conflict) == 0 {
return errors.New("ent: missing options for SecuritySecretCreateBulk.OnConflict")
}
return u.create.Exec(ctx)
}
// ExecX is like Exec, but panics if an error occurs.
func (u *SecuritySecretUpsertBulk) ExecX(ctx context.Context) {
if err := u.create.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,88 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
)
// SecuritySecretDelete is the builder for deleting a SecuritySecret entity.
type SecuritySecretDelete struct {
config
hooks []Hook
mutation *SecuritySecretMutation
}
// Where appends a list predicates to the SecuritySecretDelete builder.
func (_d *SecuritySecretDelete) Where(ps ...predicate.SecuritySecret) *SecuritySecretDelete {
_d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (_d *SecuritySecretDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *SecuritySecretDelete) ExecX(ctx context.Context) int {
n, err := _d.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (_d *SecuritySecretDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(securitysecret.Table, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64))
if ps := _d.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
_d.mutation.done = true
return affected, err
}
// SecuritySecretDeleteOne is the builder for deleting a single SecuritySecret entity.
type SecuritySecretDeleteOne struct {
_d *SecuritySecretDelete
}
// Where appends a list predicates to the SecuritySecretDelete builder.
func (_d *SecuritySecretDeleteOne) Where(ps ...predicate.SecuritySecret) *SecuritySecretDeleteOne {
_d._d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query.
func (_d *SecuritySecretDeleteOne) Exec(ctx context.Context) error {
n, err := _d._d.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{securitysecret.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *SecuritySecretDeleteOne) ExecX(ctx context.Context) {
if err := _d.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,564 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
)
// SecuritySecretQuery is the builder for querying SecuritySecret entities.
type SecuritySecretQuery struct {
config
ctx *QueryContext
order []securitysecret.OrderOption
inters []Interceptor
predicates []predicate.SecuritySecret
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the SecuritySecretQuery builder.
func (_q *SecuritySecretQuery) Where(ps ...predicate.SecuritySecret) *SecuritySecretQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *SecuritySecretQuery) Limit(limit int) *SecuritySecretQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *SecuritySecretQuery) Offset(offset int) *SecuritySecretQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *SecuritySecretQuery) Unique(unique bool) *SecuritySecretQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *SecuritySecretQuery) Order(o ...securitysecret.OrderOption) *SecuritySecretQuery {
_q.order = append(_q.order, o...)
return _q
}
// First returns the first SecuritySecret entity from the query.
// Returns a *NotFoundError when no SecuritySecret was found.
func (_q *SecuritySecretQuery) First(ctx context.Context) (*SecuritySecret, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{securitysecret.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *SecuritySecretQuery) FirstX(ctx context.Context) *SecuritySecret {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first SecuritySecret ID from the query.
// Returns a *NotFoundError when no SecuritySecret ID was found.
func (_q *SecuritySecretQuery) FirstID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{securitysecret.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *SecuritySecretQuery) FirstIDX(ctx context.Context) int64 {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single SecuritySecret entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one SecuritySecret entity is found.
// Returns a *NotFoundError when no SecuritySecret entities are found.
func (_q *SecuritySecretQuery) Only(ctx context.Context) (*SecuritySecret, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{securitysecret.Label}
default:
return nil, &NotSingularError{securitysecret.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *SecuritySecretQuery) OnlyX(ctx context.Context) *SecuritySecret {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only SecuritySecret ID in the query.
// Returns a *NotSingularError when more than one SecuritySecret ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *SecuritySecretQuery) OnlyID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{securitysecret.Label}
default:
err = &NotSingularError{securitysecret.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *SecuritySecretQuery) OnlyIDX(ctx context.Context) int64 {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of SecuritySecrets.
func (_q *SecuritySecretQuery) All(ctx context.Context) ([]*SecuritySecret, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*SecuritySecret, *SecuritySecretQuery]()
return withInterceptors[[]*SecuritySecret](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *SecuritySecretQuery) AllX(ctx context.Context) []*SecuritySecret {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of SecuritySecret IDs.
func (_q *SecuritySecretQuery) IDs(ctx context.Context) (ids []int64, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(securitysecret.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *SecuritySecretQuery) IDsX(ctx context.Context) []int64 {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *SecuritySecretQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*SecuritySecretQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *SecuritySecretQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *SecuritySecretQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *SecuritySecretQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the SecuritySecretQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *SecuritySecretQuery) Clone() *SecuritySecretQuery {
if _q == nil {
return nil
}
return &SecuritySecretQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]securitysecret.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.SecuritySecret{}, _q.predicates...),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.SecuritySecret.Query().
// GroupBy(securitysecret.FieldCreatedAt).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *SecuritySecretQuery) GroupBy(field string, fields ...string) *SecuritySecretGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &SecuritySecretGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = securitysecret.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// }
//
// client.SecuritySecret.Query().
// Select(securitysecret.FieldCreatedAt).
// Scan(ctx, &v)
func (_q *SecuritySecretQuery) Select(fields ...string) *SecuritySecretSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &SecuritySecretSelect{SecuritySecretQuery: _q}
sbuild.label = securitysecret.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a SecuritySecretSelect configured with the given aggregations.
func (_q *SecuritySecretQuery) Aggregate(fns ...AggregateFunc) *SecuritySecretSelect {
return _q.Select().Aggregate(fns...)
}
func (_q *SecuritySecretQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !securitysecret.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *SecuritySecretQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*SecuritySecret, error) {
var (
nodes = []*SecuritySecret{}
_spec = _q.querySpec()
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*SecuritySecret).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &SecuritySecret{config: _q.config}
nodes = append(nodes, node)
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
return nodes, nil
}
func (_q *SecuritySecretQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *SecuritySecretQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(securitysecret.Table, securitysecret.Columns, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, securitysecret.FieldID)
for i := range fields {
if fields[i] != securitysecret.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *SecuritySecretQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(securitysecret.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = securitysecret.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *SecuritySecretQuery) ForUpdate(opts ...sql.LockOption) *SecuritySecretQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *SecuritySecretQuery) ForShare(opts ...sql.LockOption) *SecuritySecretQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// SecuritySecretGroupBy is the group-by builder for SecuritySecret entities.
type SecuritySecretGroupBy struct {
selector
build *SecuritySecretQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *SecuritySecretGroupBy) Aggregate(fns ...AggregateFunc) *SecuritySecretGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *SecuritySecretGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*SecuritySecretQuery, *SecuritySecretGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *SecuritySecretGroupBy) sqlScan(ctx context.Context, root *SecuritySecretQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// SecuritySecretSelect is the builder for selecting fields of SecuritySecret entities.
type SecuritySecretSelect struct {
*SecuritySecretQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *SecuritySecretSelect) Aggregate(fns ...AggregateFunc) *SecuritySecretSelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *SecuritySecretSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*SecuritySecretQuery, *SecuritySecretSelect](ctx, _s.SecuritySecretQuery, _s, _s.inters, v)
}
func (_s *SecuritySecretSelect) sqlScan(ctx context.Context, root *SecuritySecretQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}

View File

@@ -0,0 +1,316 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
)
// SecuritySecretUpdate is the builder for updating SecuritySecret entities.
type SecuritySecretUpdate struct {
config
hooks []Hook
mutation *SecuritySecretMutation
}
// Where appends a list predicates to the SecuritySecretUpdate builder.
func (_u *SecuritySecretUpdate) Where(ps ...predicate.SecuritySecret) *SecuritySecretUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *SecuritySecretUpdate) SetUpdatedAt(v time.Time) *SecuritySecretUpdate {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetKey sets the "key" field.
func (_u *SecuritySecretUpdate) SetKey(v string) *SecuritySecretUpdate {
_u.mutation.SetKey(v)
return _u
}
// SetNillableKey sets the "key" field if the given value is not nil.
func (_u *SecuritySecretUpdate) SetNillableKey(v *string) *SecuritySecretUpdate {
if v != nil {
_u.SetKey(*v)
}
return _u
}
// SetValue sets the "value" field.
func (_u *SecuritySecretUpdate) SetValue(v string) *SecuritySecretUpdate {
_u.mutation.SetValue(v)
return _u
}
// SetNillableValue sets the "value" field if the given value is not nil.
func (_u *SecuritySecretUpdate) SetNillableValue(v *string) *SecuritySecretUpdate {
if v != nil {
_u.SetValue(*v)
}
return _u
}
// Mutation returns the SecuritySecretMutation object of the builder.
func (_u *SecuritySecretUpdate) Mutation() *SecuritySecretMutation {
return _u.mutation
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *SecuritySecretUpdate) Save(ctx context.Context) (int, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *SecuritySecretUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *SecuritySecretUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *SecuritySecretUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *SecuritySecretUpdate) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := securitysecret.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *SecuritySecretUpdate) check() error {
if v, ok := _u.mutation.Key(); ok {
if err := securitysecret.KeyValidator(v); err != nil {
return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.key": %w`, err)}
}
}
if v, ok := _u.mutation.Value(); ok {
if err := securitysecret.ValueValidator(v); err != nil {
return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.value": %w`, err)}
}
}
return nil
}
func (_u *SecuritySecretUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(securitysecret.Table, securitysecret.Columns, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(securitysecret.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.Key(); ok {
_spec.SetField(securitysecret.FieldKey, field.TypeString, value)
}
if value, ok := _u.mutation.Value(); ok {
_spec.SetField(securitysecret.FieldValue, field.TypeString, value)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{securitysecret.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// SecuritySecretUpdateOne is the builder for updating a single SecuritySecret entity.
type SecuritySecretUpdateOne struct {
config
fields []string
hooks []Hook
mutation *SecuritySecretMutation
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *SecuritySecretUpdateOne) SetUpdatedAt(v time.Time) *SecuritySecretUpdateOne {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetKey sets the "key" field.
func (_u *SecuritySecretUpdateOne) SetKey(v string) *SecuritySecretUpdateOne {
_u.mutation.SetKey(v)
return _u
}
// SetNillableKey sets the "key" field if the given value is not nil.
func (_u *SecuritySecretUpdateOne) SetNillableKey(v *string) *SecuritySecretUpdateOne {
if v != nil {
_u.SetKey(*v)
}
return _u
}
// SetValue sets the "value" field.
func (_u *SecuritySecretUpdateOne) SetValue(v string) *SecuritySecretUpdateOne {
_u.mutation.SetValue(v)
return _u
}
// SetNillableValue sets the "value" field if the given value is not nil.
func (_u *SecuritySecretUpdateOne) SetNillableValue(v *string) *SecuritySecretUpdateOne {
if v != nil {
_u.SetValue(*v)
}
return _u
}
// Mutation returns the SecuritySecretMutation object of the builder.
func (_u *SecuritySecretUpdateOne) Mutation() *SecuritySecretMutation {
return _u.mutation
}
// Where appends a list predicates to the SecuritySecretUpdate builder.
func (_u *SecuritySecretUpdateOne) Where(ps ...predicate.SecuritySecret) *SecuritySecretUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *SecuritySecretUpdateOne) Select(field string, fields ...string) *SecuritySecretUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated SecuritySecret entity.
func (_u *SecuritySecretUpdateOne) Save(ctx context.Context) (*SecuritySecret, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *SecuritySecretUpdateOne) SaveX(ctx context.Context) *SecuritySecret {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *SecuritySecretUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *SecuritySecretUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *SecuritySecretUpdateOne) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := securitysecret.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *SecuritySecretUpdateOne) check() error {
if v, ok := _u.mutation.Key(); ok {
if err := securitysecret.KeyValidator(v); err != nil {
return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.key": %w`, err)}
}
}
if v, ok := _u.mutation.Value(); ok {
if err := securitysecret.ValueValidator(v); err != nil {
return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.value": %w`, err)}
}
}
return nil
}
func (_u *SecuritySecretUpdateOne) sqlSave(ctx context.Context) (_node *SecuritySecret, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(securitysecret.Table, securitysecret.Columns, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "SecuritySecret.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, securitysecret.FieldID)
for _, f := range fields {
if !securitysecret.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != securitysecret.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(securitysecret.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.Key(); ok {
_spec.SetField(securitysecret.FieldKey, field.TypeString, value)
}
if value, ok := _u.mutation.Value(); ok {
_spec.SetField(securitysecret.FieldValue, field.TypeString, value)
}
_node = &SecuritySecret{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{securitysecret.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}

View File

@@ -36,6 +36,8 @@ type Tx struct {
Proxy *ProxyClient
// RedeemCode is the client for interacting with the RedeemCode builders.
RedeemCode *RedeemCodeClient
// SecuritySecret is the client for interacting with the SecuritySecret builders.
SecuritySecret *SecuritySecretClient
// Setting is the client for interacting with the Setting builders.
Setting *SettingClient
// UsageCleanupTask is the client for interacting with the UsageCleanupTask builders.
@@ -194,6 +196,7 @@ func (tx *Tx) init() {
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
tx.Proxy = NewProxyClient(tx.config)
tx.RedeemCode = NewRedeemCodeClient(tx.config)
tx.SecuritySecret = NewSecuritySecretClient(tx.config)
tx.Setting = NewSettingClient(tx.config)
tx.UsageCleanupTask = NewUsageCleanupTaskClient(tx.config)
tx.UsageLog = NewUsageLogClient(tx.config)

View File

@@ -80,6 +80,10 @@ type UsageLog struct {
ImageCount int `json:"image_count,omitempty"`
// ImageSize holds the value of the "image_size" field.
ImageSize *string `json:"image_size,omitempty"`
// MediaType holds the value of the "media_type" field.
MediaType *string `json:"media_type,omitempty"`
// CacheTTLOverridden holds the value of the "cache_ttl_overridden" field.
CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
@@ -165,13 +169,13 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case usagelog.FieldStream:
case usagelog.FieldStream, usagelog.FieldCacheTTLOverridden:
values[i] = new(sql.NullBool)
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64)
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize:
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
values[i] = new(sql.NullString)
case usagelog.FieldCreatedAt:
values[i] = new(sql.NullTime)
@@ -378,6 +382,19 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
_m.ImageSize = new(string)
*_m.ImageSize = value.String
}
case usagelog.FieldMediaType:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field media_type", values[i])
} else if value.Valid {
_m.MediaType = new(string)
*_m.MediaType = value.String
}
case usagelog.FieldCacheTTLOverridden:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i])
} else if value.Valid {
_m.CacheTTLOverridden = value.Bool
}
case usagelog.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
@@ -548,6 +565,14 @@ func (_m *UsageLog) String() string {
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.MediaType; v != nil {
builder.WriteString("media_type=")
builder.WriteString(*v)
}
builder.WriteString(", ")
builder.WriteString("cache_ttl_overridden=")
builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden))
builder.WriteString(", ")
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteByte(')')

View File

@@ -72,6 +72,10 @@ const (
FieldImageCount = "image_count"
// FieldImageSize holds the string denoting the image_size field in the database.
FieldImageSize = "image_size"
// FieldMediaType holds the string denoting the media_type field in the database.
FieldMediaType = "media_type"
// FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database.
FieldCacheTTLOverridden = "cache_ttl_overridden"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// EdgeUser holds the string denoting the user edge name in mutations.
@@ -155,6 +159,8 @@ var Columns = []string{
FieldIPAddress,
FieldImageCount,
FieldImageSize,
FieldMediaType,
FieldCacheTTLOverridden,
FieldCreatedAt,
}
@@ -211,6 +217,10 @@ var (
DefaultImageCount int
// ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
ImageSizeValidator func(string) error
// MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
MediaTypeValidator func(string) error
// DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field.
DefaultCacheTTLOverridden bool
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
)
@@ -368,6 +378,16 @@ func ByImageSize(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImageSize, opts...).ToFunc()
}
// ByMediaType orders the results by the media_type field.
func ByMediaType(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldMediaType, opts...).ToFunc()
}
// ByCacheTTLOverridden orders the results by the cache_ttl_overridden field.
func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()

View File

@@ -200,6 +200,16 @@ func ImageSize(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v))
}
// MediaType applies equality check predicate on the "media_type" field. It's identical to MediaTypeEQ.
func MediaType(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v))
}
// CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ.
func CacheTTLOverridden(v bool) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v))
@@ -1440,6 +1450,91 @@ func ImageSizeContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v))
}
// MediaTypeEQ applies the EQ predicate on the "media_type" field.
func MediaTypeEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v))
}
// MediaTypeNEQ applies the NEQ predicate on the "media_type" field.
func MediaTypeNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldMediaType, v))
}
// MediaTypeIn applies the In predicate on the "media_type" field.
func MediaTypeIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldMediaType, vs...))
}
// MediaTypeNotIn applies the NotIn predicate on the "media_type" field.
func MediaTypeNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldMediaType, vs...))
}
// MediaTypeGT applies the GT predicate on the "media_type" field.
func MediaTypeGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldMediaType, v))
}
// MediaTypeGTE applies the GTE predicate on the "media_type" field.
func MediaTypeGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldMediaType, v))
}
// MediaTypeLT applies the LT predicate on the "media_type" field.
func MediaTypeLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldMediaType, v))
}
// MediaTypeLTE applies the LTE predicate on the "media_type" field.
func MediaTypeLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldMediaType, v))
}
// MediaTypeContains applies the Contains predicate on the "media_type" field.
func MediaTypeContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldMediaType, v))
}
// MediaTypeHasPrefix applies the HasPrefix predicate on the "media_type" field.
func MediaTypeHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldMediaType, v))
}
// MediaTypeHasSuffix applies the HasSuffix predicate on the "media_type" field.
func MediaTypeHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldMediaType, v))
}
// MediaTypeIsNil applies the IsNil predicate on the "media_type" field.
func MediaTypeIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldMediaType))
}
// MediaTypeNotNil applies the NotNil predicate on the "media_type" field.
func MediaTypeNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldMediaType))
}
// MediaTypeEqualFold applies the EqualFold predicate on the "media_type" field.
func MediaTypeEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldMediaType, v))
}
// MediaTypeContainsFold applies the ContainsFold predicate on the "media_type" field.
func MediaTypeContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldMediaType, v))
}
// CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field.
func CacheTTLOverriddenEQ(v bool) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
}
// CacheTTLOverriddenNEQ applies the NEQ predicate on the "cache_ttl_overridden" field.
func CacheTTLOverriddenNEQ(v bool) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldCacheTTLOverridden, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v))

View File

@@ -393,6 +393,34 @@ func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate {
return _c
}
// SetMediaType sets the "media_type" field.
func (_c *UsageLogCreate) SetMediaType(v string) *UsageLogCreate {
_c.mutation.SetMediaType(v)
return _c
}
// SetNillableMediaType sets the "media_type" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableMediaType(v *string) *UsageLogCreate {
if v != nil {
_c.SetMediaType(*v)
}
return _c
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate {
_c.mutation.SetCacheTTLOverridden(v)
return _c
}
// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableCacheTTLOverridden(v *bool) *UsageLogCreate {
if v != nil {
_c.SetCacheTTLOverridden(*v)
}
return _c
}
// SetCreatedAt sets the "created_at" field.
func (_c *UsageLogCreate) SetCreatedAt(v time.Time) *UsageLogCreate {
_c.mutation.SetCreatedAt(v)
@@ -531,6 +559,10 @@ func (_c *UsageLogCreate) defaults() {
v := usagelog.DefaultImageCount
_c.mutation.SetImageCount(v)
}
if _, ok := _c.mutation.CacheTTLOverridden(); !ok {
v := usagelog.DefaultCacheTTLOverridden
_c.mutation.SetCacheTTLOverridden(v)
}
if _, ok := _c.mutation.CreatedAt(); !ok {
v := usagelog.DefaultCreatedAt()
_c.mutation.SetCreatedAt(v)
@@ -627,6 +659,14 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
}
}
if v, ok := _c.mutation.MediaType(); ok {
if err := usagelog.MediaTypeValidator(v); err != nil {
return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
}
}
if _, ok := _c.mutation.CacheTTLOverridden(); !ok {
return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)}
}
if _, ok := _c.mutation.CreatedAt(); !ok {
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UsageLog.created_at"`)}
}
@@ -762,6 +802,14 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldImageSize, field.TypeString, value)
_node.ImageSize = &value
}
if value, ok := _c.mutation.MediaType(); ok {
_spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
_node.MediaType = &value
}
if value, ok := _c.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
_node.CacheTTLOverridden = value
}
if value, ok := _c.mutation.CreatedAt(); ok {
_spec.SetField(usagelog.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = value
@@ -1407,6 +1455,36 @@ func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert {
return u
}
// SetMediaType sets the "media_type" field.
func (u *UsageLogUpsert) SetMediaType(v string) *UsageLogUpsert {
u.Set(usagelog.FieldMediaType, v)
return u
}
// UpdateMediaType sets the "media_type" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateMediaType() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldMediaType)
return u
}
// ClearMediaType clears the value of the "media_type" field.
func (u *UsageLogUpsert) ClearMediaType() *UsageLogUpsert {
u.SetNull(usagelog.FieldMediaType)
return u
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert {
u.Set(usagelog.FieldCacheTTLOverridden, v)
return u
}
// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateCacheTTLOverridden() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldCacheTTLOverridden)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -2040,6 +2118,41 @@ func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne {
})
}
// SetMediaType sets the "media_type" field.
func (u *UsageLogUpsertOne) SetMediaType(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetMediaType(v)
})
}
// UpdateMediaType sets the "media_type" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateMediaType() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateMediaType()
})
}
// ClearMediaType clears the value of the "media_type" field.
func (u *UsageLogUpsertOne) ClearMediaType() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearMediaType()
})
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetCacheTTLOverridden(v)
})
}
// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateCacheTTLOverridden() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateCacheTTLOverridden()
})
}
// Exec executes the query.
func (u *UsageLogUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -2839,6 +2952,41 @@ func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk {
})
}
// SetMediaType sets the "media_type" field.
func (u *UsageLogUpsertBulk) SetMediaType(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetMediaType(v)
})
}
// UpdateMediaType sets the "media_type" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateMediaType() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateMediaType()
})
}
// ClearMediaType clears the value of the "media_type" field.
func (u *UsageLogUpsertBulk) ClearMediaType() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearMediaType()
})
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetCacheTTLOverridden(v)
})
}
// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateCacheTTLOverridden() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateCacheTTLOverridden()
})
}
// Exec executes the query.
func (u *UsageLogUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {

View File

@@ -612,6 +612,40 @@ func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate {
return _u
}
// SetMediaType sets the "media_type" field.
func (_u *UsageLogUpdate) SetMediaType(v string) *UsageLogUpdate {
_u.mutation.SetMediaType(v)
return _u
}
// SetNillableMediaType sets the "media_type" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableMediaType(v *string) *UsageLogUpdate {
if v != nil {
_u.SetMediaType(*v)
}
return _u
}
// ClearMediaType clears the value of the "media_type" field.
func (_u *UsageLogUpdate) ClearMediaType() *UsageLogUpdate {
_u.mutation.ClearMediaType()
return _u
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate {
_u.mutation.SetCacheTTLOverridden(v)
return _u
}
// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableCacheTTLOverridden(v *bool) *UsageLogUpdate {
if v != nil {
_u.SetCacheTTLOverridden(*v)
}
return _u
}
// SetUser sets the "user" edge to the User entity.
func (_u *UsageLogUpdate) SetUser(v *User) *UsageLogUpdate {
return _u.SetUserID(v.ID)
@@ -726,6 +760,11 @@ func (_u *UsageLogUpdate) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
}
}
if v, ok := _u.mutation.MediaType(); ok {
if err := usagelog.MediaTypeValidator(v); err != nil {
return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
}
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
}
@@ -894,6 +933,15 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.ImageSizeCleared() {
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
}
if value, ok := _u.mutation.MediaType(); ok {
_spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
}
if _u.mutation.MediaTypeCleared() {
_spec.ClearField(usagelog.FieldMediaType, field.TypeString)
}
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
@@ -1639,6 +1687,40 @@ func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne {
return _u
}
// SetMediaType sets the "media_type" field.
func (_u *UsageLogUpdateOne) SetMediaType(v string) *UsageLogUpdateOne {
_u.mutation.SetMediaType(v)
return _u
}
// SetNillableMediaType sets the "media_type" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableMediaType(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetMediaType(*v)
}
return _u
}
// ClearMediaType clears the value of the "media_type" field.
func (_u *UsageLogUpdateOne) ClearMediaType() *UsageLogUpdateOne {
_u.mutation.ClearMediaType()
return _u
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne {
_u.mutation.SetCacheTTLOverridden(v)
return _u
}
// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableCacheTTLOverridden(v *bool) *UsageLogUpdateOne {
if v != nil {
_u.SetCacheTTLOverridden(*v)
}
return _u
}
// SetUser sets the "user" edge to the User entity.
func (_u *UsageLogUpdateOne) SetUser(v *User) *UsageLogUpdateOne {
return _u.SetUserID(v.ID)
@@ -1766,6 +1848,11 @@ func (_u *UsageLogUpdateOne) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
}
}
if v, ok := _u.mutation.MediaType(); ok {
if err := usagelog.MediaTypeValidator(v); err != nil {
return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
}
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
}
@@ -1951,6 +2038,15 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if _u.mutation.ImageSizeCleared() {
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
}
if value, ok := _u.mutation.MediaType(); ok {
_spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
}
if _u.mutation.MediaTypeCleared() {
_spec.ClearField(usagelog.FieldMediaType, field.TypeString)
}
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,

View File

@@ -5,6 +5,9 @@ go 1.25.7
require (
entgo.io/ent v0.14.5
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/DouDOU-start/go-sora2api v1.1.0
github.com/alitto/pond/v2 v2.6.2
github.com/cespare/xxhash/v2 v2.3.0
github.com/dgraph-io/ristretto v0.2.0
github.com/gin-gonic/gin v1.9.1
github.com/golang-jwt/jwt/v5 v5.2.2
@@ -13,9 +16,10 @@ require (
github.com/gorilla/websocket v1.5.3
github.com/imroc/req/v3 v3.57.0
github.com/lib/pq v1.10.9
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pquerna/otp v1.5.0
github.com/redis/go-redis/v9 v9.17.2
github.com/refraction-networking/utls v1.8.1
github.com/refraction-networking/utls v1.8.2
github.com/robfig/cron/v3 v3.0.1
github.com/shirou/gopsutil/v4 v4.25.6
github.com/spf13/viper v1.18.2
@@ -25,10 +29,12 @@ require (
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
github.com/zeromicro/go-zero v1.9.4
golang.org/x/crypto v0.47.0
go.uber.org/zap v1.24.0
golang.org/x/crypto v0.48.0
golang.org/x/net v0.49.0
golang.org/x/sync v0.19.0
golang.org/x/term v0.39.0
golang.org/x/term v0.40.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
modernc.org/sqlite v1.44.3
)
@@ -41,11 +47,17 @@ require (
github.com/agext/levenshtein v1.2.3 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
github.com/bdandy/go-errors v1.2.2 // indirect
github.com/bdandy/go-socks4 v1.2.3 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect
github.com/bogdanfinn/fhttp v0.6.8 // indirect
github.com/bogdanfinn/quic-go-utls v1.0.9-utls // indirect
github.com/bogdanfinn/tls-client v1.14.0 // indirect
github.com/bogdanfinn/utls v1.7.7-barnius // indirect
github.com/bogdanfinn/websocket v1.5.5-barnius // indirect
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/containerd/errdefs v1.0.0 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect
@@ -103,7 +115,6 @@ require (
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
@@ -120,6 +131,7 @@ require (
github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 // indirect
github.com/testcontainers/testcontainers-go v0.40.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
@@ -141,9 +153,9 @@ require (
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/mod v0.31.0 // indirect
golang.org/x/sys v0.40.0 // indirect
golang.org/x/text v0.33.0 // indirect
golang.org/x/mod v0.32.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect

View File

@@ -10,16 +10,36 @@ github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOEl
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/DouDOU-start/go-sora2api v1.1.0 h1:PxWiukK77StiHxEngOFwT1rKUn9oTAJJTl07wQUXwiU=
github.com/DouDOU-start/go-sora2api v1.1.0/go.mod h1:dcwpethoKfAsMWskDD9iGgc/3yox2tkthPLSMVGnhkE=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo=
github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558=
github.com/alitto/pond/v2 v2.6.2 h1:Sphe40g0ILeM1pA2c2K+Th0DGU+pt0A/Kprr+WB24Pw=
github.com/alitto/pond/v2 v2.6.2/go.mod h1:xkjYEgQ05RSpWdfSd1nM3OVv7TBhLdy7rMp3+2Nq+yE=
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
github.com/bdandy/go-socks4 v1.2.3/go.mod h1:98kiVFgpdogR8aIGLWLvjDVZ8XcKPsSI/ypGrO+bqHI=
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
github.com/bogdanfinn/fhttp v0.6.8 h1:LiQyHOY3i0QoxxNB7nq27/nGNNbtPj0fuBPozhR7Ws4=
github.com/bogdanfinn/fhttp v0.6.8/go.mod h1:A+EKDzMx2hb4IUbMx4TlkoHnaJEiLl8r/1Ss1Y+5e5M=
github.com/bogdanfinn/quic-go-utls v1.0.9-utls h1:tV6eDEiRbRCcepALSzxR94JUVD3N3ACIiRLgyc2Ep8s=
github.com/bogdanfinn/quic-go-utls v1.0.9-utls/go.mod h1:aHph9B9H9yPOt5xnhWKSOum27DJAqpiHzwX+gjvaXcg=
github.com/bogdanfinn/tls-client v1.14.0 h1:vyk7Cn4BIvLAGVuMfb0tP22OqogfO1lYamquQNEZU1A=
github.com/bogdanfinn/tls-client v1.14.0/go.mod h1:LsU6mXVn8MOFDwTkyRfI7V1BZM1p0wf2ZfZsICW/1fM=
github.com/bogdanfinn/utls v1.7.7-barnius h1:OuJ497cc7F3yKNVHRsYPQdGggmk5x6+V5ZlrCR7fOLU=
github.com/bogdanfinn/utls v1.7.7-barnius/go.mod h1:aAK1VZQlpKZClF1WEQeq6kyclbkPq4hz6xTbB5xSlmg=
github.com/bogdanfinn/websocket v1.5.5-barnius h1:bY+qnxpai1qe7Jmjx+Sds/cmOSpuuLoR8x61rWltjOI=
github.com/bogdanfinn/websocket v1.5.5-barnius/go.mod h1:gvvEw6pTKHb7yOiFvIfAFTStQWyrm25BMVCTj5wRSsI=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
@@ -135,8 +155,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -172,8 +190,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
@@ -207,8 +223,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@@ -234,12 +248,10 @@ github.com/quic-go/quic-go v0.57.1 h1:25KAAR9QR8KZrCZRThWMKVAwGoiHIrNbT72ULHTuI1
github.com/quic-go/quic-go v0.57.1/go.mod h1:ly4QBAjHA2VhdnxhojRsCUOeJwKYg+taDlos92xb1+s=
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
@@ -262,8 +274,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
@@ -285,6 +295,8 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 h1:YqAladjX7xpA6BM04leXMWAEjS0mTZ5kUU9KRBriQJc=
github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5/go.mod h1:2JjD2zLQYH5HO74y5+aE3remJQvl6q4Sn6aWA2wD1Ng=
github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU=
github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY=
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 h1:s2bIayFXlbDFexo96y+htn7FzuhpXLYJNnIuglNKqOk=
@@ -340,25 +352,32 @@ go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ=
go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs=
go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60=
go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
golang.org/x/net v0.0.0-20211104170005-ce137452f963/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -366,16 +385,19 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY=
golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww=
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
@@ -391,6 +413,8 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -5,7 +5,7 @@ import (
"crypto/rand"
"encoding/hex"
"fmt"
"log"
"log/slog"
"net/url"
"os"
"strings"
@@ -19,6 +19,13 @@ const (
RunModeSimple = "simple"
)
// 使用量记录队列溢出策略
const (
UsageRecordOverflowPolicyDrop = "drop"
UsageRecordOverflowPolicySample = "sample"
UsageRecordOverflowPolicySync = "sync"
)
// DefaultCSPPolicy is the default Content-Security-Policy with nonce support
// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
@@ -38,31 +45,68 @@ const (
)
type Config struct {
Server ServerConfig `mapstructure:"server"`
CORS CORSConfig `mapstructure:"cors"`
Security SecurityConfig `mapstructure:"security"`
Billing BillingConfig `mapstructure:"billing"`
Turnstile TurnstileConfig `mapstructure:"turnstile"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
Ops OpsConfig `mapstructure:"ops"`
JWT JWTConfig `mapstructure:"jwt"`
Totp TotpConfig `mapstructure:"totp"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"`
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Gemini GeminiConfig `mapstructure:"gemini"`
Update UpdateConfig `mapstructure:"update"`
Server ServerConfig `mapstructure:"server"`
Log LogConfig `mapstructure:"log"`
CORS CORSConfig `mapstructure:"cors"`
Security SecurityConfig `mapstructure:"security"`
Billing BillingConfig `mapstructure:"billing"`
Turnstile TurnstileConfig `mapstructure:"turnstile"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
Ops OpsConfig `mapstructure:"ops"`
JWT JWTConfig `mapstructure:"jwt"`
Totp TotpConfig `mapstructure:"totp"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"`
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
SubscriptionCache SubscriptionCacheConfig `mapstructure:"subscription_cache"`
SubscriptionMaintenance SubscriptionMaintenanceConfig `mapstructure:"subscription_maintenance"`
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
Sora SoraConfig `mapstructure:"sora"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Gemini GeminiConfig `mapstructure:"gemini"`
Update UpdateConfig `mapstructure:"update"`
Idempotency IdempotencyConfig `mapstructure:"idempotency"`
}
type LogConfig struct {
Level string `mapstructure:"level"`
Format string `mapstructure:"format"`
ServiceName string `mapstructure:"service_name"`
Environment string `mapstructure:"env"`
Caller bool `mapstructure:"caller"`
StacktraceLevel string `mapstructure:"stacktrace_level"`
Output LogOutputConfig `mapstructure:"output"`
Rotation LogRotationConfig `mapstructure:"rotation"`
Sampling LogSamplingConfig `mapstructure:"sampling"`
}
type LogOutputConfig struct {
ToStdout bool `mapstructure:"to_stdout"`
ToFile bool `mapstructure:"to_file"`
FilePath string `mapstructure:"file_path"`
}
type LogRotationConfig struct {
MaxSizeMB int `mapstructure:"max_size_mb"`
MaxBackups int `mapstructure:"max_backups"`
MaxAgeDays int `mapstructure:"max_age_days"`
Compress bool `mapstructure:"compress"`
LocalTime bool `mapstructure:"local_time"`
}
type LogSamplingConfig struct {
Enabled bool `mapstructure:"enabled"`
Initial int `mapstructure:"initial"`
Thereafter int `mapstructure:"thereafter"`
}
type GeminiConfig struct {
@@ -94,6 +138,25 @@ type UpdateConfig struct {
ProxyURL string `mapstructure:"proxy_url"`
}
type IdempotencyConfig struct {
// ObserveOnly 为 true 时处于观察期:未携带 Idempotency-Key 的请求继续放行。
ObserveOnly bool `mapstructure:"observe_only"`
// DefaultTTLSeconds 关键写接口的幂等记录默认 TTL
DefaultTTLSeconds int `mapstructure:"default_ttl_seconds"`
// SystemOperationTTLSeconds 系统操作接口的幂等记录 TTL
SystemOperationTTLSeconds int `mapstructure:"system_operation_ttl_seconds"`
// ProcessingTimeoutSeconds processing 状态锁超时(秒)。
ProcessingTimeoutSeconds int `mapstructure:"processing_timeout_seconds"`
// FailedRetryBackoffSeconds 失败退避窗口(秒)。
FailedRetryBackoffSeconds int `mapstructure:"failed_retry_backoff_seconds"`
// MaxStoredResponseLen 持久化响应体最大长度(字节)。
MaxStoredResponseLen int `mapstructure:"max_stored_response_len"`
// CleanupIntervalSeconds 过期记录清理周期(秒)。
CleanupIntervalSeconds int `mapstructure:"cleanup_interval_seconds"`
// CleanupBatchSize 每次清理的最大记录数。
CleanupBatchSize int `mapstructure:"cleanup_batch_size"`
}
type LinuxDoConnectConfig struct {
Enabled bool `mapstructure:"enabled"`
ClientID string `mapstructure:"client_id"`
@@ -126,6 +189,8 @@ type TokenRefreshConfig struct {
MaxRetries int `mapstructure:"max_retries"`
// 重试退避基础时间(秒)
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
// 是否允许 OpenAI 刷新器同步覆盖关联的 Sora 账号 token默认关闭
SyncLinkedSoraAccounts bool `mapstructure:"sync_linked_sora_accounts"`
}
type PricingConfig struct {
@@ -147,6 +212,7 @@ type ServerConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Mode string `mapstructure:"mode"` // debug/release
FrontendURL string `mapstructure:"frontend_url"` // 前端基础 URL用于生成邮件中的外部链接
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表CIDR/IP
@@ -173,6 +239,7 @@ type SecurityConfig struct {
URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"`
ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"`
CSP CSPConfig `mapstructure:"csp"`
ProxyFallback ProxyFallbackConfig `mapstructure:"proxy_fallback"`
ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"`
}
@@ -197,6 +264,12 @@ type CSPConfig struct {
Policy string `mapstructure:"policy"`
}
type ProxyFallbackConfig struct {
// AllowDirectOnError 当代理初始化失败时是否允许回退直连。
// 默认 false避免因代理配置错误导致 IP 泄露/关联。
AllowDirectOnError bool `mapstructure:"allow_direct_on_error"`
}
type ProxyProbeConfig struct {
InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"` // 已禁用:禁止跳过 TLS 证书验证
}
@@ -217,6 +290,59 @@ type ConcurrencyConfig struct {
PingInterval int `mapstructure:"ping_interval"`
}
// SoraConfig 直连 Sora 配置
type SoraConfig struct {
Client SoraClientConfig `mapstructure:"client"`
Storage SoraStorageConfig `mapstructure:"storage"`
}
// SoraClientConfig 直连 Sora 客户端配置
type SoraClientConfig struct {
BaseURL string `mapstructure:"base_url"`
TimeoutSeconds int `mapstructure:"timeout_seconds"`
MaxRetries int `mapstructure:"max_retries"`
CloudflareChallengeCooldownSeconds int `mapstructure:"cloudflare_challenge_cooldown_seconds"`
PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
MaxPollAttempts int `mapstructure:"max_poll_attempts"`
RecentTaskLimit int `mapstructure:"recent_task_limit"`
RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
Debug bool `mapstructure:"debug"`
UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"`
Headers map[string]string `mapstructure:"headers"`
UserAgent string `mapstructure:"user_agent"`
DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
CurlCFFISidecar SoraCurlCFFISidecarConfig `mapstructure:"curl_cffi_sidecar"`
}
// SoraCurlCFFISidecarConfig Sora 专用 curl_cffi sidecar 配置
type SoraCurlCFFISidecarConfig struct {
Enabled bool `mapstructure:"enabled"`
BaseURL string `mapstructure:"base_url"`
Impersonate string `mapstructure:"impersonate"`
TimeoutSeconds int `mapstructure:"timeout_seconds"`
SessionReuseEnabled bool `mapstructure:"session_reuse_enabled"`
SessionTTLSeconds int `mapstructure:"session_ttl_seconds"`
}
// SoraStorageConfig 媒体存储配置
type SoraStorageConfig struct {
Type string `mapstructure:"type"`
LocalPath string `mapstructure:"local_path"`
FallbackToUpstream bool `mapstructure:"fallback_to_upstream"`
MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"`
DownloadTimeoutSeconds int `mapstructure:"download_timeout_seconds"`
MaxDownloadBytes int64 `mapstructure:"max_download_bytes"`
Debug bool `mapstructure:"debug"`
Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"`
}
// SoraStorageCleanupConfig 媒体清理配置
type SoraStorageCleanupConfig struct {
Enabled bool `mapstructure:"enabled"`
Schedule string `mapstructure:"schedule"`
RetentionDays int `mapstructure:"retention_days"`
}
// GatewayConfig API网关相关配置
type GatewayConfig struct {
// 等待上游响应头的超时时间0表示无超时
@@ -224,8 +350,20 @@ type GatewayConfig struct {
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"`
// 请求体最大字节数,用于网关请求体大小限制
MaxBodySize int64 `mapstructure:"max_body_size"`
// 非流式上游响应体读取上限(字节),用于防止无界读取导致内存放大
UpstreamResponseReadMaxBytes int64 `mapstructure:"upstream_response_read_max_bytes"`
// 代理探测响应体读取上限(字节)
ProxyProbeResponseReadMaxBytes int64 `mapstructure:"proxy_probe_response_read_max_bytes"`
// Gemini 上游响应头调试日志开关(默认关闭,避免高频日志开销)
GeminiDebugResponseHeaders bool `mapstructure:"gemini_debug_response_headers"`
// ConnectionPoolIsolation: 上游连接池隔离策略proxy/account/account_proxy
ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"`
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
// 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。
ForceCodexCLI bool `mapstructure:"force_codex_cli"`
// OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头
// 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
// MaxIdleConns: 所有主机的最大空闲连接总数
@@ -271,6 +409,24 @@ type GatewayConfig struct {
// 是否允许对部分 400 错误触发 failover默认关闭以避免改变语义
FailoverOn400 bool `mapstructure:"failover_on_400"`
// Sora 专用配置
// SoraMaxBodySize: Sora 请求体最大字节数0 表示使用 gateway.max_body_size
SoraMaxBodySize int64 `mapstructure:"sora_max_body_size"`
// SoraStreamTimeoutSeconds: Sora 流式请求总超时0 表示不限制)
SoraStreamTimeoutSeconds int `mapstructure:"sora_stream_timeout_seconds"`
// SoraRequestTimeoutSeconds: Sora 非流式请求超时0 表示不限制)
SoraRequestTimeoutSeconds int `mapstructure:"sora_request_timeout_seconds"`
// SoraStreamMode: stream 强制策略force/error
SoraStreamMode string `mapstructure:"sora_stream_mode"`
// SoraModelFilters: 模型列表过滤配置
SoraModelFilters SoraModelFiltersConfig `mapstructure:"sora_model_filters"`
// SoraMediaRequireAPIKey: 是否要求访问 /sora/media 携带 API Key
SoraMediaRequireAPIKey bool `mapstructure:"sora_media_require_api_key"`
// SoraMediaSigningKey: /sora/media 临时签名密钥(空表示禁用签名)
SoraMediaSigningKey string `mapstructure:"sora_media_signing_key"`
// SoraMediaSignedURLTTLSeconds: 临时签名 URL 有效期(秒,<=0 表示禁用)
SoraMediaSignedURLTTLSeconds int `mapstructure:"sora_media_signed_url_ttl_seconds"`
// 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限)
MaxAccountSwitches int `mapstructure:"max_account_switches"`
// Gemini 账户切换最大次数Gemini 平台单独配置,因 API 限制更严格)
@@ -284,6 +440,53 @@ type GatewayConfig struct {
// TLSFingerprint: TLS指纹伪装配置
TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"`
// UsageRecord: 使用量记录异步队列配置(有界队列 + 固定 worker
UsageRecord GatewayUsageRecordConfig `mapstructure:"usage_record"`
// UserGroupRateCacheTTLSeconds: 用户分组倍率热路径缓存 TTL
UserGroupRateCacheTTLSeconds int `mapstructure:"user_group_rate_cache_ttl_seconds"`
// ModelsListCacheTTLSeconds: /v1/models 模型列表短缓存 TTL
ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"`
}
// GatewayUsageRecordConfig 使用量记录异步队列配置
type GatewayUsageRecordConfig struct {
// WorkerCount: worker 初始数量(自动扩缩容开启时作为初始并发上限)
WorkerCount int `mapstructure:"worker_count"`
// QueueSize: 队列容量(有界)
QueueSize int `mapstructure:"queue_size"`
// TaskTimeoutSeconds: 单个使用量记录任务超时(秒)
TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"`
// OverflowPolicy: 队列满时策略drop/sample/sync
OverflowPolicy string `mapstructure:"overflow_policy"`
// OverflowSamplePercent: sample 策略下同步回写采样百分比1-100
OverflowSamplePercent int `mapstructure:"overflow_sample_percent"`
// AutoScaleEnabled: 是否启用 worker 自动扩缩容
AutoScaleEnabled bool `mapstructure:"auto_scale_enabled"`
// AutoScaleMinWorkers: 自动扩缩容最小 worker 数
AutoScaleMinWorkers int `mapstructure:"auto_scale_min_workers"`
// AutoScaleMaxWorkers: 自动扩缩容最大 worker 数
AutoScaleMaxWorkers int `mapstructure:"auto_scale_max_workers"`
// AutoScaleUpQueuePercent: 队列占用率达到该阈值时触发扩容
AutoScaleUpQueuePercent int `mapstructure:"auto_scale_up_queue_percent"`
// AutoScaleDownQueuePercent: 队列占用率低于该阈值时触发缩容
AutoScaleDownQueuePercent int `mapstructure:"auto_scale_down_queue_percent"`
// AutoScaleUpStep: 每次扩容步长
AutoScaleUpStep int `mapstructure:"auto_scale_up_step"`
// AutoScaleDownStep: 每次缩容步长
AutoScaleDownStep int `mapstructure:"auto_scale_down_step"`
// AutoScaleCheckIntervalSeconds: 自动扩缩容检测间隔(秒)
AutoScaleCheckIntervalSeconds int `mapstructure:"auto_scale_check_interval_seconds"`
// AutoScaleCooldownSeconds: 自动扩缩容冷却时间(秒)
AutoScaleCooldownSeconds int `mapstructure:"auto_scale_cooldown_seconds"`
}
// SoraModelFiltersConfig Sora 模型过滤配置
type SoraModelFiltersConfig struct {
// HidePromptEnhance 是否隐藏 prompt-enhance 模型
HidePromptEnhance bool `mapstructure:"hide_prompt_enhance"`
}
// TLSFingerprintConfig TLS指纹伪装配置
@@ -479,8 +682,9 @@ type OpsMetricsCollectorCacheConfig struct {
type JWTConfig struct {
Secret string `mapstructure:"secret"`
ExpireHour int `mapstructure:"expire_hour"`
// AccessTokenExpireMinutes: Access Token有效期分钟默认15分钟
// 短有效期减少被盗用风险配合Refresh Token实现无感续期
// AccessTokenExpireMinutes: Access Token有效期分钟
// - >0: 使用分钟配置(优先级高于 ExpireHour
// - =0: 回退使用 ExpireHour向后兼容旧配置
AccessTokenExpireMinutes int `mapstructure:"access_token_expire_minutes"`
// RefreshTokenExpireDays: Refresh Token有效期默认30天
RefreshTokenExpireDays int `mapstructure:"refresh_token_expire_days"`
@@ -525,6 +729,20 @@ type APIKeyAuthCacheConfig struct {
Singleflight bool `mapstructure:"singleflight"`
}
// SubscriptionCacheConfig 订阅认证 L1 缓存配置
type SubscriptionCacheConfig struct {
L1Size int `mapstructure:"l1_size"`
L1TTLSeconds int `mapstructure:"l1_ttl_seconds"`
JitterPercent int `mapstructure:"jitter_percent"`
}
// SubscriptionMaintenanceConfig 订阅窗口维护后台任务配置。
// 用于将“请求路径触发的维护动作”有界化,避免高并发下 goroutine 膨胀。
type SubscriptionMaintenanceConfig struct {
WorkerCount int `mapstructure:"worker_count"`
QueueSize int `mapstructure:"queue_size"`
}
// DashboardCacheConfig 仪表盘统计缓存配置
type DashboardCacheConfig struct {
// Enabled: 是否启用仪表盘缓存
@@ -588,7 +806,19 @@ func NormalizeRunMode(value string) string {
}
}
// Load 读取并校验完整配置(要求 jwt.secret 已显式提供)。
func Load() (*Config, error) {
return load(false)
}
// LoadForBootstrap 读取启动阶段配置。
//
// 启动阶段允许 jwt.secret 先留空,后续由数据库初始化流程补齐并再次完整校验。
func LoadForBootstrap() (*Config, error) {
return load(true)
}
func load(allowMissingJWTSecret bool) (*Config, error) {
viper.SetConfigName("config")
viper.SetConfigType("yaml")
@@ -630,6 +860,7 @@ func Load() (*Config, error) {
if cfg.Server.Mode == "" {
cfg.Server.Mode = "debug"
}
cfg.Server.FrontendURL = strings.TrimSpace(cfg.Server.FrontendURL)
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID)
cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret)
@@ -648,15 +879,12 @@ func Load() (*Config, error) {
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy)
if cfg.JWT.Secret == "" {
secret, err := generateJWTSecret(64)
if err != nil {
return nil, fmt.Errorf("generate jwt secret error: %w", err)
}
cfg.JWT.Secret = secret
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
}
cfg.Log.Level = strings.ToLower(strings.TrimSpace(cfg.Log.Level))
cfg.Log.Format = strings.ToLower(strings.TrimSpace(cfg.Log.Format))
cfg.Log.ServiceName = strings.TrimSpace(cfg.Log.ServiceName)
cfg.Log.Environment = strings.TrimSpace(cfg.Log.Environment)
cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel))
cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath)
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
@@ -667,29 +895,39 @@ func Load() (*Config, error) {
}
cfg.Totp.EncryptionKey = key
cfg.Totp.EncryptionKeyConfigured = false
log.Println("Warning: TOTP encryption key auto-generated. Consider setting a fixed key for production.")
slog.Warn("TOTP encryption key auto-generated. Consider setting a fixed key for production.")
} else {
cfg.Totp.EncryptionKeyConfigured = true
}
originalJWTSecret := cfg.JWT.Secret
if allowMissingJWTSecret && originalJWTSecret == "" {
// 启动阶段允许先无 JWT 密钥,后续在数据库初始化后补齐。
cfg.JWT.Secret = strings.Repeat("0", 32)
}
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("validate config error: %w", err)
}
if allowMissingJWTSecret && originalJWTSecret == "" {
cfg.JWT.Secret = ""
}
if !cfg.Security.URLAllowlist.Enabled {
log.Println("Warning: security.url_allowlist.enabled=false; allowlist/SSRF checks disabled (minimal format validation only).")
slog.Warn("security.url_allowlist.enabled=false; allowlist/SSRF checks disabled (minimal format validation only).")
}
if !cfg.Security.ResponseHeaders.Enabled {
log.Println("Warning: security.response_headers.enabled=false; configurable header filtering disabled (default allowlist only).")
slog.Warn("security.response_headers.enabled=false; configurable header filtering disabled (default allowlist only).")
}
if cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) {
log.Println("Warning: JWT secret appears weak; use a 32+ character random secret in production.")
slog.Warn("JWT secret appears weak; use a 32+ character random secret in production.")
}
if len(cfg.Security.ResponseHeaders.AdditionalAllowed) > 0 || len(cfg.Security.ResponseHeaders.ForceRemove) > 0 {
log.Printf("AUDIT: response header policy configured additional_allowed=%v force_remove=%v",
cfg.Security.ResponseHeaders.AdditionalAllowed,
cfg.Security.ResponseHeaders.ForceRemove,
slog.Info("response header policy configured",
"additional_allowed", cfg.Security.ResponseHeaders.AdditionalAllowed,
"force_remove", cfg.Security.ResponseHeaders.ForceRemove,
)
}
@@ -702,7 +940,8 @@ func setDefaults() {
// Server
viper.SetDefault("server.host", "0.0.0.0")
viper.SetDefault("server.port", 8080)
viper.SetDefault("server.mode", "debug")
viper.SetDefault("server.mode", "release")
viper.SetDefault("server.frontend_url", "")
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
viper.SetDefault("server.trusted_proxies", []string{})
@@ -715,6 +954,25 @@ func setDefaults() {
viper.SetDefault("server.h2c.max_upload_buffer_per_connection", 2<<20) // 2MB
viper.SetDefault("server.h2c.max_upload_buffer_per_stream", 512<<10) // 512KB
// Log
viper.SetDefault("log.level", "info")
viper.SetDefault("log.format", "console")
viper.SetDefault("log.service_name", "sub2api")
viper.SetDefault("log.env", "production")
viper.SetDefault("log.caller", true)
viper.SetDefault("log.stacktrace_level", "error")
viper.SetDefault("log.output.to_stdout", true)
viper.SetDefault("log.output.to_file", true)
viper.SetDefault("log.output.file_path", "")
viper.SetDefault("log.rotation.max_size_mb", 100)
viper.SetDefault("log.rotation.max_backups", 10)
viper.SetDefault("log.rotation.max_age_days", 7)
viper.SetDefault("log.rotation.compress", true)
viper.SetDefault("log.rotation.local_time", true)
viper.SetDefault("log.sampling.enabled", false)
viper.SetDefault("log.sampling.initial", 100)
viper.SetDefault("log.sampling.thereafter", 100)
// CORS
viper.SetDefault("cors.allowed_origins", []string{})
viper.SetDefault("cors.allow_credentials", true)
@@ -737,7 +995,7 @@ func setDefaults() {
viper.SetDefault("security.url_allowlist.crs_hosts", []string{})
viper.SetDefault("security.url_allowlist.allow_private_hosts", true)
viper.SetDefault("security.url_allowlist.allow_insecure_http", true)
viper.SetDefault("security.response_headers.enabled", false)
viper.SetDefault("security.response_headers.enabled", true)
viper.SetDefault("security.response_headers.additional_allowed", []string{})
viper.SetDefault("security.response_headers.force_remove", []string{})
viper.SetDefault("security.csp.enabled", true)
@@ -775,9 +1033,9 @@ func setDefaults() {
viper.SetDefault("database.user", "postgres")
viper.SetDefault("database.password", "postgres")
viper.SetDefault("database.dbname", "sub2api")
viper.SetDefault("database.sslmode", "disable")
viper.SetDefault("database.max_open_conns", 50)
viper.SetDefault("database.max_idle_conns", 10)
viper.SetDefault("database.sslmode", "prefer")
viper.SetDefault("database.max_open_conns", 256)
viper.SetDefault("database.max_idle_conns", 128)
viper.SetDefault("database.conn_max_lifetime_minutes", 30)
viper.SetDefault("database.conn_max_idle_time_minutes", 5)
@@ -789,8 +1047,8 @@ func setDefaults() {
viper.SetDefault("redis.dial_timeout_seconds", 5)
viper.SetDefault("redis.read_timeout_seconds", 3)
viper.SetDefault("redis.write_timeout_seconds", 3)
viper.SetDefault("redis.pool_size", 128)
viper.SetDefault("redis.min_idle_conns", 10)
viper.SetDefault("redis.pool_size", 1024)
viper.SetDefault("redis.min_idle_conns", 128)
viper.SetDefault("redis.enable_tls", false)
// Ops (vNext)
@@ -810,9 +1068,9 @@ func setDefaults() {
// JWT
viper.SetDefault("jwt.secret", "")
viper.SetDefault("jwt.expire_hour", 24)
viper.SetDefault("jwt.access_token_expire_minutes", 360) // 6小时Access Token有效期
viper.SetDefault("jwt.refresh_token_expire_days", 30) // 30天Refresh Token有效期
viper.SetDefault("jwt.refresh_window_minutes", 2) // 过期前2分钟开始允许刷新
viper.SetDefault("jwt.access_token_expire_minutes", 0) // 0 表示回退到 expire_hour
viper.SetDefault("jwt.refresh_token_expire_days", 30) // 30天Refresh Token有效期
viper.SetDefault("jwt.refresh_window_minutes", 2) // 过期前2分钟开始允许刷新
// TOTP
viper.SetDefault("totp.encryption_key", "")
@@ -830,9 +1088,9 @@ func setDefaults() {
// RateLimit
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
// Pricing - 从 price-mirror 分支同步,该分支维护了 sha256 哈希文件用于增量更新检查
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.json")
viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.sha256")
// Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据的配置
viper.SetDefault("pricing.remote_url", "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.json")
viper.SetDefault("pricing.hash_url", "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.sha256")
viper.SetDefault("pricing.data_dir", "./data")
viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json")
viper.SetDefault("pricing.update_interval_hours", 24)
@@ -849,6 +1107,11 @@ func setDefaults() {
viper.SetDefault("api_key_auth_cache.jitter_percent", 10)
viper.SetDefault("api_key_auth_cache.singleflight", true)
// Subscription auth L1 cache
viper.SetDefault("subscription_cache.l1_size", 16384)
viper.SetDefault("subscription_cache.l1_ttl_seconds", 10)
viper.SetDefault("subscription_cache.jitter_percent", 10)
// Dashboard cache
viper.SetDefault("dashboard_cache.enabled", true)
viper.SetDefault("dashboard_cache.key_prefix", "sub2api:")
@@ -874,6 +1137,16 @@ func setDefaults() {
viper.SetDefault("usage_cleanup.worker_interval_seconds", 10)
viper.SetDefault("usage_cleanup.task_timeout_seconds", 1800)
// Idempotency
viper.SetDefault("idempotency.observe_only", true)
viper.SetDefault("idempotency.default_ttl_seconds", 86400)
viper.SetDefault("idempotency.system_operation_ttl_seconds", 3600)
viper.SetDefault("idempotency.processing_timeout_seconds", 30)
viper.SetDefault("idempotency.failed_retry_backoff_seconds", 5)
viper.SetDefault("idempotency.max_stored_response_len", 64*1024)
viper.SetDefault("idempotency.cleanup_interval_seconds", 60)
viper.SetDefault("idempotency.cleanup_batch_size", 500)
// Gateway
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头LLM高负载时可能排队较久
viper.SetDefault("gateway.log_upstream_error_body", true)
@@ -882,13 +1155,26 @@ func setDefaults() {
viper.SetDefault("gateway.failover_on_400", false)
viper.SetDefault("gateway.max_account_switches", 10)
viper.SetDefault("gateway.max_account_switches_gemini", 3)
viper.SetDefault("gateway.force_codex_cli", false)
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.antigravity_extra_retries", 10)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
viper.SetDefault("gateway.gemini_debug_response_headers", false)
viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024))
viper.SetDefault("gateway.sora_stream_timeout_seconds", 900)
viper.SetDefault("gateway.sora_request_timeout_seconds", 180)
viper.SetDefault("gateway.sora_stream_mode", "force")
viper.SetDefault("gateway.sora_model_filters.hide_prompt_enhance", true)
viper.SetDefault("gateway.sora_media_require_api_key", true)
viper.SetDefault("gateway.sora_media_signed_url_ttl_seconds", 900)
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认
viper.SetDefault("gateway.max_idle_conns", 2560) // 最大空闲连接总数(高并发场景可调大
viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接HTTP/2 场景默认)
viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃HTTP/2 场景默认
viper.SetDefault("gateway.max_conns_per_host", 1024) // 每主机最大连接数(含活跃;流式/HTTP1.1 场景可调大,如 2400+
viper.SetDefault("gateway.idle_conn_timeout_seconds", 90) // 空闲连接超时(秒)
viper.SetDefault("gateway.max_upstream_clients", 5000)
viper.SetDefault("gateway.client_idle_ttl_seconds", 900)
@@ -912,16 +1198,65 @@ func setDefaults() {
viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3)
viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000)
viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300)
viper.SetDefault("gateway.usage_record.worker_count", 128)
viper.SetDefault("gateway.usage_record.queue_size", 16384)
viper.SetDefault("gateway.usage_record.task_timeout_seconds", 5)
viper.SetDefault("gateway.usage_record.overflow_policy", UsageRecordOverflowPolicySample)
viper.SetDefault("gateway.usage_record.overflow_sample_percent", 10)
viper.SetDefault("gateway.usage_record.auto_scale_enabled", true)
viper.SetDefault("gateway.usage_record.auto_scale_min_workers", 128)
viper.SetDefault("gateway.usage_record.auto_scale_max_workers", 512)
viper.SetDefault("gateway.usage_record.auto_scale_up_queue_percent", 70)
viper.SetDefault("gateway.usage_record.auto_scale_down_queue_percent", 15)
viper.SetDefault("gateway.usage_record.auto_scale_up_step", 32)
viper.SetDefault("gateway.usage_record.auto_scale_down_step", 16)
viper.SetDefault("gateway.usage_record.auto_scale_check_interval_seconds", 3)
viper.SetDefault("gateway.usage_record.auto_scale_cooldown_seconds", 10)
viper.SetDefault("gateway.user_group_rate_cache_ttl_seconds", 30)
viper.SetDefault("gateway.models_list_cache_ttl_seconds", 15)
// TLS指纹伪装配置默认关闭需要账号级别单独启用
viper.SetDefault("gateway.tls_fingerprint.enabled", true)
viper.SetDefault("concurrency.ping_interval", 10)
// Sora 直连配置
viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend")
viper.SetDefault("sora.client.timeout_seconds", 120)
viper.SetDefault("sora.client.max_retries", 3)
viper.SetDefault("sora.client.cloudflare_challenge_cooldown_seconds", 900)
viper.SetDefault("sora.client.poll_interval_seconds", 2)
viper.SetDefault("sora.client.max_poll_attempts", 600)
viper.SetDefault("sora.client.recent_task_limit", 50)
viper.SetDefault("sora.client.recent_task_limit_max", 200)
viper.SetDefault("sora.client.debug", false)
viper.SetDefault("sora.client.use_openai_token_provider", false)
viper.SetDefault("sora.client.headers", map[string]string{})
viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
viper.SetDefault("sora.client.disable_tls_fingerprint", false)
viper.SetDefault("sora.client.curl_cffi_sidecar.enabled", true)
viper.SetDefault("sora.client.curl_cffi_sidecar.base_url", "http://sora-curl-cffi-sidecar:8080")
viper.SetDefault("sora.client.curl_cffi_sidecar.impersonate", "chrome131")
viper.SetDefault("sora.client.curl_cffi_sidecar.timeout_seconds", 60)
viper.SetDefault("sora.client.curl_cffi_sidecar.session_reuse_enabled", true)
viper.SetDefault("sora.client.curl_cffi_sidecar.session_ttl_seconds", 3600)
viper.SetDefault("sora.storage.type", "local")
viper.SetDefault("sora.storage.local_path", "")
viper.SetDefault("sora.storage.fallback_to_upstream", true)
viper.SetDefault("sora.storage.max_concurrent_downloads", 4)
viper.SetDefault("sora.storage.download_timeout_seconds", 120)
viper.SetDefault("sora.storage.max_download_bytes", int64(200<<20))
viper.SetDefault("sora.storage.debug", false)
viper.SetDefault("sora.storage.cleanup.enabled", true)
viper.SetDefault("sora.storage.cleanup.retention_days", 7)
viper.SetDefault("sora.storage.cleanup.schedule", "0 3 * * *")
// TokenRefresh
viper.SetDefault("token_refresh.enabled", true)
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新适配Google 1小时token
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
viper.SetDefault("token_refresh.sync_linked_sora_accounts", false) // 默认不跨平台覆盖 Sora token
// Gemini OAuth - configure via environment variables or config file
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
@@ -930,9 +1265,106 @@ func setDefaults() {
viper.SetDefault("gemini.oauth.client_secret", "")
viper.SetDefault("gemini.oauth.scopes", "")
viper.SetDefault("gemini.quota.policy", "")
// Security - proxy fallback
viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false)
// Subscription Maintenance (bounded queue + worker pool)
viper.SetDefault("subscription_maintenance.worker_count", 2)
viper.SetDefault("subscription_maintenance.queue_size", 1024)
}
func (c *Config) Validate() error {
jwtSecret := strings.TrimSpace(c.JWT.Secret)
if jwtSecret == "" {
return fmt.Errorf("jwt.secret is required")
}
// NOTE: 按 UTF-8 编码后的字节长度计算。
// 选择 bytes 而不是 rune 计数,确保二进制/随机串的长度语义更接近“熵”而非“字符数”。
if len([]byte(jwtSecret)) < 32 {
return fmt.Errorf("jwt.secret must be at least 32 bytes")
}
switch c.Log.Level {
case "debug", "info", "warn", "error":
case "":
return fmt.Errorf("log.level is required")
default:
return fmt.Errorf("log.level must be one of: debug/info/warn/error")
}
switch c.Log.Format {
case "json", "console":
case "":
return fmt.Errorf("log.format is required")
default:
return fmt.Errorf("log.format must be one of: json/console")
}
switch c.Log.StacktraceLevel {
case "none", "error", "fatal":
case "":
return fmt.Errorf("log.stacktrace_level is required")
default:
return fmt.Errorf("log.stacktrace_level must be one of: none/error/fatal")
}
if !c.Log.Output.ToStdout && !c.Log.Output.ToFile {
return fmt.Errorf("log.output.to_stdout and log.output.to_file cannot both be false")
}
if c.Log.Rotation.MaxSizeMB <= 0 {
return fmt.Errorf("log.rotation.max_size_mb must be positive")
}
if c.Log.Rotation.MaxBackups < 0 {
return fmt.Errorf("log.rotation.max_backups must be non-negative")
}
if c.Log.Rotation.MaxAgeDays < 0 {
return fmt.Errorf("log.rotation.max_age_days must be non-negative")
}
if c.Log.Sampling.Enabled {
if c.Log.Sampling.Initial <= 0 {
return fmt.Errorf("log.sampling.initial must be positive when sampling is enabled")
}
if c.Log.Sampling.Thereafter <= 0 {
return fmt.Errorf("log.sampling.thereafter must be positive when sampling is enabled")
}
} else {
if c.Log.Sampling.Initial < 0 {
return fmt.Errorf("log.sampling.initial must be non-negative")
}
if c.Log.Sampling.Thereafter < 0 {
return fmt.Errorf("log.sampling.thereafter must be non-negative")
}
}
if c.SubscriptionMaintenance.WorkerCount < 0 {
return fmt.Errorf("subscription_maintenance.worker_count must be non-negative")
}
if c.SubscriptionMaintenance.QueueSize < 0 {
return fmt.Errorf("subscription_maintenance.queue_size must be non-negative")
}
// Gemini OAuth 配置校验client_id 与 client_secret 必须同时设置或同时留空。
// 留空时表示使用内置的 Gemini CLI OAuth 客户端(其 client_secret 通过环境变量注入)。
geminiClientID := strings.TrimSpace(c.Gemini.OAuth.ClientID)
geminiClientSecret := strings.TrimSpace(c.Gemini.OAuth.ClientSecret)
if (geminiClientID == "") != (geminiClientSecret == "") {
return fmt.Errorf("gemini.oauth.client_id and gemini.oauth.client_secret must be both set or both empty")
}
if strings.TrimSpace(c.Server.FrontendURL) != "" {
if err := ValidateAbsoluteHTTPURL(c.Server.FrontendURL); err != nil {
return fmt.Errorf("server.frontend_url invalid: %w", err)
}
u, err := url.Parse(strings.TrimSpace(c.Server.FrontendURL))
if err != nil {
return fmt.Errorf("server.frontend_url invalid: %w", err)
}
if u.RawQuery != "" || u.ForceQuery {
return fmt.Errorf("server.frontend_url invalid: must not include query")
}
if u.User != nil {
return fmt.Errorf("server.frontend_url invalid: must not include userinfo")
}
warnIfInsecureURL("server.frontend_url", c.Server.FrontendURL)
}
if c.JWT.ExpireHour <= 0 {
return fmt.Errorf("jwt.expire_hour must be positive")
}
@@ -940,20 +1372,20 @@ func (c *Config) Validate() error {
return fmt.Errorf("jwt.expire_hour must be <= 168 (7 days)")
}
if c.JWT.ExpireHour > 24 {
log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour)
slog.Warn("jwt.expire_hour is high; consider shorter expiration for security", "expire_hour", c.JWT.ExpireHour)
}
// JWT Refresh Token配置验证
if c.JWT.AccessTokenExpireMinutes <= 0 {
return fmt.Errorf("jwt.access_token_expire_minutes must be positive")
if c.JWT.AccessTokenExpireMinutes < 0 {
return fmt.Errorf("jwt.access_token_expire_minutes must be non-negative")
}
if c.JWT.AccessTokenExpireMinutes > 720 {
log.Printf("Warning: jwt.access_token_expire_minutes is %d (> 720). Consider shorter expiration for security.", c.JWT.AccessTokenExpireMinutes)
slog.Warn("jwt.access_token_expire_minutes is high; consider shorter expiration for security", "access_token_expire_minutes", c.JWT.AccessTokenExpireMinutes)
}
if c.JWT.RefreshTokenExpireDays <= 0 {
return fmt.Errorf("jwt.refresh_token_expire_days must be positive")
}
if c.JWT.RefreshTokenExpireDays > 90 {
log.Printf("Warning: jwt.refresh_token_expire_days is %d (> 90). Consider shorter expiration for security.", c.JWT.RefreshTokenExpireDays)
slog.Warn("jwt.refresh_token_expire_days is high; consider shorter expiration for security", "refresh_token_expire_days", c.JWT.RefreshTokenExpireDays)
}
if c.JWT.RefreshWindowMinutes < 0 {
return fmt.Errorf("jwt.refresh_window_minutes must be non-negative")
@@ -1159,9 +1591,116 @@ func (c *Config) Validate() error {
return fmt.Errorf("usage_cleanup.task_timeout_seconds must be non-negative")
}
}
if c.Idempotency.DefaultTTLSeconds <= 0 {
return fmt.Errorf("idempotency.default_ttl_seconds must be positive")
}
if c.Idempotency.SystemOperationTTLSeconds <= 0 {
return fmt.Errorf("idempotency.system_operation_ttl_seconds must be positive")
}
if c.Idempotency.ProcessingTimeoutSeconds <= 0 {
return fmt.Errorf("idempotency.processing_timeout_seconds must be positive")
}
if c.Idempotency.FailedRetryBackoffSeconds <= 0 {
return fmt.Errorf("idempotency.failed_retry_backoff_seconds must be positive")
}
if c.Idempotency.MaxStoredResponseLen <= 0 {
return fmt.Errorf("idempotency.max_stored_response_len must be positive")
}
if c.Idempotency.CleanupIntervalSeconds <= 0 {
return fmt.Errorf("idempotency.cleanup_interval_seconds must be positive")
}
if c.Idempotency.CleanupBatchSize <= 0 {
return fmt.Errorf("idempotency.cleanup_batch_size must be positive")
}
if c.Gateway.MaxBodySize <= 0 {
return fmt.Errorf("gateway.max_body_size must be positive")
}
if c.Gateway.UpstreamResponseReadMaxBytes <= 0 {
return fmt.Errorf("gateway.upstream_response_read_max_bytes must be positive")
}
if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 {
return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive")
}
if c.Gateway.SoraMaxBodySize < 0 {
return fmt.Errorf("gateway.sora_max_body_size must be non-negative")
}
if c.Gateway.SoraStreamTimeoutSeconds < 0 {
return fmt.Errorf("gateway.sora_stream_timeout_seconds must be non-negative")
}
if c.Gateway.SoraRequestTimeoutSeconds < 0 {
return fmt.Errorf("gateway.sora_request_timeout_seconds must be non-negative")
}
if c.Gateway.SoraMediaSignedURLTTLSeconds < 0 {
return fmt.Errorf("gateway.sora_media_signed_url_ttl_seconds must be non-negative")
}
if mode := strings.TrimSpace(strings.ToLower(c.Gateway.SoraStreamMode)); mode != "" {
switch mode {
case "force", "error":
default:
return fmt.Errorf("gateway.sora_stream_mode must be one of: force/error")
}
}
if c.Sora.Client.TimeoutSeconds < 0 {
return fmt.Errorf("sora.client.timeout_seconds must be non-negative")
}
if c.Sora.Client.MaxRetries < 0 {
return fmt.Errorf("sora.client.max_retries must be non-negative")
}
if c.Sora.Client.CloudflareChallengeCooldownSeconds < 0 {
return fmt.Errorf("sora.client.cloudflare_challenge_cooldown_seconds must be non-negative")
}
if c.Sora.Client.PollIntervalSeconds < 0 {
return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative")
}
if c.Sora.Client.MaxPollAttempts < 0 {
return fmt.Errorf("sora.client.max_poll_attempts must be non-negative")
}
if c.Sora.Client.RecentTaskLimit < 0 {
return fmt.Errorf("sora.client.recent_task_limit must be non-negative")
}
if c.Sora.Client.RecentTaskLimitMax < 0 {
return fmt.Errorf("sora.client.recent_task_limit_max must be non-negative")
}
if c.Sora.Client.RecentTaskLimitMax > 0 && c.Sora.Client.RecentTaskLimit > 0 &&
c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit {
c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit
}
if c.Sora.Client.CurlCFFISidecar.TimeoutSeconds < 0 {
return fmt.Errorf("sora.client.curl_cffi_sidecar.timeout_seconds must be non-negative")
}
if c.Sora.Client.CurlCFFISidecar.SessionTTLSeconds < 0 {
return fmt.Errorf("sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative")
}
if !c.Sora.Client.CurlCFFISidecar.Enabled {
return fmt.Errorf("sora.client.curl_cffi_sidecar.enabled must be true")
}
if strings.TrimSpace(c.Sora.Client.CurlCFFISidecar.BaseURL) == "" {
return fmt.Errorf("sora.client.curl_cffi_sidecar.base_url is required")
}
if c.Sora.Storage.MaxConcurrentDownloads < 0 {
return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative")
}
if c.Sora.Storage.DownloadTimeoutSeconds < 0 {
return fmt.Errorf("sora.storage.download_timeout_seconds must be non-negative")
}
if c.Sora.Storage.MaxDownloadBytes < 0 {
return fmt.Errorf("sora.storage.max_download_bytes must be non-negative")
}
if c.Sora.Storage.Cleanup.Enabled {
if c.Sora.Storage.Cleanup.RetentionDays <= 0 {
return fmt.Errorf("sora.storage.cleanup.retention_days must be positive")
}
if strings.TrimSpace(c.Sora.Storage.Cleanup.Schedule) == "" {
return fmt.Errorf("sora.storage.cleanup.schedule is required when cleanup is enabled")
}
} else {
if c.Sora.Storage.Cleanup.RetentionDays < 0 {
return fmt.Errorf("sora.storage.cleanup.retention_days must be non-negative")
}
}
if storageType := strings.TrimSpace(strings.ToLower(c.Sora.Storage.Type)); storageType != "" && storageType != "local" {
return fmt.Errorf("sora.storage.type must be 'local'")
}
if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" {
switch c.Gateway.ConnectionPoolIsolation {
case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy:
@@ -1183,7 +1722,7 @@ func (c *Config) Validate() error {
return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive")
}
if c.Gateway.IdleConnTimeoutSeconds > 180 {
log.Printf("Warning: gateway.idle_conn_timeout_seconds is %d (> 180). Consider 60-120 seconds for better connection reuse.", c.Gateway.IdleConnTimeoutSeconds)
slog.Warn("gateway.idle_conn_timeout_seconds is high; consider 60-120 seconds for better connection reuse", "idle_conn_timeout_seconds", c.Gateway.IdleConnTimeoutSeconds)
}
if c.Gateway.MaxUpstreamClients <= 0 {
return fmt.Errorf("gateway.max_upstream_clients must be positive")
@@ -1214,6 +1753,70 @@ func (c *Config) Validate() error {
if c.Gateway.MaxLineSize != 0 && c.Gateway.MaxLineSize < 1024*1024 {
return fmt.Errorf("gateway.max_line_size must be at least 1MB")
}
if c.Gateway.UsageRecord.WorkerCount <= 0 {
return fmt.Errorf("gateway.usage_record.worker_count must be positive")
}
if c.Gateway.UsageRecord.QueueSize <= 0 {
return fmt.Errorf("gateway.usage_record.queue_size must be positive")
}
if c.Gateway.UsageRecord.TaskTimeoutSeconds <= 0 {
return fmt.Errorf("gateway.usage_record.task_timeout_seconds must be positive")
}
switch strings.ToLower(strings.TrimSpace(c.Gateway.UsageRecord.OverflowPolicy)) {
case UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync:
default:
return fmt.Errorf("gateway.usage_record.overflow_policy must be one of: %s/%s/%s",
UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync)
}
if c.Gateway.UsageRecord.OverflowSamplePercent < 0 || c.Gateway.UsageRecord.OverflowSamplePercent > 100 {
return fmt.Errorf("gateway.usage_record.overflow_sample_percent must be between 0-100")
}
if strings.EqualFold(strings.TrimSpace(c.Gateway.UsageRecord.OverflowPolicy), UsageRecordOverflowPolicySample) &&
c.Gateway.UsageRecord.OverflowSamplePercent <= 0 {
return fmt.Errorf("gateway.usage_record.overflow_sample_percent must be positive when overflow_policy=sample")
}
if c.Gateway.UsageRecord.AutoScaleEnabled {
if c.Gateway.UsageRecord.AutoScaleMinWorkers <= 0 {
return fmt.Errorf("gateway.usage_record.auto_scale_min_workers must be positive")
}
if c.Gateway.UsageRecord.AutoScaleMaxWorkers <= 0 {
return fmt.Errorf("gateway.usage_record.auto_scale_max_workers must be positive")
}
if c.Gateway.UsageRecord.AutoScaleMaxWorkers < c.Gateway.UsageRecord.AutoScaleMinWorkers {
return fmt.Errorf("gateway.usage_record.auto_scale_max_workers must be >= auto_scale_min_workers")
}
if c.Gateway.UsageRecord.WorkerCount < c.Gateway.UsageRecord.AutoScaleMinWorkers ||
c.Gateway.UsageRecord.WorkerCount > c.Gateway.UsageRecord.AutoScaleMaxWorkers {
return fmt.Errorf("gateway.usage_record.worker_count must be between auto_scale_min_workers and auto_scale_max_workers")
}
if c.Gateway.UsageRecord.AutoScaleUpQueuePercent <= 0 || c.Gateway.UsageRecord.AutoScaleUpQueuePercent > 100 {
return fmt.Errorf("gateway.usage_record.auto_scale_up_queue_percent must be between 1-100")
}
if c.Gateway.UsageRecord.AutoScaleDownQueuePercent < 0 || c.Gateway.UsageRecord.AutoScaleDownQueuePercent >= 100 {
return fmt.Errorf("gateway.usage_record.auto_scale_down_queue_percent must be between 0-99")
}
if c.Gateway.UsageRecord.AutoScaleDownQueuePercent >= c.Gateway.UsageRecord.AutoScaleUpQueuePercent {
return fmt.Errorf("gateway.usage_record.auto_scale_down_queue_percent must be less than auto_scale_up_queue_percent")
}
if c.Gateway.UsageRecord.AutoScaleUpStep <= 0 {
return fmt.Errorf("gateway.usage_record.auto_scale_up_step must be positive")
}
if c.Gateway.UsageRecord.AutoScaleDownStep <= 0 {
return fmt.Errorf("gateway.usage_record.auto_scale_down_step must be positive")
}
if c.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds <= 0 {
return fmt.Errorf("gateway.usage_record.auto_scale_check_interval_seconds must be positive")
}
if c.Gateway.UsageRecord.AutoScaleCooldownSeconds < 0 {
return fmt.Errorf("gateway.usage_record.auto_scale_cooldown_seconds must be non-negative")
}
}
if c.Gateway.UserGroupRateCacheTTLSeconds <= 0 {
return fmt.Errorf("gateway.user_group_rate_cache_ttl_seconds must be positive")
}
if c.Gateway.ModelsListCacheTTLSeconds < 10 || c.Gateway.ModelsListCacheTTLSeconds > 30 {
return fmt.Errorf("gateway.models_list_cache_ttl_seconds must be between 10-30")
}
if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 {
return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive")
}
@@ -1420,6 +2023,6 @@ func warnIfInsecureURL(field, raw string) {
return
}
if strings.EqualFold(u.Scheme, "http") {
log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field)
slog.Warn("url uses http scheme; use https in production to avoid token leakage", "field", field)
}
}

View File

@@ -8,6 +8,25 @@ import (
"github.com/spf13/viper"
)
func resetViperWithJWTSecret(t *testing.T) {
t.Helper()
viper.Reset()
t.Setenv("JWT_SECRET", strings.Repeat("x", 32))
}
func TestLoadForBootstrapAllowsMissingJWTSecret(t *testing.T) {
viper.Reset()
t.Setenv("JWT_SECRET", "")
cfg, err := LoadForBootstrap()
if err != nil {
t.Fatalf("LoadForBootstrap() error: %v", err)
}
if cfg.JWT.Secret != "" {
t.Fatalf("LoadForBootstrap() should keep empty jwt.secret during bootstrap")
}
}
func TestNormalizeRunMode(t *testing.T) {
tests := []struct {
input string
@@ -29,7 +48,7 @@ func TestNormalizeRunMode(t *testing.T) {
}
func TestLoadDefaultSchedulingConfig(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -56,8 +75,44 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) {
}
}
func TestLoadDefaultIdempotencyConfig(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if !cfg.Idempotency.ObserveOnly {
t.Fatalf("Idempotency.ObserveOnly = false, want true")
}
if cfg.Idempotency.DefaultTTLSeconds != 86400 {
t.Fatalf("Idempotency.DefaultTTLSeconds = %d, want 86400", cfg.Idempotency.DefaultTTLSeconds)
}
if cfg.Idempotency.SystemOperationTTLSeconds != 3600 {
t.Fatalf("Idempotency.SystemOperationTTLSeconds = %d, want 3600", cfg.Idempotency.SystemOperationTTLSeconds)
}
}
func TestLoadIdempotencyConfigFromEnv(t *testing.T) {
resetViperWithJWTSecret(t)
t.Setenv("IDEMPOTENCY_OBSERVE_ONLY", "false")
t.Setenv("IDEMPOTENCY_DEFAULT_TTL_SECONDS", "600")
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Idempotency.ObserveOnly {
t.Fatalf("Idempotency.ObserveOnly = true, want false")
}
if cfg.Idempotency.DefaultTTLSeconds != 600 {
t.Fatalf("Idempotency.DefaultTTLSeconds = %d, want 600", cfg.Idempotency.DefaultTTLSeconds)
}
}
func TestLoadSchedulingConfigFromEnv(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5")
cfg, err := Load()
@@ -71,7 +126,7 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) {
}
func TestLoadDefaultSecurityToggles(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -87,13 +142,69 @@ func TestLoadDefaultSecurityToggles(t *testing.T) {
if !cfg.Security.URLAllowlist.AllowPrivateHosts {
t.Fatalf("URLAllowlist.AllowPrivateHosts = false, want true")
}
if cfg.Security.ResponseHeaders.Enabled {
t.Fatalf("ResponseHeaders.Enabled = true, want false")
if !cfg.Security.ResponseHeaders.Enabled {
t.Fatalf("ResponseHeaders.Enabled = false, want true")
}
}
func TestLoadDefaultServerMode(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Server.Mode != "release" {
t.Fatalf("Server.Mode = %q, want %q", cfg.Server.Mode, "release")
}
}
func TestLoadDefaultJWTAccessTokenExpireMinutes(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.JWT.ExpireHour != 24 {
t.Fatalf("JWT.ExpireHour = %d, want 24", cfg.JWT.ExpireHour)
}
if cfg.JWT.AccessTokenExpireMinutes != 0 {
t.Fatalf("JWT.AccessTokenExpireMinutes = %d, want 0", cfg.JWT.AccessTokenExpireMinutes)
}
}
func TestLoadJWTAccessTokenExpireMinutesFromEnv(t *testing.T) {
resetViperWithJWTSecret(t)
t.Setenv("JWT_ACCESS_TOKEN_EXPIRE_MINUTES", "90")
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.JWT.AccessTokenExpireMinutes != 90 {
t.Fatalf("JWT.AccessTokenExpireMinutes = %d, want 90", cfg.JWT.AccessTokenExpireMinutes)
}
}
func TestLoadDefaultDatabaseSSLMode(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Database.SSLMode != "prefer" {
t.Fatalf("Database.SSLMode = %q, want %q", cfg.Database.SSLMode, "prefer")
}
}
func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -118,7 +229,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
}
func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -143,7 +254,7 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
}
func TestLoadDefaultDashboardCacheConfig(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -168,7 +279,7 @@ func TestLoadDefaultDashboardCacheConfig(t *testing.T) {
}
func TestValidateDashboardCacheConfigEnabled(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -188,7 +299,7 @@ func TestValidateDashboardCacheConfigEnabled(t *testing.T) {
}
func TestValidateDashboardCacheConfigDisabled(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -207,7 +318,7 @@ func TestValidateDashboardCacheConfigDisabled(t *testing.T) {
}
func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -244,7 +355,7 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
}
func TestValidateDashboardAggregationConfigDisabled(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -263,7 +374,7 @@ func TestValidateDashboardAggregationConfigDisabled(t *testing.T) {
}
func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -282,7 +393,7 @@ func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) {
}
func TestLoadDefaultUsageCleanupConfig(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -307,7 +418,7 @@ func TestLoadDefaultUsageCleanupConfig(t *testing.T) {
}
func TestValidateUsageCleanupConfigEnabled(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -326,7 +437,7 @@ func TestValidateUsageCleanupConfigEnabled(t *testing.T) {
}
func TestValidateUsageCleanupConfigDisabled(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -424,6 +535,40 @@ func TestValidateAbsoluteHTTPURL(t *testing.T) {
}
}
func TestValidateServerFrontendURL(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Server.FrontendURL = "https://example.com"
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate() frontend_url valid error: %v", err)
}
cfg.Server.FrontendURL = "https://example.com/path"
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate() frontend_url with path valid error: %v", err)
}
cfg.Server.FrontendURL = "https://example.com?utm=1"
if err := cfg.Validate(); err == nil {
t.Fatalf("Validate() should reject server.frontend_url with query")
}
cfg.Server.FrontendURL = "https://user:pass@example.com"
if err := cfg.Validate(); err == nil {
t.Fatalf("Validate() should reject server.frontend_url with userinfo")
}
cfg.Server.FrontendURL = "/relative"
if err := cfg.Validate(); err == nil {
t.Fatalf("Validate() should reject relative server.frontend_url")
}
}
func TestValidateFrontendRedirectURL(t *testing.T) {
if err := ValidateFrontendRedirectURL("/auth/callback"); err != nil {
t.Fatalf("ValidateFrontendRedirectURL relative error: %v", err)
@@ -445,6 +590,7 @@ func TestValidateFrontendRedirectURL(t *testing.T) {
func TestWarnIfInsecureURL(t *testing.T) {
warnIfInsecureURL("test", "http://example.com")
warnIfInsecureURL("test", "bad://url")
warnIfInsecureURL("test", "://invalid")
}
func TestGenerateJWTSecretDefaultLength(t *testing.T) {
@@ -458,7 +604,7 @@ func TestGenerateJWTSecretDefaultLength(t *testing.T) {
}
func TestValidateOpsCleanupScheduleRequired(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -476,7 +622,7 @@ func TestValidateOpsCleanupScheduleRequired(t *testing.T) {
}
func TestValidateConcurrencyPingInterval(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -493,14 +639,14 @@ func TestValidateConcurrencyPingInterval(t *testing.T) {
}
func TestProvideConfig(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
if _, err := ProvideConfig(); err != nil {
t.Fatalf("ProvideConfig() error: %v", err)
}
}
func TestValidateConfigWithLinuxDoEnabled(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -544,6 +690,24 @@ func TestGenerateJWTSecretWithLength(t *testing.T) {
}
}
func TestDatabaseDSNWithTimezone_WithPassword(t *testing.T) {
d := &DatabaseConfig{
Host: "localhost",
Port: 5432,
User: "u",
Password: "p",
DBName: "db",
SSLMode: "prefer",
}
got := d.DSNWithTimezone("UTC")
if !strings.Contains(got, "password=p") {
t.Fatalf("DSNWithTimezone should include password: %q", got)
}
if !strings.Contains(got, "TimeZone=UTC") {
t.Fatalf("DSNWithTimezone should include TimeZone=UTC: %q", got)
}
}
func TestValidateAbsoluteHTTPURLMissingHost(t *testing.T) {
if err := ValidateAbsoluteHTTPURL("https://"); err == nil {
t.Fatalf("ValidateAbsoluteHTTPURL should reject missing host")
@@ -566,10 +730,35 @@ func TestWarnIfInsecureURLHTTPS(t *testing.T) {
warnIfInsecureURL("secure", "https://example.com")
}
func TestValidateJWTSecret_UTF8Bytes(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
// 31 bytes (< 32) even though it's 31 characters.
cfg.JWT.Secret = strings.Repeat("a", 31)
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() should reject 31-byte secret")
}
if !strings.Contains(err.Error(), "at least 32 bytes") {
t.Fatalf("Validate() error = %v", err)
}
// 32 bytes OK.
cfg.JWT.Secret = strings.Repeat("a", 32)
err = cfg.Validate()
if err != nil {
t.Fatalf("Validate() should accept 32-byte secret: %v", err)
}
}
func TestValidateConfigErrors(t *testing.T) {
buildValid := func(t *testing.T) *Config {
t.Helper()
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
@@ -582,6 +771,26 @@ func TestValidateConfigErrors(t *testing.T) {
mutate func(*Config)
wantErr string
}{
{
name: "jwt secret required",
mutate: func(c *Config) { c.JWT.Secret = "" },
wantErr: "jwt.secret is required",
},
{
name: "jwt secret min bytes",
mutate: func(c *Config) { c.JWT.Secret = strings.Repeat("a", 31) },
wantErr: "jwt.secret must be at least 32 bytes",
},
{
name: "subscription maintenance worker_count non-negative",
mutate: func(c *Config) { c.SubscriptionMaintenance.WorkerCount = -1 },
wantErr: "subscription_maintenance.worker_count",
},
{
name: "subscription maintenance queue_size non-negative",
mutate: func(c *Config) { c.SubscriptionMaintenance.QueueSize = -1 },
wantErr: "subscription_maintenance.queue_size",
},
{
name: "jwt expire hour positive",
mutate: func(c *Config) { c.JWT.ExpireHour = 0 },
@@ -592,6 +801,11 @@ func TestValidateConfigErrors(t *testing.T) {
mutate: func(c *Config) { c.JWT.ExpireHour = 200 },
wantErr: "jwt.expire_hour must be <= 168",
},
{
name: "jwt access token expire minutes non-negative",
mutate: func(c *Config) { c.JWT.AccessTokenExpireMinutes = -1 },
wantErr: "jwt.access_token_expire_minutes must be non-negative",
},
{
name: "csp policy required",
mutate: func(c *Config) { c.Security.CSP.Enabled = true; c.Security.CSP.Policy = "" },
@@ -799,6 +1013,84 @@ func TestValidateConfigErrors(t *testing.T) {
mutate: func(c *Config) { c.Gateway.MaxLineSize = -1 },
wantErr: "gateway.max_line_size must be non-negative",
},
{
name: "gateway usage record worker count",
mutate: func(c *Config) { c.Gateway.UsageRecord.WorkerCount = 0 },
wantErr: "gateway.usage_record.worker_count",
},
{
name: "gateway usage record queue size",
mutate: func(c *Config) { c.Gateway.UsageRecord.QueueSize = 0 },
wantErr: "gateway.usage_record.queue_size",
},
{
name: "gateway usage record timeout",
mutate: func(c *Config) { c.Gateway.UsageRecord.TaskTimeoutSeconds = 0 },
wantErr: "gateway.usage_record.task_timeout_seconds",
},
{
name: "gateway usage record overflow policy",
mutate: func(c *Config) { c.Gateway.UsageRecord.OverflowPolicy = "invalid" },
wantErr: "gateway.usage_record.overflow_policy",
},
{
name: "gateway usage record sample percent range",
mutate: func(c *Config) { c.Gateway.UsageRecord.OverflowSamplePercent = 101 },
wantErr: "gateway.usage_record.overflow_sample_percent",
},
{
name: "gateway usage record sample percent required for sample policy",
mutate: func(c *Config) {
c.Gateway.UsageRecord.OverflowPolicy = UsageRecordOverflowPolicySample
c.Gateway.UsageRecord.OverflowSamplePercent = 0
},
wantErr: "gateway.usage_record.overflow_sample_percent must be positive",
},
{
name: "gateway usage record auto scale max gte min",
mutate: func(c *Config) {
c.Gateway.UsageRecord.AutoScaleMinWorkers = 256
c.Gateway.UsageRecord.AutoScaleMaxWorkers = 128
},
wantErr: "gateway.usage_record.auto_scale_max_workers",
},
{
name: "gateway usage record worker in auto scale range",
mutate: func(c *Config) {
c.Gateway.UsageRecord.AutoScaleMinWorkers = 200
c.Gateway.UsageRecord.AutoScaleMaxWorkers = 300
c.Gateway.UsageRecord.WorkerCount = 128
},
wantErr: "gateway.usage_record.worker_count must be between auto_scale_min_workers and auto_scale_max_workers",
},
{
name: "gateway usage record auto scale queue thresholds order",
mutate: func(c *Config) {
c.Gateway.UsageRecord.AutoScaleUpQueuePercent = 50
c.Gateway.UsageRecord.AutoScaleDownQueuePercent = 50
},
wantErr: "gateway.usage_record.auto_scale_down_queue_percent must be less",
},
{
name: "gateway usage record auto scale up step",
mutate: func(c *Config) { c.Gateway.UsageRecord.AutoScaleUpStep = 0 },
wantErr: "gateway.usage_record.auto_scale_up_step",
},
{
name: "gateway usage record auto scale interval",
mutate: func(c *Config) { c.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds = 0 },
wantErr: "gateway.usage_record.auto_scale_check_interval_seconds",
},
{
name: "gateway user group rate cache ttl",
mutate: func(c *Config) { c.Gateway.UserGroupRateCacheTTLSeconds = 0 },
wantErr: "gateway.user_group_rate_cache_ttl_seconds",
},
{
name: "gateway models list cache ttl range",
mutate: func(c *Config) { c.Gateway.ModelsListCacheTTLSeconds = 31 },
wantErr: "gateway.models_list_cache_ttl_seconds",
},
{
name: "gateway scheduling sticky waiting",
mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 },
@@ -822,6 +1114,37 @@ func TestValidateConfigErrors(t *testing.T) {
},
wantErr: "gateway.scheduling.outbox_lag_rebuild_seconds",
},
{
name: "log level invalid",
mutate: func(c *Config) { c.Log.Level = "trace" },
wantErr: "log.level",
},
{
name: "log format invalid",
mutate: func(c *Config) { c.Log.Format = "plain" },
wantErr: "log.format",
},
{
name: "log output disabled",
mutate: func(c *Config) {
c.Log.Output.ToStdout = false
c.Log.Output.ToFile = false
},
wantErr: "log.output.to_stdout and log.output.to_file cannot both be false",
},
{
name: "log rotation size",
mutate: func(c *Config) { c.Log.Rotation.MaxSizeMB = 0 },
wantErr: "log.rotation.max_size_mb",
},
{
name: "log sampling enabled invalid",
mutate: func(c *Config) {
c.Log.Sampling.Enabled = true
c.Log.Sampling.Initial = 0
},
wantErr: "log.sampling.initial",
},
{
name: "ops metrics collector ttl",
mutate: func(c *Config) { c.Ops.MetricsCollectorCache.TTL = -1 },
@@ -850,3 +1173,234 @@ func TestValidateConfigErrors(t *testing.T) {
})
}
}
func TestValidateConfig_AutoScaleDisabledIgnoreAutoScaleFields(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Gateway.UsageRecord.AutoScaleEnabled = false
cfg.Gateway.UsageRecord.WorkerCount = 64
// 自动扩缩容关闭时,这些字段应被忽略,不应导致校验失败。
cfg.Gateway.UsageRecord.AutoScaleMinWorkers = 0
cfg.Gateway.UsageRecord.AutoScaleMaxWorkers = 0
cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent = 0
cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent = 100
cfg.Gateway.UsageRecord.AutoScaleUpStep = 0
cfg.Gateway.UsageRecord.AutoScaleDownStep = 0
cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds = 0
cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds = -1
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate() should ignore auto scale fields when disabled: %v", err)
}
}
func TestValidateConfig_LogRequiredAndRotationBounds(t *testing.T) {
resetViperWithJWTSecret(t)
cases := []struct {
name string
mutate func(*Config)
wantErr string
}{
{
name: "log level required",
mutate: func(c *Config) {
c.Log.Level = ""
},
wantErr: "log.level is required",
},
{
name: "log format required",
mutate: func(c *Config) {
c.Log.Format = ""
},
wantErr: "log.format is required",
},
{
name: "log stacktrace required",
mutate: func(c *Config) {
c.Log.StacktraceLevel = ""
},
wantErr: "log.stacktrace_level is required",
},
{
name: "log max backups non-negative",
mutate: func(c *Config) {
c.Log.Rotation.MaxBackups = -1
},
wantErr: "log.rotation.max_backups must be non-negative",
},
{
name: "log max age non-negative",
mutate: func(c *Config) {
c.Log.Rotation.MaxAgeDays = -1
},
wantErr: "log.rotation.max_age_days must be non-negative",
},
{
name: "sampling thereafter non-negative when disabled",
mutate: func(c *Config) {
c.Log.Sampling.Enabled = false
c.Log.Sampling.Thereafter = -1
},
wantErr: "log.sampling.thereafter must be non-negative",
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
tt.mutate(cfg)
err = cfg.Validate()
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("Validate() error = %v, want %q", err, tt.wantErr)
}
})
}
}
func TestSoraCurlCFFISidecarDefaults(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if !cfg.Sora.Client.CurlCFFISidecar.Enabled {
t.Fatalf("Sora curl_cffi sidecar should be enabled by default")
}
if cfg.Sora.Client.CloudflareChallengeCooldownSeconds <= 0 {
t.Fatalf("Sora cloudflare challenge cooldown should be positive by default")
}
if cfg.Sora.Client.CurlCFFISidecar.BaseURL == "" {
t.Fatalf("Sora curl_cffi sidecar base_url should not be empty by default")
}
if cfg.Sora.Client.CurlCFFISidecar.Impersonate == "" {
t.Fatalf("Sora curl_cffi sidecar impersonate should not be empty by default")
}
if !cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled {
t.Fatalf("Sora curl_cffi sidecar session reuse should be enabled by default")
}
if cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds <= 0 {
t.Fatalf("Sora curl_cffi sidecar session ttl should be positive by default")
}
}
func TestValidateSoraCurlCFFISidecarRequired(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Sora.Client.CurlCFFISidecar.Enabled = false
err = cfg.Validate()
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.enabled must be true") {
t.Fatalf("Validate() error = %v, want sidecar enabled error", err)
}
}
func TestValidateSoraCurlCFFISidecarBaseURLRequired(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Sora.Client.CurlCFFISidecar.BaseURL = " "
err = cfg.Validate()
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.base_url is required") {
t.Fatalf("Validate() error = %v, want sidecar base_url required error", err)
}
}
func TestValidateSoraCurlCFFISidecarSessionTTLNonNegative(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds = -1
err = cfg.Validate()
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") {
t.Fatalf("Validate() error = %v, want sidecar session ttl error", err)
}
}
func TestValidateSoraCloudflareChallengeCooldownNonNegative(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Sora.Client.CloudflareChallengeCooldownSeconds = -1
err = cfg.Validate()
if err == nil || !strings.Contains(err.Error(), "sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") {
t.Fatalf("Validate() error = %v, want cloudflare cooldown error", err)
}
}
func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Gateway.UsageRecord.WorkerCount != 128 {
t.Fatalf("worker_count = %d, want 128", cfg.Gateway.UsageRecord.WorkerCount)
}
if cfg.Gateway.UsageRecord.QueueSize != 16384 {
t.Fatalf("queue_size = %d, want 16384", cfg.Gateway.UsageRecord.QueueSize)
}
if cfg.Gateway.UsageRecord.TaskTimeoutSeconds != 5 {
t.Fatalf("task_timeout_seconds = %d, want 5", cfg.Gateway.UsageRecord.TaskTimeoutSeconds)
}
if cfg.Gateway.UsageRecord.OverflowPolicy != UsageRecordOverflowPolicySample {
t.Fatalf("overflow_policy = %s, want %s", cfg.Gateway.UsageRecord.OverflowPolicy, UsageRecordOverflowPolicySample)
}
if cfg.Gateway.UsageRecord.OverflowSamplePercent != 10 {
t.Fatalf("overflow_sample_percent = %d, want 10", cfg.Gateway.UsageRecord.OverflowSamplePercent)
}
if !cfg.Gateway.UsageRecord.AutoScaleEnabled {
t.Fatalf("auto_scale_enabled = false, want true")
}
if cfg.Gateway.UsageRecord.AutoScaleMinWorkers != 128 {
t.Fatalf("auto_scale_min_workers = %d, want 128", cfg.Gateway.UsageRecord.AutoScaleMinWorkers)
}
if cfg.Gateway.UsageRecord.AutoScaleMaxWorkers != 512 {
t.Fatalf("auto_scale_max_workers = %d, want 512", cfg.Gateway.UsageRecord.AutoScaleMaxWorkers)
}
if cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent != 70 {
t.Fatalf("auto_scale_up_queue_percent = %d, want 70", cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent)
}
if cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent != 15 {
t.Fatalf("auto_scale_down_queue_percent = %d, want 15", cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent)
}
if cfg.Gateway.UsageRecord.AutoScaleUpStep != 32 {
t.Fatalf("auto_scale_up_step = %d, want 32", cfg.Gateway.UsageRecord.AutoScaleUpStep)
}
if cfg.Gateway.UsageRecord.AutoScaleDownStep != 16 {
t.Fatalf("auto_scale_down_step = %d, want 16", cfg.Gateway.UsageRecord.AutoScaleDownStep)
}
if cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds != 3 {
t.Fatalf("auto_scale_check_interval_seconds = %d, want 3", cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds)
}
if cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds != 10 {
t.Fatalf("auto_scale_cooldown_seconds = %d, want 10", cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds)
}
}

View File

@@ -9,5 +9,5 @@ var ProviderSet = wire.NewSet(
// ProvideConfig 提供应用配置
func ProvideConfig() (*Config, error) {
return Load()
return LoadForBootstrap()
}

View File

@@ -22,6 +22,7 @@ const (
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
PlatformAntigravity = "antigravity"
PlatformSora = "sora"
)
// Account type constants
@@ -73,6 +74,7 @@ var DefaultAntigravityModelMapping = map[string]string{
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking", // 官方模型
"claude-opus-4-6": "claude-opus-4-6-thinking", // 简称映射
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking", // 迁移旧模型
"claude-sonnet-4-6": "claude-sonnet-4-6",
"claude-sonnet-4-5": "claude-sonnet-4-5",
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
// Claude 详细版本 ID 映射
@@ -87,14 +89,21 @@ var DefaultAntigravityModelMapping = map[string]string{
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
"gemini-2.5-pro": "gemini-2.5-pro",
// Gemini 3 白名单
"gemini-3-flash": "gemini-3-flash",
"gemini-3-pro-high": "gemini-3-pro-high",
"gemini-3-pro-low": "gemini-3-pro-low",
"gemini-3-pro-image": "gemini-3-pro-image",
"gemini-3-flash": "gemini-3-flash",
"gemini-3-pro-high": "gemini-3-pro-high",
"gemini-3-pro-low": "gemini-3-pro-low",
// Gemini 3 preview 映射
"gemini-3-flash-preview": "gemini-3-flash",
"gemini-3-pro-preview": "gemini-3-pro-high",
"gemini-3-pro-image-preview": "gemini-3-pro-image",
"gemini-3-flash-preview": "gemini-3-flash",
"gemini-3-pro-preview": "gemini-3-pro-high",
// Gemini 3.1 白名单
"gemini-3.1-pro-high": "gemini-3.1-pro-high",
"gemini-3.1-pro-low": "gemini-3.1-pro-low",
// Gemini 3.1 preview 映射
"gemini-3.1-pro-preview": "gemini-3.1-pro-high",
// Gemini 3.1 image 白名单
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
// Gemini 3.1 image preview 映射
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
// 其他官方模型
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
"tab_flash_lite_preview": "tab_flash_lite_preview",

View File

@@ -175,22 +175,28 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
return
}
dataPayload := req.Data
if err := validateDataHeader(dataPayload); err != nil {
if err := validateDataHeader(req.Data); err != nil {
response.BadRequest(c, err.Error())
return
}
executeAdminIdempotentJSON(c, "admin.accounts.import_data", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
return h.importData(ctx, req)
})
}
func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest) (DataImportResult, error) {
skipDefaultGroupBind := true
if req.SkipDefaultGroupBind != nil {
skipDefaultGroupBind = *req.SkipDefaultGroupBind
}
dataPayload := req.Data
result := DataImportResult{}
existingProxies, err := h.listAllProxies(c.Request.Context())
existingProxies, err := h.listAllProxies(ctx)
if err != nil {
response.ErrorFrom(c, err)
return
return result, err
}
proxyKeyToID := make(map[string]int64, len(existingProxies))
@@ -221,8 +227,8 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
proxyKeyToID[key] = existingID
result.ProxyReused++
if normalizedStatus != "" {
if proxy, err := h.adminService.GetProxy(c.Request.Context(), existingID); err == nil && proxy != nil && proxy.Status != normalizedStatus {
_, _ = h.adminService.UpdateProxy(c.Request.Context(), existingID, &service.UpdateProxyInput{
if proxy, getErr := h.adminService.GetProxy(ctx, existingID); getErr == nil && proxy != nil && proxy.Status != normalizedStatus {
_, _ = h.adminService.UpdateProxy(ctx, existingID, &service.UpdateProxyInput{
Status: normalizedStatus,
})
}
@@ -230,7 +236,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
continue
}
created, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
created, createErr := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{
Name: defaultProxyName(item.Name),
Protocol: item.Protocol,
Host: item.Host,
@@ -238,13 +244,13 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
Username: item.Username,
Password: item.Password,
})
if err != nil {
if createErr != nil {
result.ProxyFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "proxy",
Name: item.Name,
ProxyKey: key,
Message: err.Error(),
Message: createErr.Error(),
})
continue
}
@@ -252,7 +258,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
result.ProxyCreated++
if normalizedStatus != "" && normalizedStatus != created.Status {
_, _ = h.adminService.UpdateProxy(c.Request.Context(), created.ID, &service.UpdateProxyInput{
_, _ = h.adminService.UpdateProxy(ctx, created.ID, &service.UpdateProxyInput{
Status: normalizedStatus,
})
}
@@ -303,7 +309,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
SkipDefaultGroupBind: skipDefaultGroupBind,
}
if _, err := h.adminService.CreateAccount(c.Request.Context(), accountInput); err != nil {
if _, err := h.adminService.CreateAccount(ctx, accountInput); err != nil {
result.AccountFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "account",
@@ -315,7 +321,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
result.AccountCreated++
}
response.Success(c, result)
return result, nil
}
func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, error) {

View File

@@ -2,7 +2,13 @@
package admin
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
@@ -10,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
@@ -133,6 +140,13 @@ type BulkUpdateAccountsRequest struct {
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
}
// CheckMixedChannelRequest represents check mixed channel risk request
type CheckMixedChannelRequest struct {
Platform string `json:"platform" binding:"required"`
GroupIDs []int64 `json:"group_ids"`
AccountID *int64 `json:"account_id"`
}
// AccountWithConcurrency extends Account with real-time concurrency info
type AccountWithConcurrency struct {
*dto.Account
@@ -142,6 +156,44 @@ type AccountWithConcurrency struct {
ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数
}
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
item := AccountWithConcurrency{
Account: dto.AccountFromService(account),
CurrentConcurrency: 0,
}
if account == nil {
return item
}
if h.concurrencyService != nil {
if counts, err := h.concurrencyService.GetAccountConcurrencyBatch(ctx, []int64{account.ID}); err == nil {
item.CurrentConcurrency = counts[account.ID]
}
}
if account.IsAnthropicOAuthOrSetupToken() {
if h.accountUsageService != nil && account.GetWindowCostLimit() > 0 {
startTime := account.GetCurrentWindowStartTime()
if stats, err := h.accountUsageService.GetAccountWindowStats(ctx, account.ID, startTime); err == nil && stats != nil {
cost := stats.StandardCost
item.CurrentWindowCost = &cost
}
}
if h.sessionLimitCache != nil && account.GetMaxSessions() > 0 {
idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute
idleTimeouts := map[int64]time.Duration{account.ID: idleTimeout}
if sessions, err := h.sessionLimitCache.GetActiveSessionCountBatch(ctx, []int64{account.ID}, idleTimeouts); err == nil {
if count, ok := sessions[account.ID]; ok {
item.ActiveSessions = &count
}
}
}
}
return item
}
// List handles listing all accounts with pagination
// GET /api/v1/admin/accounts
func (h *AccountHandler) List(c *gin.Context) {
@@ -262,9 +314,71 @@ func (h *AccountHandler) List(c *gin.Context) {
result[i] = item
}
etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search)
if etag != "" {
c.Header("ETag", etag)
c.Header("Vary", "If-None-Match")
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), etag) {
c.Status(http.StatusNotModified)
return
}
}
response.Paginated(c, result, total, page, pageSize)
}
func buildAccountsListETag(
items []AccountWithConcurrency,
total int64,
page, pageSize int,
platform, accountType, status, search string,
) string {
payload := struct {
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
Platform string `json:"platform"`
AccountType string `json:"type"`
Status string `json:"status"`
Search string `json:"search"`
Items []AccountWithConcurrency `json:"items"`
}{
Total: total,
Page: page,
PageSize: pageSize,
Platform: platform,
AccountType: accountType,
Status: status,
Search: search,
Items: items,
}
raw, err := json.Marshal(payload)
if err != nil {
return ""
}
sum := sha256.Sum256(raw)
return "\"" + hex.EncodeToString(sum[:]) + "\""
}
func ifNoneMatchMatched(ifNoneMatch, etag string) bool {
if etag == "" || ifNoneMatch == "" {
return false
}
for _, token := range strings.Split(ifNoneMatch, ",") {
candidate := strings.TrimSpace(token)
if candidate == "*" {
return true
}
if candidate == etag {
return true
}
if strings.HasPrefix(candidate, "W/") && strings.TrimPrefix(candidate, "W/") == etag {
return true
}
}
return false
}
// GetByID handles getting an account by ID
// GET /api/v1/admin/accounts/:id
func (h *AccountHandler) GetByID(c *gin.Context) {
@@ -280,7 +394,51 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.AccountFromService(account))
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// CheckMixedChannel handles checking mixed channel risk for account-group binding.
// POST /api/v1/admin/accounts/check-mixed-channel
func (h *AccountHandler) CheckMixedChannel(c *gin.Context) {
var req CheckMixedChannelRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.GroupIDs) == 0 {
response.Success(c, gin.H{"has_risk": false})
return
}
accountID := int64(0)
if req.AccountID != nil {
accountID = *req.AccountID
}
err := h.adminService.CheckMixedChannelRisk(c.Request.Context(), accountID, req.Platform, req.GroupIDs)
if err != nil {
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
response.Success(c, gin.H{
"has_risk": true,
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
})
return
}
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"has_risk": false})
}
// Create handles creating a new account
@@ -299,46 +457,51 @@ func (h *AccountHandler) Create(c *gin.Context) {
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: req.Name,
Notes: req.Notes,
Platform: req.Platform,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
RateMultiplier: req.RateMultiplier,
GroupIDs: req.GroupIDs,
ExpiresAt: req.ExpiresAt,
AutoPauseOnExpired: req.AutoPauseOnExpired,
SkipMixedChannelCheck: skipCheck,
result, err := executeAdminIdempotent(c, "admin.accounts.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
account, execErr := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
Name: req.Name,
Notes: req.Notes,
Platform: req.Platform,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
RateMultiplier: req.RateMultiplier,
GroupIDs: req.GroupIDs,
ExpiresAt: req.ExpiresAt,
AutoPauseOnExpired: req.AutoPauseOnExpired,
SkipMixedChannelCheck: skipCheck,
})
if execErr != nil {
return nil, execErr
}
return h.buildAccountResponseWithRuntime(ctx, account), nil
})
if err != nil {
// 检查是否为混合渠道错误
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
// 返回特殊错误码要求确认
// 创建接口仅返回最小必要字段,详细信息由专门检查接口提供
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
"require_confirmation": true,
})
return
}
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.AccountFromService(account))
if result != nil && result.Replayed {
c.Header("X-Idempotency-Replayed", "true")
}
response.Success(c, result.Data)
}
// Update handles updating an account
@@ -383,17 +546,10 @@ func (h *AccountHandler) Update(c *gin.Context) {
// 检查是否为混合渠道错误
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
// 返回特殊错误码要求确认
// 更新接口仅返回最小必要字段,详细信息由专门检查接口提供
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
"require_confirmation": true,
})
return
}
@@ -402,7 +558,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
return
}
response.Success(c, dto.AccountFromService(account))
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// Delete handles deleting an account
@@ -660,7 +816,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
}
}
response.Success(c, dto.AccountFromService(updatedAccount))
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount))
}
// GetStats handles getting account statistics
@@ -718,7 +874,7 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
}
}
response.Success(c, dto.AccountFromService(account))
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// BatchCreate handles batch creating accounts
@@ -732,61 +888,62 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
return
}
ctx := c.Request.Context()
success := 0
failed := 0
results := make([]gin.H, 0, len(req.Accounts))
executeAdminIdempotentJSON(c, "admin.accounts.batch_create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
success := 0
failed := 0
results := make([]gin.H, 0, len(req.Accounts))
for _, item := range req.Accounts {
if item.RateMultiplier != nil && *item.RateMultiplier < 0 {
failed++
for _, item := range req.Accounts {
if item.RateMultiplier != nil && *item.RateMultiplier < 0 {
failed++
results = append(results, gin.H{
"name": item.Name,
"success": false,
"error": "rate_multiplier must be >= 0",
})
continue
}
skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk
account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
Name: item.Name,
Notes: item.Notes,
Platform: item.Platform,
Type: item.Type,
Credentials: item.Credentials,
Extra: item.Extra,
ProxyID: item.ProxyID,
Concurrency: item.Concurrency,
Priority: item.Priority,
RateMultiplier: item.RateMultiplier,
GroupIDs: item.GroupIDs,
ExpiresAt: item.ExpiresAt,
AutoPauseOnExpired: item.AutoPauseOnExpired,
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
failed++
results = append(results, gin.H{
"name": item.Name,
"success": false,
"error": err.Error(),
})
continue
}
success++
results = append(results, gin.H{
"name": item.Name,
"success": false,
"error": "rate_multiplier must be >= 0",
"id": account.ID,
"success": true,
})
continue
}
skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk
account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
Name: item.Name,
Notes: item.Notes,
Platform: item.Platform,
Type: item.Type,
Credentials: item.Credentials,
Extra: item.Extra,
ProxyID: item.ProxyID,
Concurrency: item.Concurrency,
Priority: item.Priority,
RateMultiplier: item.RateMultiplier,
GroupIDs: item.GroupIDs,
ExpiresAt: item.ExpiresAt,
AutoPauseOnExpired: item.AutoPauseOnExpired,
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
failed++
results = append(results, gin.H{
"name": item.Name,
"success": false,
"error": err.Error(),
})
continue
}
success++
results = append(results, gin.H{
"name": item.Name,
"id": account.ID,
"success": true,
})
}
response.Success(c, gin.H{
"success": success,
"failed": failed,
"results": results,
return gin.H{
"success": success,
"failed": failed,
"results": results,
}, nil
})
}
@@ -824,57 +981,58 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
}
ctx := c.Request.Context()
success := 0
failed := 0
results := []gin.H{}
// 阶段一:预验证所有账号存在,收集 credentials
type accountUpdate struct {
ID int64
Credentials map[string]any
}
updates := make([]accountUpdate, 0, len(req.AccountIDs))
for _, accountID := range req.AccountIDs {
// Get account
account, err := h.adminService.GetAccount(ctx, accountID)
if err != nil {
failed++
results = append(results, gin.H{
"account_id": accountID,
"success": false,
"error": "Account not found",
})
continue
response.Error(c, 404, fmt.Sprintf("Account %d not found", accountID))
return
}
// Update credentials field
if account.Credentials == nil {
account.Credentials = make(map[string]any)
}
account.Credentials[req.Field] = req.Value
updates = append(updates, accountUpdate{ID: accountID, Credentials: account.Credentials})
}
// Update account
updateInput := &service.UpdateAccountInput{
Credentials: account.Credentials,
}
_, err = h.adminService.UpdateAccount(ctx, accountID, updateInput)
if err != nil {
// 阶段二:依次更新,返回每个账号的成功/失败明细,便于调用方重试
success := 0
failed := 0
successIDs := make([]int64, 0, len(updates))
failedIDs := make([]int64, 0, len(updates))
results := make([]gin.H, 0, len(updates))
for _, u := range updates {
updateInput := &service.UpdateAccountInput{Credentials: u.Credentials}
if _, err := h.adminService.UpdateAccount(ctx, u.ID, updateInput); err != nil {
failed++
failedIDs = append(failedIDs, u.ID)
results = append(results, gin.H{
"account_id": accountID,
"account_id": u.ID,
"success": false,
"error": err.Error(),
})
continue
}
success++
successIDs = append(successIDs, u.ID)
results = append(results, gin.H{
"account_id": accountID,
"account_id": u.ID,
"success": true,
})
}
response.Success(c, gin.H{
"success": success,
"failed": failed,
"results": results,
"success": success,
"failed": failed,
"success_ids": successIDs,
"failed_ids": failedIDs,
"results": results,
})
}
@@ -1109,7 +1267,13 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
return
}
response.Success(c, gin.H{"message": "Rate limit cleared successfully"})
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// GetTempUnschedulable handles getting temporary unschedulable status
@@ -1199,7 +1363,7 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
return
}
response.Success(c, dto.AccountFromService(account))
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// GetAvailableModels handles getting available models for an account
@@ -1296,32 +1460,14 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
// Handle Antigravity accounts: return Claude + Gemini models
if account.Platform == service.PlatformAntigravity {
// Antigravity 支持 Claude 和部分 Gemini 模型
type UnifiedModel struct {
ID string `json:"id"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
}
// 直接复用 antigravity.DefaultModels(),与 /v1/models 端点保持同步
response.Success(c, antigravity.DefaultModels())
return
}
var models []UnifiedModel
// 添加 Claude 模型
for _, m := range claude.DefaultModels {
models = append(models, UnifiedModel{
ID: m.ID,
Type: m.Type,
DisplayName: m.DisplayName,
})
}
// 添加 Gemini 3 系列模型用于测试
geminiTestModels := []UnifiedModel{
{ID: "gemini-3-flash", Type: "model", DisplayName: "Gemini 3 Flash"},
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview"},
}
models = append(models, geminiTestModels...)
response.Success(c, models)
// Handle Sora accounts
if account.Platform == service.PlatformSora {
response.Success(c, service.DefaultSoraModels(nil))
return
}

View File

@@ -0,0 +1,147 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func setupAccountMixedChannelRouter(adminSvc *stubAdminService) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/check-mixed-channel", accountHandler.CheckMixedChannel)
router.POST("/api/v1/admin/accounts", accountHandler.Create)
router.PUT("/api/v1/admin/accounts/:id", accountHandler.Update)
return router
}
func TestAccountHandlerCheckMixedChannelNoRisk(t *testing.T) {
adminSvc := newStubAdminService()
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"platform": "antigravity",
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, float64(0), resp["code"])
data, ok := resp["data"].(map[string]any)
require.True(t, ok)
require.Equal(t, false, data["has_risk"])
require.Equal(t, int64(0), adminSvc.lastMixedCheck.accountID)
require.Equal(t, "antigravity", adminSvc.lastMixedCheck.platform)
require.Equal(t, []int64{27}, adminSvc.lastMixedCheck.groupIDs)
}
func TestAccountHandlerCheckMixedChannelWithRisk(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.checkMixedErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"platform": "antigravity",
"group_ids": []int64{27},
"account_id": 99,
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, float64(0), resp["code"])
data, ok := resp["data"].(map[string]any)
require.True(t, ok)
require.Equal(t, true, data["has_risk"])
require.Equal(t, "mixed_channel_warning", data["error"])
details, ok := data["details"].(map[string]any)
require.True(t, ok)
require.Equal(t, float64(27), details["group_id"])
require.Equal(t, "claude-max", details["group_name"])
require.Equal(t, "Antigravity", details["current_platform"])
require.Equal(t, "Anthropic", details["other_platform"])
require.Equal(t, int64(99), adminSvc.lastMixedCheck.accountID)
}
func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.createAccountErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"name": "ag-oauth-1",
"platform": "antigravity",
"type": "oauth",
"credentials": map[string]any{"refresh_token": "rt"},
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusConflict, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "mixed_channel_warning", resp["error"])
require.Contains(t, resp["message"], "mixed_channel_warning")
_, hasDetails := resp["details"]
_, hasRequireConfirmation := resp["require_confirmation"]
require.False(t, hasDetails)
require.False(t, hasRequireConfirmation)
}
func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.updateAccountErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/accounts/3", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusConflict, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "mixed_channel_warning", resp["error"])
require.Contains(t, resp["message"], "mixed_channel_warning")
_, hasDetails := resp["details"]
_, hasRequireConfirmation := resp["require_confirmation"]
require.False(t, hasDetails)
require.False(t, hasRequireConfirmation)
}

View File

@@ -0,0 +1,66 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestAccountHandler_Create_AnthropicAPIKeyPassthroughExtraForwarded(t *testing.T) {
gin.SetMode(gin.TestMode)
adminSvc := newStubAdminService()
handler := NewAccountHandler(
adminSvc,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
)
router := gin.New()
router.POST("/api/v1/admin/accounts", handler.Create)
body := map[string]any{
"name": "anthropic-key-1",
"platform": "anthropic",
"type": "apikey",
"credentials": map[string]any{
"api_key": "sk-ant-xxx",
"base_url": "https://api.anthropic.com",
},
"extra": map[string]any{
"anthropic_passthrough": true,
},
"concurrency": 1,
"priority": 1,
}
raw, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(raw))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Len(t, adminSvc.createdAccounts, 1)
created := adminSvc.createdAccounts[0]
require.Equal(t, "anthropic", created.Platform)
require.Equal(t, "apikey", created.Type)
require.NotNil(t, created.Extra)
require.Equal(t, true, created.Extra["anthropic_passthrough"])
}

View File

@@ -47,6 +47,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete)
router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete)
router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test)
router.POST("/api/v1/admin/proxies/:id/quality-check", proxyHandler.CheckQuality)
router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats)
router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts)
@@ -208,6 +209,11 @@ func TestProxyHandlerEndpoints(t *testing.T) {
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/quality-check", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil)
router.ServeHTTP(rec, req)

View File

@@ -58,6 +58,96 @@ func TestParseOpsDuration(t *testing.T) {
require.False(t, ok)
}
func TestParseOpsOpenAITokenStatsDuration(t *testing.T) {
tests := []struct {
input string
want time.Duration
ok bool
}{
{input: "30m", want: 30 * time.Minute, ok: true},
{input: "1h", want: time.Hour, ok: true},
{input: "1d", want: 24 * time.Hour, ok: true},
{input: "15d", want: 15 * 24 * time.Hour, ok: true},
{input: "30d", want: 30 * 24 * time.Hour, ok: true},
{input: "7d", want: 0, ok: false},
}
for _, tt := range tests {
got, ok := parseOpsOpenAITokenStatsDuration(tt.input)
require.Equal(t, tt.ok, ok, "input=%s", tt.input)
require.Equal(t, tt.want, got, "input=%s", tt.input)
}
}
func TestParseOpsOpenAITokenStatsFilter_Defaults(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
before := time.Now().UTC()
filter, err := parseOpsOpenAITokenStatsFilter(c)
after := time.Now().UTC()
require.NoError(t, err)
require.NotNil(t, filter)
require.Equal(t, "30d", filter.TimeRange)
require.Equal(t, 1, filter.Page)
require.Equal(t, 20, filter.PageSize)
require.Equal(t, 0, filter.TopN)
require.Nil(t, filter.GroupID)
require.Equal(t, "", filter.Platform)
require.True(t, filter.StartTime.Before(filter.EndTime))
require.WithinDuration(t, before.Add(-30*24*time.Hour), filter.StartTime, 2*time.Second)
require.WithinDuration(t, after, filter.EndTime, 2*time.Second)
}
func TestParseOpsOpenAITokenStatsFilter_WithTopN(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(
http.MethodGet,
"/?time_range=1h&platform=openai&group_id=12&top_n=50",
nil,
)
filter, err := parseOpsOpenAITokenStatsFilter(c)
require.NoError(t, err)
require.Equal(t, "1h", filter.TimeRange)
require.Equal(t, "openai", filter.Platform)
require.NotNil(t, filter.GroupID)
require.Equal(t, int64(12), *filter.GroupID)
require.Equal(t, 50, filter.TopN)
require.Equal(t, 0, filter.Page)
require.Equal(t, 0, filter.PageSize)
}
func TestParseOpsOpenAITokenStatsFilter_InvalidParams(t *testing.T) {
tests := []string{
"/?time_range=7d",
"/?group_id=0",
"/?group_id=abc",
"/?top_n=0",
"/?top_n=101",
"/?top_n=10&page=1",
"/?top_n=10&page_size=20",
"/?page=0",
"/?page_size=0",
"/?page_size=101",
}
gin.SetMode(gin.TestMode)
for _, rawURL := range tests {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, rawURL, nil)
_, err := parseOpsOpenAITokenStatsFilter(c)
require.Error(t, err, "url=%s", rawURL)
}
}
func TestParseOpsTimeRange(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()

View File

@@ -10,19 +10,27 @@ import (
)
type stubAdminService struct {
users []service.User
apiKeys []service.APIKey
groups []service.Group
accounts []service.Account
proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode
createdAccounts []*service.CreateAccountInput
createdProxies []*service.CreateProxyInput
updatedProxyIDs []int64
updatedProxies []*service.UpdateProxyInput
testedProxyIDs []int64
mu sync.Mutex
users []service.User
apiKeys []service.APIKey
groups []service.Group
accounts []service.Account
proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode
createdAccounts []*service.CreateAccountInput
createdProxies []*service.CreateProxyInput
updatedProxyIDs []int64
updatedProxies []*service.UpdateProxyInput
testedProxyIDs []int64
createAccountErr error
updateAccountErr error
checkMixedErr error
lastMixedCheck struct {
accountID int64
platform string
groupIDs []int64
}
mu sync.Mutex
}
func newStubAdminService() *stubAdminService {
@@ -188,11 +196,17 @@ func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.Cre
s.mu.Lock()
s.createdAccounts = append(s.createdAccounts, input)
s.mu.Unlock()
if s.createAccountErr != nil {
return nil, s.createAccountErr
}
account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
if s.updateAccountErr != nil {
return nil, s.updateAccountErr
}
account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive}
return &account, nil
}
@@ -224,6 +238,13 @@ func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *servic
return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil
}
func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
s.lastMixedCheck.accountID = currentAccountID
s.lastMixedCheck.platform = currentAccountPlatform
s.lastMixedCheck.groupIDs = append([]int64(nil), groupIDs...)
return s.checkMixedErr
}
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
search = strings.TrimSpace(strings.ToLower(search))
filtered := make([]service.Proxy, 0, len(s.proxies))
@@ -327,6 +348,27 @@ func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.Pr
return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
}
func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*service.ProxyQualityCheckResult, error) {
return &service.ProxyQualityCheckResult{
ProxyID: id,
Score: 95,
Grade: "A",
Summary: "通过 5 项,告警 0 项,失败 0 项,挑战 0 项",
PassedCount: 5,
WarnCount: 0,
FailedCount: 0,
ChallengeCount: 0,
CheckedAt: time.Now().Unix(),
Items: []service.ProxyQualityCheckItem{
{Target: "base_connectivity", Status: "pass", Message: "ok"},
{Target: "openai", Status: "pass", HTTPStatus: 401},
{Target: "anthropic", Status: "pass", HTTPStatus: 401},
{Target: "gemini", Status: "pass", HTTPStatus: 200},
{Target: "sora", Status: "pass", HTTPStatus: 401},
},
}, nil
}
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) {
return s.redeems, int64(len(s.redeems)), nil
}

View File

@@ -0,0 +1,208 @@
//go:build unit
package admin
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// failingAdminService 嵌入 stubAdminService可配置 UpdateAccount 在指定 ID 时失败。
type failingAdminService struct {
*stubAdminService
failOnAccountID int64
updateCallCount atomic.Int64
}
func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
f.updateCallCount.Add(1)
if id == f.failOnAccountID {
return nil, errors.New("database error")
}
return f.stubAdminService.UpdateAccount(ctx, id, input)
}
func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) {
gin.SetMode(gin.TestMode)
router := gin.New()
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials)
return router, handler
}
func TestBatchUpdateCredentials_AllSuccess(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "account_uuid",
Value: "test-uuid",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code, "全部成功时应返回 200")
require.Equal(t, int64(3), svc.updateCallCount.Load(), "应调用 3 次 UpdateAccount")
}
func TestBatchUpdateCredentials_PartialFailure(t *testing.T) {
// 让第 2 个账号ID=2更新时失败
svc := &failingAdminService{
stubAdminService: newStubAdminService(),
failOnAccountID: 2,
}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "org_uuid",
Value: "test-org",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
// 实现采用"部分成功"模式:总是返回 200 + 成功/失败明细
require.Equal(t, http.StatusOK, w.Code, "批量更新返回 200 + 成功/失败明细")
var resp map[string]any
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
data := resp["data"].(map[string]any)
require.Equal(t, float64(2), data["success"], "应有 2 个成功")
require.Equal(t, float64(1), data["failed"], "应有 1 个失败")
// 所有 3 个账号都会被尝试更新(非 fail-fast
require.Equal(t, int64(3), svc.updateCallCount.Load(),
"应调用 3 次 UpdateAccount逐个尝试失败后继续")
}
func TestBatchUpdateCredentials_FirstAccountNotFound(t *testing.T) {
// GetAccount 在 stubAdminService 中总是成功的,需要创建一个 GetAccount 会失败的 stub
svc := &getAccountFailingService{
stubAdminService: newStubAdminService(),
failOnAccountID: 1,
}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "account_uuid",
Value: "test",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusNotFound, w.Code, "第一阶段验证失败应返回 404")
}
// getAccountFailingService 模拟 GetAccount 在特定 ID 时返回 not found。
type getAccountFailingService struct {
*stubAdminService
failOnAccountID int64
}
func (f *getAccountFailingService) GetAccount(ctx context.Context, id int64) (*service.Account, error) {
if id == f.failOnAccountID {
return nil, errors.New("not found")
}
return f.stubAdminService.GetAccount(ctx, id)
}
func TestBatchUpdateCredentials_InterceptWarmupRequests_NonBool(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// intercept_warmup_requests 传入非 bool 类型string应返回 400
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "intercept_warmup_requests",
"value": "not-a-bool",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code,
"intercept_warmup_requests 传入非 bool 值应返回 400")
}
func TestBatchUpdateCredentials_InterceptWarmupRequests_ValidBool(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "intercept_warmup_requests",
"value": true,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code,
"intercept_warmup_requests 传入合法 bool 值应返回 200")
}
func TestBatchUpdateCredentials_AccountUUID_NonString(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// account_uuid 传入非 string 类型number应返回 400
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "account_uuid",
"value": 12345,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code,
"account_uuid 传入非 string 值应返回 400")
}
func TestBatchUpdateCredentials_AccountUUID_NullValue(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// account_uuid 传入 null设置为空应正常通过
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "account_uuid",
"value": nil,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code,
"account_uuid 传入 null 应返回 200")
}

View File

@@ -379,7 +379,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
return
}
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{})
if err != nil {
response.Error(c, 500, "Failed to get user usage stats")
return
@@ -407,7 +407,7 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
return
}
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs)
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{})
if err != nil {
response.Error(c, 500, "Failed to get API key usage stats")
return

View File

@@ -61,7 +61,11 @@ func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
if err != nil {
msg := err.Error()
// Treat missing/invalid OAuth client configuration as a user/config error.
if strings.Contains(msg, "OAuth client not configured") || strings.Contains(msg, "requires your own OAuth Client") {
if strings.Contains(msg, "OAuth client not configured") ||
strings.Contains(msg, "requires your own OAuth Client") ||
strings.Contains(msg, "requires a custom OAuth Client") ||
strings.Contains(msg, "GEMINI_CLI_OAUTH_CLIENT_SECRET_MISSING") ||
strings.Contains(msg, "built-in Gemini CLI OAuth client_secret is not configured") {
response.BadRequest(c, "Failed to generate auth URL: "+msg)
return
}

View File

@@ -27,7 +27,7 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler {
type CreateGroupRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
@@ -38,6 +38,10 @@ type CreateGroupRequest struct {
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
@@ -55,7 +59,7 @@ type CreateGroupRequest struct {
type UpdateGroupRequest struct {
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
@@ -67,6 +71,10 @@ type UpdateGroupRequest struct {
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
ClaudeCodeOnly *bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
@@ -179,6 +187,10 @@ func (h *GroupHandler) Create(c *gin.Context) {
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
SoraImagePrice360: req.SoraImagePrice360,
SoraImagePrice540: req.SoraImagePrice540,
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
@@ -225,6 +237,10 @@ func (h *GroupHandler) Update(c *gin.Context) {
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
SoraImagePrice360: req.SoraImagePrice360,
SoraImagePrice540: req.SoraImagePrice540,
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,

View File

@@ -0,0 +1,115 @@
package admin
import (
"context"
"strconv"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type idempotencyStoreUnavailableMode int
const (
idempotencyStoreUnavailableFailClose idempotencyStoreUnavailableMode = iota
idempotencyStoreUnavailableFailOpen
)
func executeAdminIdempotent(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) (*service.IdempotencyExecuteResult, error) {
coordinator := service.DefaultIdempotencyCoordinator()
if coordinator == nil {
data, err := execute(c.Request.Context())
if err != nil {
return nil, err
}
return &service.IdempotencyExecuteResult{Data: data}, nil
}
actorScope := "admin:0"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10)
}
return coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{
Scope: scope,
ActorScope: actorScope,
Method: c.Request.Method,
Route: c.FullPath(),
IdempotencyKey: c.GetHeader("Idempotency-Key"),
Payload: payload,
RequireKey: true,
TTL: ttl,
}, execute)
}
func executeAdminIdempotentJSON(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) {
executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailClose, execute)
}
func executeAdminIdempotentJSONFailOpenOnStoreUnavailable(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) {
executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailOpen, execute)
}
func executeAdminIdempotentJSONWithMode(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
mode idempotencyStoreUnavailableMode,
execute func(context.Context) (any, error),
) {
result, err := executeAdminIdempotent(c, scope, payload, ttl, execute)
if err != nil {
if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) {
strategy := "fail_close"
if mode == idempotencyStoreUnavailableFailOpen {
strategy = "fail_open"
}
service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_"+strategy)
logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=%s", c.Request.Method, c.FullPath(), scope, strategy)
if mode == idempotencyStoreUnavailableFailOpen {
data, fallbackErr := execute(c.Request.Context())
if fallbackErr != nil {
response.ErrorFrom(c, fallbackErr)
return
}
c.Header("X-Idempotency-Degraded", "store-unavailable")
response.Success(c, data)
return
}
}
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
response.ErrorFrom(c, err)
return
}
if result != nil && result.Replayed {
c.Header("X-Idempotency-Replayed", "true")
}
response.Success(c, result.Data)
}

View File

@@ -0,0 +1,285 @@
package admin
import (
"bytes"
"context"
"errors"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type storeUnavailableRepoStub struct{}
func (storeUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) {
return false, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) {
return nil, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
return errors.New("store unavailable")
}
func (storeUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
return errors.New("store unavailable")
}
func (storeUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) {
return 0, errors.New("store unavailable")
}
func TestExecuteAdminIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode)
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed int
router := gin.New()
router.POST("/idempotent", func(c *gin.Context) {
executeAdminIdempotentJSON(c, "admin.test.high", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed++
return gin.H{"ok": true}, nil
})
})
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "test-key-1")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
require.Equal(t, 0, executed, "fail-close should block business execution when idempotency store is unavailable")
}
func TestExecuteAdminIdempotentJSONFailOpenOnStoreUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode)
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed int
router := gin.New()
router.POST("/idempotent", func(c *gin.Context) {
executeAdminIdempotentJSONFailOpenOnStoreUnavailable(c, "admin.test.medium", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed++
return gin.H{"ok": true}, nil
})
})
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "test-key-2")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "store-unavailable", rec.Header().Get("X-Idempotency-Degraded"))
require.Equal(t, 1, executed, "fail-open strategy should allow semantic idempotent path to continue")
}
type memoryIdempotencyRepoStub struct {
mu sync.Mutex
nextID int64
data map[string]*service.IdempotencyRecord
}
func newMemoryIdempotencyRepoStub() *memoryIdempotencyRepoStub {
return &memoryIdempotencyRepoStub{
nextID: 1,
data: make(map[string]*service.IdempotencyRecord),
}
}
func (r *memoryIdempotencyRepoStub) key(scope, keyHash string) string {
return scope + "|" + keyHash
}
func (r *memoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord {
if in == nil {
return nil
}
out := *in
if in.LockedUntil != nil {
v := *in.LockedUntil
out.LockedUntil = &v
}
if in.ResponseBody != nil {
v := *in.ResponseBody
out.ResponseBody = &v
}
if in.ResponseStatus != nil {
v := *in.ResponseStatus
out.ResponseStatus = &v
}
if in.ErrorReason != nil {
v := *in.ErrorReason
out.ErrorReason = &v
}
return &out
}
func (r *memoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
k := r.key(record.Scope, record.IdempotencyKeyHash)
if _, ok := r.data[k]; ok {
return false, nil
}
cp := r.clone(record)
cp.ID = r.nextID
r.nextID++
r.data[k] = cp
record.ID = cp.ID
return true, nil
}
func (r *memoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) {
r.mu.Lock()
defer r.mu.Unlock()
return r.clone(r.data[r.key(scope, keyHash)]), nil
}
func (r *memoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != fromStatus {
return false, nil
}
if rec.LockedUntil != nil && rec.LockedUntil.After(now) {
return false, nil
}
rec.Status = service.IdempotencyStatusProcessing
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
rec.ErrorReason = nil
return true, nil
}
return false, nil
}
func (r *memoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint {
return false, nil
}
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
return true, nil
}
return false, nil
}
func (r *memoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = service.IdempotencyStatusSucceeded
rec.LockedUntil = nil
rec.ExpiresAt = expiresAt
rec.ResponseStatus = &responseStatus
rec.ResponseBody = &responseBody
rec.ErrorReason = nil
return nil
}
return nil
}
func (r *memoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = service.IdempotencyStatusFailedRetryable
rec.LockedUntil = &lockedUntil
rec.ExpiresAt = expiresAt
rec.ErrorReason = &errorReason
return nil
}
return nil
}
func (r *memoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) {
return 0, nil
}
func TestExecuteAdminIdempotentJSONConcurrentRetryOnlyOneSideEffect(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := newMemoryIdempotencyRepoStub()
cfg := service.DefaultIdempotencyConfig()
cfg.ProcessingTimeout = 2 * time.Second
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed atomic.Int32
router := gin.New()
router.POST("/idempotent", func(c *gin.Context) {
executeAdminIdempotentJSON(c, "admin.test.concurrent", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed.Add(1)
time.Sleep(120 * time.Millisecond)
return gin.H{"ok": true}, nil
})
})
call := func() (int, http.Header) {
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "same-key")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
return rec.Code, rec.Header()
}
var status1, status2 int
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
status1, _ = call()
}()
go func() {
defer wg.Done()
status2, _ = call()
}()
wg.Wait()
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1)
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2)
require.Equal(t, int32(1), executed.Load(), "same idempotency key should execute side-effect only once")
status3, headers3 := call()
require.Equal(t, http.StatusOK, status3)
require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed"))
require.Equal(t, int32(1), executed.Load())
}

View File

@@ -2,6 +2,7 @@ package admin
import (
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
@@ -16,6 +17,13 @@ type OpenAIOAuthHandler struct {
adminService service.AdminService
}
func oauthPlatformFromPath(c *gin.Context) string {
if strings.Contains(c.FullPath(), "/admin/sora/") {
return service.PlatformSora
}
return service.PlatformOpenAI
}
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
return &OpenAIOAuthHandler{
@@ -52,6 +60,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
type OpenAIExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"`
State string `json:"state" binding:"required"`
RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"`
}
@@ -68,6 +77,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID,
Code: req.Code,
State: req.State,
RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID,
})
@@ -81,18 +91,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
type OpenAIRefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
RefreshToken string `json:"refresh_token"`
RT string `json:"rt"`
ClientID string `json:"client_id"`
ProxyID *int64 `json:"proxy_id"`
}
// RefreshToken refreshes an OpenAI OAuth token
// POST /api/v1/admin/openai/refresh-token
// POST /api/v1/admin/sora/rt2at
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
var req OpenAIRefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
refreshToken := strings.TrimSpace(req.RefreshToken)
if refreshToken == "" {
refreshToken = strings.TrimSpace(req.RT)
}
if refreshToken == "" {
response.BadRequest(c, "refresh_token is required")
return
}
var proxyURL string
if req.ProxyID != nil {
@@ -102,7 +123,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
}
}
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, strings.TrimSpace(req.ClientID))
if err != nil {
response.ErrorFrom(c, err)
return
@@ -111,8 +132,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
response.Success(c, tokenInfo)
}
// RefreshAccountToken refreshes token for a specific OpenAI account
// ExchangeSoraSessionToken exchanges Sora session token to access token
// POST /api/v1/admin/sora/st2at
func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) {
var req struct {
SessionToken string `json:"session_token"`
ST string `json:"st"`
ProxyID *int64 `json:"proxy_id"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
sessionToken := strings.TrimSpace(req.SessionToken)
if sessionToken == "" {
sessionToken = strings.TrimSpace(req.ST)
}
if sessionToken == "" {
response.BadRequest(c, "session_token is required")
return
}
tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, tokenInfo)
}
// RefreshAccountToken refreshes token for a specific OpenAI/Sora account
// POST /api/v1/admin/openai/accounts/:id/refresh
// POST /api/v1/admin/sora/accounts/:id/refresh
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
@@ -127,9 +179,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
return
}
// Ensure account is OpenAI platform
if !account.IsOpenAI() {
response.BadRequest(c, "Account is not an OpenAI account")
platform := oauthPlatformFromPath(c)
if account.Platform != platform {
response.BadRequest(c, "Account platform does not match OAuth endpoint")
return
}
@@ -167,12 +219,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
response.Success(c, dto.AccountFromService(updatedAccount))
}
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info
// POST /api/v1/admin/openai/create-from-oauth
// POST /api/v1/admin/sora/create-from-oauth
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
var req struct {
SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"`
State string `json:"state" binding:"required"`
RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"`
Name string `json:"name"`
@@ -189,6 +243,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID,
Code: req.Code,
State: req.State,
RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID,
})
@@ -200,19 +255,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
// Build credentials from token info
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
platform := oauthPlatformFromPath(c)
// Use email as default name if not provided
name := req.Name
if name == "" && tokenInfo.Email != "" {
name = tokenInfo.Email
}
if name == "" {
name = "OpenAI OAuth Account"
if platform == service.PlatformSora {
name = "Sora OAuth Account"
} else {
name = "OpenAI OAuth Account"
}
}
// Create account
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: name,
Platform: "openai",
Platform: platform,
Type: "oauth",
Credentials: credentials,
ProxyID: req.ProxyID,

View File

@@ -1,6 +1,7 @@
package admin
import (
"fmt"
"net/http"
"strconv"
"strings"
@@ -218,6 +219,115 @@ func (h *OpsHandler) GetDashboardErrorDistribution(c *gin.Context) {
response.Success(c, data)
}
// GetDashboardOpenAITokenStats returns OpenAI token efficiency stats grouped by model.
// GET /api/v1/admin/ops/dashboard/openai-token-stats
func (h *OpsHandler) GetDashboardOpenAITokenStats(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
filter, err := parseOpsOpenAITokenStatsFilter(c)
if err != nil {
response.BadRequest(c, err.Error())
return
}
data, err := h.opsService.GetOpenAITokenStats(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, data)
}
func parseOpsOpenAITokenStatsFilter(c *gin.Context) (*service.OpsOpenAITokenStatsFilter, error) {
if c == nil {
return nil, fmt.Errorf("invalid request")
}
timeRange := strings.TrimSpace(c.Query("time_range"))
if timeRange == "" {
timeRange = "30d"
}
dur, ok := parseOpsOpenAITokenStatsDuration(timeRange)
if !ok {
return nil, fmt.Errorf("invalid time_range")
}
end := time.Now().UTC()
start := end.Add(-dur)
filter := &service.OpsOpenAITokenStatsFilter{
TimeRange: timeRange,
StartTime: start,
EndTime: end,
Platform: strings.TrimSpace(c.Query("platform")),
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
return nil, fmt.Errorf("invalid group_id")
}
filter.GroupID = &id
}
topNRaw := strings.TrimSpace(c.Query("top_n"))
pageRaw := strings.TrimSpace(c.Query("page"))
pageSizeRaw := strings.TrimSpace(c.Query("page_size"))
if topNRaw != "" && (pageRaw != "" || pageSizeRaw != "") {
return nil, fmt.Errorf("invalid query: top_n cannot be used with page/page_size")
}
if topNRaw != "" {
topN, err := strconv.Atoi(topNRaw)
if err != nil || topN < 1 || topN > 100 {
return nil, fmt.Errorf("invalid top_n")
}
filter.TopN = topN
return filter, nil
}
filter.Page = 1
filter.PageSize = 20
if pageRaw != "" {
page, err := strconv.Atoi(pageRaw)
if err != nil || page < 1 {
return nil, fmt.Errorf("invalid page")
}
filter.Page = page
}
if pageSizeRaw != "" {
pageSize, err := strconv.Atoi(pageSizeRaw)
if err != nil || pageSize < 1 || pageSize > 100 {
return nil, fmt.Errorf("invalid page_size")
}
filter.PageSize = pageSize
}
return filter, nil
}
func parseOpsOpenAITokenStatsDuration(v string) (time.Duration, bool) {
switch strings.TrimSpace(v) {
case "30m":
return 30 * time.Minute, true
case "1h":
return time.Hour, true
case "1d":
return 24 * time.Hour, true
case "15d":
return 15 * 24 * time.Hour, true
case "30d":
return 30 * 24 * time.Hour, true
default:
return 0, false
}
}
func pickThroughputBucketSeconds(window time.Duration) int {
// Keep buckets predictable and avoid huge responses.
switch {

View File

@@ -0,0 +1,173 @@
package admin
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type testSettingRepo struct {
values map[string]string
}
func newTestSettingRepo() *testSettingRepo {
return &testSettingRepo{values: map[string]string{}}
}
func (s *testSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) {
v, err := s.GetValue(ctx, key)
if err != nil {
return nil, err
}
return &service.Setting{Key: key, Value: v}, nil
}
func (s *testSettingRepo) GetValue(ctx context.Context, key string) (string, error) {
v, ok := s.values[key]
if !ok {
return "", service.ErrSettingNotFound
}
return v, nil
}
func (s *testSettingRepo) Set(ctx context.Context, key, value string) error {
s.values[key] = value
return nil
}
func (s *testSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, k := range keys {
if v, ok := s.values[k]; ok {
out[k] = v
}
}
return out, nil
}
func (s *testSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error {
for k, v := range settings {
s.values[k] = v
}
return nil
}
func (s *testSettingRepo) GetAll(ctx context.Context) (map[string]string, error) {
out := make(map[string]string, len(s.values))
for k, v := range s.values {
out[k] = v
}
return out, nil
}
func (s *testSettingRepo) Delete(ctx context.Context, key string) error {
delete(s.values, key)
return nil
}
func newOpsRuntimeRouter(handler *OpsHandler, withUser bool) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
if withUser {
r.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 7})
c.Next()
})
}
r.GET("/runtime/logging", handler.GetRuntimeLogConfig)
r.PUT("/runtime/logging", handler.UpdateRuntimeLogConfig)
r.POST("/runtime/logging/reset", handler.ResetRuntimeLogConfig)
return r
}
func newRuntimeOpsService(t *testing.T) *service.OpsService {
t.Helper()
if err := logger.Init(logger.InitOptions{
Level: "info",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: logger.OutputOptions{
ToStdout: false,
ToFile: false,
},
}); err != nil {
t.Fatalf("init logger: %v", err)
}
settingRepo := newTestSettingRepo()
cfg := &config.Config{
Ops: config.OpsConfig{Enabled: true},
Log: config.LogConfig{
Level: "info",
Caller: true,
StacktraceLevel: "error",
Sampling: config.LogSamplingConfig{
Enabled: false,
Initial: 100,
Thereafter: 100,
},
},
}
return service.NewOpsService(nil, settingRepo, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
}
func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t))
r := newOpsRuntimeRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/runtime/logging", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d, want 200", w.Code)
}
}
func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t))
r := newOpsRuntimeRouter(h, false)
body := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}`
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", w.Code)
}
}
func TestOpsRuntimeLoggingHandler_UpdateAndResetSuccess(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t))
r := newOpsRuntimeRouter(h, true)
payload := map[string]any{
"level": "debug",
"enable_sampling": false,
"sampling_initial": 100,
"sampling_thereafter": 100,
"caller": true,
"stacktrace_level": "error",
"retention_days": 30,
}
raw, _ := json.Marshal(payload)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewReader(raw))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("update status=%d, want 200, body=%s", w.Code, w.Body.String())
}
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/runtime/logging/reset", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("reset status=%d, want 200, body=%s", w.Code, w.Body.String())
}
}

View File

@@ -4,6 +4,7 @@ import (
"net/http"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -101,6 +102,84 @@ func (h *OpsHandler) UpdateAlertRuntimeSettings(c *gin.Context) {
response.Success(c, updated)
}
// GetRuntimeLogConfig returns runtime log config (DB-backed).
// GET /api/v1/admin/ops/runtime/logging
func (h *OpsHandler) GetRuntimeLogConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
cfg, err := h.opsService.GetRuntimeLogConfig(c.Request.Context())
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to get runtime log config")
return
}
response.Success(c, cfg)
}
// UpdateRuntimeLogConfig updates runtime log config and applies changes immediately.
// PUT /api/v1/admin/ops/runtime/logging
func (h *OpsHandler) UpdateRuntimeLogConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
var req service.OpsRuntimeLogConfig
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
updated, err := h.opsService.UpdateRuntimeLogConfig(c.Request.Context(), &req, subject.UserID)
if err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, updated)
}
// ResetRuntimeLogConfig removes runtime override and falls back to env/yaml baseline.
// POST /api/v1/admin/ops/runtime/logging/reset
func (h *OpsHandler) ResetRuntimeLogConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
updated, err := h.opsService.ResetRuntimeLogConfig(c.Request.Context(), subject.UserID)
if err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, updated)
}
// GetAdvancedSettings returns Ops advanced settings (DB-backed).
// GET /api/v1/admin/ops/advanced-settings
func (h *OpsHandler) GetAdvancedSettings(c *gin.Context) {

View File

@@ -0,0 +1,174 @@
package admin
import (
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type opsSystemLogCleanupRequest struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
Level string `json:"level"`
Component string `json:"component"`
RequestID string `json:"request_id"`
ClientRequestID string `json:"client_request_id"`
UserID *int64 `json:"user_id"`
AccountID *int64 `json:"account_id"`
Platform string `json:"platform"`
Model string `json:"model"`
Query string `json:"q"`
}
// ListSystemLogs returns indexed system logs.
// GET /api/v1/admin/ops/system-logs
func (h *OpsHandler) ListSystemLogs(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
page, pageSize := response.ParsePagination(c)
if pageSize > 200 {
pageSize = 200
}
start, end, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsSystemLogFilter{
Page: page,
PageSize: pageSize,
StartTime: &start,
EndTime: &end,
Level: strings.TrimSpace(c.Query("level")),
Component: strings.TrimSpace(c.Query("component")),
RequestID: strings.TrimSpace(c.Query("request_id")),
ClientRequestID: strings.TrimSpace(c.Query("client_request_id")),
Platform: strings.TrimSpace(c.Query("platform")),
Model: strings.TrimSpace(c.Query("model")),
Query: strings.TrimSpace(c.Query("q")),
}
if v := strings.TrimSpace(c.Query("user_id")); v != "" {
id, parseErr := strconv.ParseInt(v, 10, 64)
if parseErr != nil || id <= 0 {
response.BadRequest(c, "Invalid user_id")
return
}
filter.UserID = &id
}
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
id, parseErr := strconv.ParseInt(v, 10, 64)
if parseErr != nil || id <= 0 {
response.BadRequest(c, "Invalid account_id")
return
}
filter.AccountID = &id
}
result, err := h.opsService.ListSystemLogs(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, result.Logs, int64(result.Total), result.Page, result.PageSize)
}
// CleanupSystemLogs deletes indexed system logs by filter.
// POST /api/v1/admin/ops/system-logs/cleanup
func (h *OpsHandler) CleanupSystemLogs(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
var req opsSystemLogCleanupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
parseTS := func(raw string) (*time.Time, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil, nil
}
if t, err := time.Parse(time.RFC3339Nano, raw); err == nil {
return &t, nil
}
t, err := time.Parse(time.RFC3339, raw)
if err != nil {
return nil, err
}
return &t, nil
}
start, err := parseTS(req.StartTime)
if err != nil {
response.BadRequest(c, "Invalid start_time")
return
}
end, err := parseTS(req.EndTime)
if err != nil {
response.BadRequest(c, "Invalid end_time")
return
}
filter := &service.OpsSystemLogCleanupFilter{
StartTime: start,
EndTime: end,
Level: strings.TrimSpace(req.Level),
Component: strings.TrimSpace(req.Component),
RequestID: strings.TrimSpace(req.RequestID),
ClientRequestID: strings.TrimSpace(req.ClientRequestID),
UserID: req.UserID,
AccountID: req.AccountID,
Platform: strings.TrimSpace(req.Platform),
Model: strings.TrimSpace(req.Model),
Query: strings.TrimSpace(req.Query),
}
deleted, err := h.opsService.CleanupSystemLogs(c.Request.Context(), filter, subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": deleted})
}
// GetSystemLogIngestionHealth returns sink health metrics.
// GET /api/v1/admin/ops/system-logs/health
func (h *OpsHandler) GetSystemLogIngestionHealth(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, h.opsService.GetSystemLogSinkHealth())
}

View File

@@ -0,0 +1,233 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type responseEnvelope struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data"`
}
func newOpsSystemLogTestRouter(handler *OpsHandler, withUser bool) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
if withUser {
r.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 99})
c.Next()
})
}
r.GET("/logs", handler.ListSystemLogs)
r.POST("/logs/cleanup", handler.CleanupSystemLogs)
r.GET("/logs/health", handler.GetSystemLogIngestionHealth)
return r
}
func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) {
h := NewOpsHandler(nil)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("status=%d, want 503", w.Code)
}
}
func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs?user_id=abc", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_ListInvalidAccountID(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs?account_id=-1", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) {
svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("status=%d, want 404", w.Code)
}
}
func TestOpsSystemLogHandler_ListSuccess(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs?time_range=30m&page=1&page_size=20", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d, want 200", w.Code)
}
var resp responseEnvelope
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal response: %v", err)
}
if resp.Code != 0 {
t.Fatalf("unexpected response code: %+v", resp)
}
}
func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{bad-json`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"start_time":"bad","request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"end_time":"bad","request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupServiceUnavailable(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("status=%d, want 503", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) {
svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("status=%d, want 404", w.Code)
}
}
func TestOpsSystemLogHandler_Health(t *testing.T) {
sink := service.NewOpsSystemLogSink(nil)
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d, want 200", w.Code)
}
}
func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T) {
h := NewOpsHandler(nil)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("status=%d, want 503", w.Code)
}
svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil)
h = NewOpsHandler(svc)
r = newOpsSystemLogTestRouter(h, false)
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/logs/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("status=%d, want 404", w.Code)
}
}

View File

@@ -3,7 +3,6 @@ package admin
import (
"context"
"encoding/json"
"log"
"math"
"net"
"net/http"
@@ -16,6 +15,7 @@ import (
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
@@ -252,7 +252,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) {
stats, err := opsService.GetWindowStats(ctx, now.Add(-c.requestCountWindow), now)
if err != nil || stats == nil {
if err != nil {
log.Printf("[OpsWS] refresh: get window stats failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: get window stats failed: %v", err)
}
return
}
@@ -278,7 +278,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) {
msg, err := json.Marshal(payload)
if err != nil {
log.Printf("[OpsWS] refresh: marshal payload failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: marshal payload failed: %v", err)
return
}
@@ -338,7 +338,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
// Reserve a global slot before upgrading the connection to keep the limit strict.
if !tryAcquireOpsWSTotalSlot(opsWSLimits.MaxConns) {
log.Printf("[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns)
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
return
}
@@ -350,7 +350,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
if opsWSLimits.MaxConnsPerIP > 0 && clientIP != "" {
if !tryAcquireOpsWSIPSlot(clientIP, opsWSLimits.MaxConnsPerIP) {
log.Printf("[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP)
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
return
}
@@ -359,7 +359,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Printf("[OpsWS] upgrade failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] upgrade failed: %v", err)
return
}
@@ -452,7 +452,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
conn.SetReadLimit(qpsWSMaxReadBytes)
if err := conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)); err != nil {
log.Printf("[OpsWS] set read deadline failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] set read deadline failed: %v", err)
return
}
conn.SetPongHandler(func(string) error {
@@ -471,7 +471,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
_, _, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
log.Printf("[OpsWS] read failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] read failed: %v", err)
}
return
}
@@ -508,7 +508,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
continue
}
if err := writeWithTimeout(websocket.TextMessage, msg); err != nil {
log.Printf("[OpsWS] write failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] write failed: %v", err)
cancel()
closeConn()
wg.Wait()
@@ -517,7 +517,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
case <-pingTicker.C:
if err := writeWithTimeout(websocket.PingMessage, nil); err != nil {
log.Printf("[OpsWS] ping failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] ping failed: %v", err)
cancel()
closeConn()
wg.Wait()
@@ -666,14 +666,14 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig {
if parsed, err := strconv.ParseBool(v); err == nil {
cfg.TrustProxy = parsed
} else {
log.Printf("[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy)
}
}
if raw := strings.TrimSpace(os.Getenv(envOpsWSTrustedProxies)); raw != "" {
prefixes, invalid := parseTrustedProxyList(raw)
if len(invalid) > 0 {
log.Printf("[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", "))
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", "))
}
cfg.TrustedProxies = prefixes
}
@@ -684,7 +684,7 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig {
case OriginPolicyStrict, OriginPolicyPermissive:
cfg.OriginPolicy = normalized
default:
log.Printf("[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy)
}
}
@@ -701,14 +701,14 @@ func loadOpsWSRuntimeLimitsFromEnv() opsWSRuntimeLimits {
if parsed, err := strconv.Atoi(v); err == nil && parsed > 0 {
cfg.MaxConns = int32(parsed)
} else {
log.Printf("[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns)
}
}
if v := strings.TrimSpace(os.Getenv(envOpsWSMaxConnsPerIP)); v != "" {
if parsed, err := strconv.Atoi(v); err == nil && parsed >= 0 {
cfg.MaxConnsPerIP = int32(parsed)
} else {
log.Printf("[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP)
}
}
return cfg

View File

@@ -1,6 +1,7 @@
package admin
import (
"context"
"strconv"
"strings"
@@ -130,20 +131,20 @@ func (h *ProxyHandler) Create(c *gin.Context) {
return
}
proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
Name: strings.TrimSpace(req.Name),
Protocol: strings.TrimSpace(req.Protocol),
Host: strings.TrimSpace(req.Host),
Port: req.Port,
Username: strings.TrimSpace(req.Username),
Password: strings.TrimSpace(req.Password),
executeAdminIdempotentJSON(c, "admin.proxies.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
proxy, err := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{
Name: strings.TrimSpace(req.Name),
Protocol: strings.TrimSpace(req.Protocol),
Host: strings.TrimSpace(req.Host),
Port: req.Port,
Username: strings.TrimSpace(req.Username),
Password: strings.TrimSpace(req.Password),
})
if err != nil {
return nil, err
}
return dto.ProxyFromService(proxy), nil
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.ProxyFromService(proxy))
}
// Update handles updating a proxy
@@ -236,6 +237,24 @@ func (h *ProxyHandler) Test(c *gin.Context) {
response.Success(c, result)
}
// CheckQuality handles checking proxy quality across common AI targets.
// POST /api/v1/admin/proxies/:id/quality-check
func (h *ProxyHandler) CheckQuality(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
result, err := h.adminService.CheckProxyQuality(c.Request.Context(), proxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// GetStats handles getting proxy statistics
// GET /api/v1/admin/proxies/:id/stats
func (h *ProxyHandler) GetStats(c *gin.Context) {

View File

@@ -2,6 +2,7 @@ package admin
import (
"bytes"
"context"
"encoding/csv"
"fmt"
"strconv"
@@ -88,23 +89,24 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
return
}
codes, err := h.adminService.GenerateRedeemCodes(c.Request.Context(), &service.GenerateRedeemCodesInput{
Count: req.Count,
Type: req.Type,
Value: req.Value,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
executeAdminIdempotentJSON(c, "admin.redeem_codes.generate", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
codes, execErr := h.adminService.GenerateRedeemCodes(ctx, &service.GenerateRedeemCodesInput{
Count: req.Count,
Type: req.Type,
Value: req.Value,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
})
if execErr != nil {
return nil, execErr
}
out := make([]dto.AdminRedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
}
response.Success(c, out)
out := make([]dto.AdminRedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
}
return out, nil
})
}
// Delete handles deleting a redeem code

View File

@@ -0,0 +1,97 @@
//go:build unit
package admin
import (
"testing"
"github.com/stretchr/testify/require"
)
// truncateSearchByRune 模拟 user_handler.go 中的 search 截断逻辑
func truncateSearchByRune(search string, maxRunes int) string {
if runes := []rune(search); len(runes) > maxRunes {
return string(runes[:maxRunes])
}
return search
}
func TestTruncateSearchByRune(t *testing.T) {
tests := []struct {
name string
input string
maxRunes int
wantLen int // 期望的 rune 长度
}{
{
name: "纯中文超长",
input: string(make([]rune, 150)),
maxRunes: 100,
wantLen: 100,
},
{
name: "纯 ASCII 超长",
input: string(make([]byte, 150)),
maxRunes: 100,
wantLen: 100,
},
{
name: "空字符串",
input: "",
maxRunes: 100,
wantLen: 0,
},
{
name: "恰好 100 个字符",
input: string(make([]rune, 100)),
maxRunes: 100,
wantLen: 100,
},
{
name: "不足 100 字符不截断",
input: "hello世界",
maxRunes: 100,
wantLen: 7,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := truncateSearchByRune(tc.input, tc.maxRunes)
require.Equal(t, tc.wantLen, len([]rune(result)))
})
}
}
func TestTruncateSearchByRune_PreservesMultibyte(t *testing.T) {
// 101 个中文字符,截断到 100 个后应该仍然是有效 UTF-8
input := ""
for i := 0; i < 101; i++ {
input += "中"
}
result := truncateSearchByRune(input, 100)
require.Equal(t, 100, len([]rune(result)))
// 验证截断结果是有效的 UTF-8每个中文字符 3 字节)
require.Equal(t, 300, len(result))
}
func TestTruncateSearchByRune_MixedASCIIAndMultibyte(t *testing.T) {
// 50 个 ASCII + 51 个中文 = 101 个 rune
input := ""
for i := 0; i < 50; i++ {
input += "a"
}
for i := 0; i < 51; i++ {
input += "中"
}
result := truncateSearchByRune(input, 100)
runes := []rune(result)
require.Equal(t, 100, len(runes))
// 前 50 个应该是 'a',后 50 个应该是 '中'
require.Equal(t, 'a', runes[0])
require.Equal(t, 'a', runes[49])
require.Equal(t, '中', runes[50])
require.Equal(t, '中', runes[99])
}

View File

@@ -1,6 +1,7 @@
package admin
import (
"context"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
@@ -199,13 +200,20 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
return
}
subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days)
if err != nil {
response.ErrorFrom(c, err)
return
idempotencyPayload := struct {
SubscriptionID int64 `json:"subscription_id"`
Body AdjustSubscriptionRequest `json:"body"`
}{
SubscriptionID: subscriptionID,
Body: req,
}
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
executeAdminIdempotentJSON(c, "admin.subscriptions.extend", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
subscription, execErr := h.subscriptionService.ExtendSubscription(ctx, subscriptionID, req.Days)
if execErr != nil {
return nil, execErr
}
return dto.UserSubscriptionFromServiceAdmin(subscription), nil
})
}
// Revoke handles revoking a subscription

View File

@@ -1,11 +1,15 @@
package admin
import (
"context"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/sysutil"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -14,12 +18,14 @@ import (
// SystemHandler handles system-related operations
type SystemHandler struct {
updateSvc *service.UpdateService
lockSvc *service.SystemOperationLockService
}
// NewSystemHandler creates a new SystemHandler
func NewSystemHandler(updateSvc *service.UpdateService) *SystemHandler {
func NewSystemHandler(updateSvc *service.UpdateService, lockSvc *service.SystemOperationLockService) *SystemHandler {
return &SystemHandler{
updateSvc: updateSvc,
lockSvc: lockSvc,
}
}
@@ -47,41 +53,125 @@ func (h *SystemHandler) CheckUpdates(c *gin.Context) {
// PerformUpdate downloads and applies the update
// POST /api/v1/admin/system/update
func (h *SystemHandler) PerformUpdate(c *gin.Context) {
if err := h.updateSvc.PerformUpdate(c.Request.Context()); err != nil {
response.Error(c, http.StatusInternalServerError, err.Error())
return
}
response.Success(c, gin.H{
"message": "Update completed. Please restart the service.",
"need_restart": true,
operationID := buildSystemOperationID(c, "update")
payload := gin.H{"operation_id": operationID}
executeAdminIdempotentJSON(c, "admin.system.update", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
lock, release, err := h.acquireSystemLock(ctx, operationID)
if err != nil {
return nil, err
}
var releaseReason string
succeeded := false
defer func() {
release(releaseReason, succeeded)
}()
if err := h.updateSvc.PerformUpdate(ctx); err != nil {
releaseReason = "SYSTEM_UPDATE_FAILED"
return nil, err
}
succeeded = true
return gin.H{
"message": "Update completed. Please restart the service.",
"need_restart": true,
"operation_id": lock.OperationID(),
}, nil
})
}
// Rollback restores the previous version
// POST /api/v1/admin/system/rollback
func (h *SystemHandler) Rollback(c *gin.Context) {
if err := h.updateSvc.Rollback(); err != nil {
response.Error(c, http.StatusInternalServerError, err.Error())
return
}
response.Success(c, gin.H{
"message": "Rollback completed. Please restart the service.",
"need_restart": true,
operationID := buildSystemOperationID(c, "rollback")
payload := gin.H{"operation_id": operationID}
executeAdminIdempotentJSON(c, "admin.system.rollback", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
lock, release, err := h.acquireSystemLock(ctx, operationID)
if err != nil {
return nil, err
}
var releaseReason string
succeeded := false
defer func() {
release(releaseReason, succeeded)
}()
if err := h.updateSvc.Rollback(); err != nil {
releaseReason = "SYSTEM_ROLLBACK_FAILED"
return nil, err
}
succeeded = true
return gin.H{
"message": "Rollback completed. Please restart the service.",
"need_restart": true,
"operation_id": lock.OperationID(),
}, nil
})
}
// RestartService restarts the systemd service
// POST /api/v1/admin/system/restart
func (h *SystemHandler) RestartService(c *gin.Context) {
// Schedule service restart in background after sending response
// This ensures the client receives the success response before the service restarts
go func() {
// Wait a moment to ensure the response is sent
time.Sleep(500 * time.Millisecond)
sysutil.RestartServiceAsync()
}()
operationID := buildSystemOperationID(c, "restart")
payload := gin.H{"operation_id": operationID}
executeAdminIdempotentJSON(c, "admin.system.restart", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
lock, release, err := h.acquireSystemLock(ctx, operationID)
if err != nil {
return nil, err
}
succeeded := false
defer func() {
release("", succeeded)
}()
response.Success(c, gin.H{
"message": "Service restart initiated",
// Schedule service restart in background after sending response
// This ensures the client receives the success response before the service restarts
go func() {
// Wait a moment to ensure the response is sent
time.Sleep(500 * time.Millisecond)
sysutil.RestartServiceAsync()
}()
succeeded = true
return gin.H{
"message": "Service restart initiated",
"operation_id": lock.OperationID(),
}, nil
})
}
func (h *SystemHandler) acquireSystemLock(
ctx context.Context,
operationID string,
) (*service.SystemOperationLock, func(string, bool), error) {
if h.lockSvc == nil {
return nil, nil, service.ErrIdempotencyStoreUnavail
}
lock, err := h.lockSvc.Acquire(ctx, operationID)
if err != nil {
return nil, nil, err
}
release := func(reason string, succeeded bool) {
releaseCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = h.lockSvc.Release(releaseCtx, lock, succeeded, reason)
}
return lock, release, nil
}
func buildSystemOperationID(c *gin.Context, operation string) string {
key := strings.TrimSpace(c.GetHeader("Idempotency-Key"))
if key == "" {
return "sysop-" + operation + "-" + strconv.FormatInt(time.Now().UnixNano(), 36)
}
actorScope := "admin:0"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10)
}
seed := operation + "|" + actorScope + "|" + c.FullPath() + "|" + key
hash := service.HashIdempotencyKey(seed)
if len(hash) > 24 {
hash = hash[:24]
}
return "sysop-" + hash
}

View File

@@ -1,13 +1,14 @@
package admin
import (
"log"
"context"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
@@ -378,11 +379,11 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
operator = subject.UserID
}
page, pageSize := response.ParsePagination(c)
log.Printf("[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params)
if err != nil {
log.Printf("[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
response.ErrorFrom(c, err)
return
}
@@ -390,7 +391,7 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
for i := range tasks {
out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i]))
}
log.Printf("[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
response.Paginated(c, out, result.Total, page, pageSize)
}
@@ -472,29 +473,36 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
billingType = *filters.BillingType
}
log.Printf("[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
subject.UserID,
filters.StartTime.Format(time.RFC3339),
filters.EndTime.Format(time.RFC3339),
userID,
apiKeyID,
accountID,
groupID,
model,
stream,
billingType,
req.Timezone,
)
task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID)
if err != nil {
log.Printf("[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
response.ErrorFrom(c, err)
return
idempotencyPayload := struct {
OperatorID int64 `json:"operator_id"`
Body CreateUsageCleanupTaskRequest `json:"body"`
}{
OperatorID: subject.UserID,
Body: req,
}
executeAdminIdempotentJSON(c, "admin.usage.cleanup_tasks.create", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
subject.UserID,
filters.StartTime.Format(time.RFC3339),
filters.EndTime.Format(time.RFC3339),
userID,
apiKeyID,
accountID,
groupID,
model,
stream,
billingType,
req.Timezone,
)
log.Printf("[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
response.Success(c, dto.UsageCleanupTaskFromService(task))
task, err := h.cleanupService.CreateTask(ctx, filters, subject.UserID)
if err != nil {
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
return nil, err
}
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
return dto.UsageCleanupTaskFromService(task), nil
})
}
// CancelCleanupTask handles canceling a usage cleanup task
@@ -515,12 +523,12 @@ func (h *UsageHandler) CancelCleanupTask(c *gin.Context) {
response.BadRequest(c, "Invalid task id")
return
}
log.Printf("[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil {
log.Printf("[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
response.ErrorFrom(c, err)
return
}
log.Printf("[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled})
}

View File

@@ -1,6 +1,7 @@
package admin
import (
"context"
"strconv"
"strings"
@@ -78,8 +79,8 @@ func (h *UserHandler) List(c *gin.Context) {
search := c.Query("search")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
search = search[:100]
if runes := []rune(search); len(runes) > 100 {
search = string(runes[:100])
}
filters := service.UserListFilters{
@@ -257,13 +258,20 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
return
}
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes)
if err != nil {
response.ErrorFrom(c, err)
return
idempotencyPayload := struct {
UserID int64 `json:"user_id"`
Body UpdateBalanceRequest `json:"body"`
}{
UserID: userID,
Body: req,
}
response.Success(c, dto.UserFromServiceAdmin(user))
executeAdminIdempotentJSON(c, "admin.users.balance.update", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
user, execErr := h.adminService.UpdateUserBalance(ctx, userID, req.Balance, req.Operation, req.Notes)
if execErr != nil {
return nil, execErr
}
return dto.UserFromServiceAdmin(user), nil
})
}
// GetUserAPIKeys handles getting user's API keys

View File

@@ -2,6 +2,7 @@
package handler
import (
"context"
"strconv"
"time"
@@ -130,13 +131,14 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
if req.Quota != nil {
svcReq.Quota = *req.Quota
}
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.APIKeyFromService(key))
executeUserIdempotentJSON(c, "user.api_keys.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
key, err := h.apiKeyService.Create(ctx, subject.UserID, svcReq)
if err != nil {
return nil, err
}
return dto.APIKeyFromService(key), nil
})
}
// Update handles updating an API key

View File

@@ -2,6 +2,7 @@ package handler
import (
"log/slog"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
@@ -112,12 +113,11 @@ func (h *AuthHandler) Register(c *gin.Context) {
return
}
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
if req.VerifyCode == "" {
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
response.ErrorFrom(c, err)
return
}
// Turnstile 验证 — 始终执行,防止绕过
// TODO: 确认前端在提交邮箱验证码注册时也传递了 turnstile_token
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
response.ErrorFrom(c, err)
return
}
_, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
@@ -448,17 +448,12 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
return
}
// Build frontend base URL from request
scheme := "https"
if c.Request.TLS == nil {
// Check X-Forwarded-Proto header (common in reverse proxy setups)
if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" {
scheme = proto
} else {
scheme = "http"
}
frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL)
if frontendBaseURL == "" {
slog.Error("server.frontend_url not configured; cannot build password reset link")
response.InternalError(c, "Password reset is not configured")
return
}
frontendBaseURL := scheme + "://" + c.Request.Host
// Request password reset (async)
// Note: This returns success even if email doesn't exist (to prevent enumeration)

View File

@@ -0,0 +1,40 @@
package dto
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestAPIKeyFromService_MapsLastUsedAt(t *testing.T) {
lastUsed := time.Now().UTC().Truncate(time.Second)
src := &service.APIKey{
ID: 1,
UserID: 2,
Key: "sk-map-last-used",
Name: "Mapper",
Status: service.StatusActive,
LastUsedAt: &lastUsed,
}
out := APIKeyFromService(src)
require.NotNil(t, out)
require.NotNil(t, out.LastUsedAt)
require.WithinDuration(t, lastUsed, *out.LastUsedAt, time.Second)
}
func TestAPIKeyFromService_MapsNilLastUsedAt(t *testing.T) {
src := &service.APIKey{
ID: 1,
UserID: 2,
Key: "sk-map-last-used-nil",
Name: "MapperNil",
Status: service.StatusActive,
}
out := APIKeyFromService(src)
require.NotNil(t, out)
require.Nil(t, out.LastUsedAt)
}

View File

@@ -2,6 +2,7 @@
package dto
import (
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -77,6 +78,7 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
Status: k.Status,
IPWhitelist: k.IPWhitelist,
IPBlacklist: k.IPBlacklist,
LastUsedAt: k.LastUsedAt,
Quota: k.Quota,
QuotaUsed: k.QuotaUsed,
ExpiresAt: k.ExpiresAt,
@@ -129,23 +131,26 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
func groupFromServiceBase(g *service.Group) Group {
return Group{
ID: g.ID,
Name: g.Name,
Description: g.Description,
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUSD,
WeeklyLimitUSD: g.WeeklyLimitUSD,
MonthlyLimitUSD: g.MonthlyLimitUSD,
ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
// 无效请求兜底分组
ID: g.ID,
Name: g.Name,
Description: g.Description,
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUSD,
WeeklyLimitUSD: g.WeeklyLimitUSD,
MonthlyLimitUSD: g.MonthlyLimitUSD,
ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K,
SoraImagePrice360: g.SoraImagePrice360,
SoraImagePrice540: g.SoraImagePrice540,
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
@@ -211,6 +216,13 @@ func AccountFromServiceShallow(a *service.Account) *Account {
enabled := true
out.EnableSessionIDMasking = &enabled
}
// 缓存 TTL 强制替换
if a.IsCacheTTLOverrideEnabled() {
enabled := true
out.CacheTTLOverrideEnabled = &enabled
target := a.GetCacheTTLOverrideTarget()
out.CacheTTLOverrideTarget = &target
}
}
return out
@@ -293,6 +305,11 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
CountryCode: p.CountryCode,
Region: p.Region,
City: p.City,
QualityStatus: p.QualityStatus,
QualityScore: p.QualityScore,
QualityGrade: p.QualityGrade,
QualitySummary: p.QualitySummary,
QualityChecked: p.QualityChecked,
}
}
@@ -397,7 +414,9 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
FirstTokenMs: l.FirstTokenMs,
ImageCount: l.ImageCount,
ImageSize: l.ImageSize,
MediaType: l.MediaType,
UserAgent: l.UserAgent,
CacheTTLOverridden: l.CacheTTLOverridden,
CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User),
APIKey: APIKeyFromService(l.APIKey),
@@ -524,11 +543,18 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
for i := range r.Subscriptions {
subs = append(subs, *UserSubscriptionFromServiceAdmin(&r.Subscriptions[i]))
}
statuses := make(map[string]string, len(r.Statuses))
for userID, status := range r.Statuses {
statuses[strconv.FormatInt(userID, 10)] = status
}
return &BulkAssignResult{
SuccessCount: r.SuccessCount,
CreatedCount: r.CreatedCount,
ReusedCount: r.ReusedCount,
FailedCount: r.FailedCount,
Subscriptions: subs,
Errors: r.Errors,
Statuses: statuses,
}
}

View File

@@ -38,6 +38,7 @@ type APIKey struct {
Status string `json:"status"`
IPWhitelist []string `json:"ip_whitelist"`
IPBlacklist []string `json:"ip_blacklist"`
LastUsedAt *time.Time `json:"last_used_at"`
Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited)
QuotaUsed float64 `json:"quota_used"` // Used quota amount in USD
ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = never expires)
@@ -67,6 +68,12 @@ type Group struct {
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
// Sora 按次计费配置
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
// Claude Code 客户端限制
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
@@ -150,6 +157,11 @@ type Account struct {
// 从 extra 字段提取,方便前端显示和编辑
EnableSessionIDMasking *bool `json:"session_id_masking_enabled,omitempty"`
// 缓存 TTL 强制替换(仅 Anthropic OAuth/SetupToken 账号有效)
// 启用后将所有 cache creation tokens 归入指定的 TTL 类型计费
CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"`
CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"`
Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
@@ -191,6 +203,11 @@ type ProxyWithAccountCount struct {
CountryCode string `json:"country_code,omitempty"`
Region string `json:"region,omitempty"`
City string `json:"city,omitempty"`
QualityStatus string `json:"quality_status,omitempty"`
QualityScore *int `json:"quality_score,omitempty"`
QualityGrade string `json:"quality_grade,omitempty"`
QualitySummary string `json:"quality_summary,omitempty"`
QualityChecked *int64 `json:"quality_checked,omitempty"`
}
type ProxyAccountSummary struct {
@@ -269,10 +286,14 @@ type UsageLog struct {
// 图片生成字段
ImageCount int `json:"image_count"`
ImageSize *string `json:"image_size"`
MediaType *string `json:"media_type"`
// User-Agent
UserAgent *string `json:"user_agent"`
// Cache TTL Override 标记
CacheTTLOverridden bool `json:"cache_ttl_overridden"`
CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"`
@@ -374,9 +395,12 @@ type AdminUserSubscription struct {
type BulkAssignResult struct {
SuccessCount int `json:"success_count"`
CreatedCount int `json:"created_count"`
ReusedCount int `json:"reused_count"`
FailedCount int `json:"failed_count"`
Subscriptions []AdminUserSubscription `json:"subscriptions"`
Errors []string `json:"errors"`
Statuses map[string]string `json:"statuses,omitempty"`
}
// PromoCode 注册优惠码

View File

@@ -0,0 +1,160 @@
package handler
import (
"context"
"log"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// TempUnscheduler 用于 HandleFailoverError 中同账号重试耗尽后的临时封禁。
// GatewayService 隐式实现此接口。
type TempUnscheduler interface {
TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *service.UpstreamFailoverError)
}
// FailoverAction 表示 failover 错误处理后的下一步动作
type FailoverAction int
const (
// FailoverContinue 继续循环(同账号重试或切换账号,调用方统一 continue
FailoverContinue FailoverAction = iota
// FailoverExhausted 切换次数耗尽(调用方应返回错误响应)
FailoverExhausted
// FailoverCanceled context 已取消(调用方应直接 return
FailoverCanceled
)
const (
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
maxSameAccountRetries = 2
// sameAccountRetryDelay 同账号重试间隔
sameAccountRetryDelay = 500 * time.Millisecond
// singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。
// Service 层在 SingleAccountRetry 模式下已做充分原地重试(最多 3 次、总等待 30s
// Handler 层只需短暂间隔后重新进入 Service 层即可。
singleAccountBackoffDelay = 2 * time.Second
)
// FailoverState 跨循环迭代共享的 failover 状态
type FailoverState struct {
SwitchCount int
MaxSwitches int
FailedAccountIDs map[int64]struct{}
SameAccountRetryCount map[int64]int
LastFailoverErr *service.UpstreamFailoverError
ForceCacheBilling bool
hasBoundSession bool
}
// NewFailoverState 创建 failover 状态
func NewFailoverState(maxSwitches int, hasBoundSession bool) *FailoverState {
return &FailoverState{
MaxSwitches: maxSwitches,
FailedAccountIDs: make(map[int64]struct{}),
SameAccountRetryCount: make(map[int64]int),
hasBoundSession: hasBoundSession,
}
}
// HandleFailoverError 处理 UpstreamFailoverError返回下一步动作。
// 包含缓存计费判断、同账号重试、临时封禁、切换计数、Antigravity 延时。
func (s *FailoverState) HandleFailoverError(
ctx context.Context,
gatewayService TempUnscheduler,
accountID int64,
platform string,
failoverErr *service.UpstreamFailoverError,
) FailoverAction {
s.LastFailoverErr = failoverErr
// 缓存计费判断
if needForceCacheBilling(s.hasBoundSession, failoverErr) {
s.ForceCacheBilling = true
}
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if failoverErr.RetryableOnSameAccount && s.SameAccountRetryCount[accountID] < maxSameAccountRetries {
s.SameAccountRetryCount[accountID]++
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
accountID, failoverErr.StatusCode, s.SameAccountRetryCount[accountID], maxSameAccountRetries)
if !sleepWithContext(ctx, sameAccountRetryDelay) {
return FailoverCanceled
}
return FailoverContinue
}
// 同账号重试用尽,执行临时封禁
if failoverErr.RetryableOnSameAccount {
gatewayService.TempUnscheduleRetryableError(ctx, accountID, failoverErr)
}
// 加入失败列表
s.FailedAccountIDs[accountID] = struct{}{}
// 检查是否耗尽
if s.SwitchCount >= s.MaxSwitches {
return FailoverExhausted
}
// 递增切换计数
s.SwitchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d",
accountID, failoverErr.StatusCode, s.SwitchCount, s.MaxSwitches)
// Antigravity 平台换号线性递增延时
if platform == service.PlatformAntigravity {
delay := time.Duration(s.SwitchCount-1) * time.Second
if !sleepWithContext(ctx, delay) {
return FailoverCanceled
}
}
return FailoverContinue
}
// HandleSelectionExhausted 处理选号失败(所有候选账号都在排除列表中)时的退避重试决策。
// 针对 Antigravity 单账号分组的 503 (MODEL_CAPACITY_EXHAUSTED) 场景:
// 清除排除列表、等待退避后重新选号。
//
// 返回 FailoverContinue 时,调用方应设置 SingleAccountRetry context 并 continue。
// 返回 FailoverExhausted 时,调用方应返回错误响应。
// 返回 FailoverCanceled 时,调用方应直接 return。
func (s *FailoverState) HandleSelectionExhausted(ctx context.Context) FailoverAction {
if s.LastFailoverErr != nil &&
s.LastFailoverErr.StatusCode == http.StatusServiceUnavailable &&
s.SwitchCount <= s.MaxSwitches {
log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)",
singleAccountBackoffDelay, s.SwitchCount)
if !sleepWithContext(ctx, singleAccountBackoffDelay) {
return FailoverCanceled
}
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d",
s.SwitchCount, s.MaxSwitches)
s.FailedAccountIDs = make(map[int64]struct{})
return FailoverContinue
}
return FailoverExhausted
}
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费。
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费。
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
}
// sleepWithContext 等待指定时长,返回 false 表示 context 已取消。
func sleepWithContext(ctx context.Context, d time.Duration) bool {
if d <= 0 {
return true
}
select {
case <-ctx.Done():
return false
case <-time.After(d):
return true
}
}

View File

@@ -0,0 +1,732 @@
package handler
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// Mock
// ---------------------------------------------------------------------------
// mockTempUnscheduler 记录 TempUnscheduleRetryableError 的调用信息。
type mockTempUnscheduler struct {
calls []tempUnscheduleCall
}
type tempUnscheduleCall struct {
accountID int64
failoverErr *service.UpstreamFailoverError
}
func (m *mockTempUnscheduler) TempUnscheduleRetryableError(_ context.Context, accountID int64, failoverErr *service.UpstreamFailoverError) {
m.calls = append(m.calls, tempUnscheduleCall{accountID: accountID, failoverErr: failoverErr})
}
// ---------------------------------------------------------------------------
// Helper
// ---------------------------------------------------------------------------
func newTestFailoverErr(statusCode int, retryable, forceBilling bool) *service.UpstreamFailoverError {
return &service.UpstreamFailoverError{
StatusCode: statusCode,
RetryableOnSameAccount: retryable,
ForceCacheBilling: forceBilling,
}
}
// ---------------------------------------------------------------------------
// NewFailoverState 测试
// ---------------------------------------------------------------------------
func TestNewFailoverState(t *testing.T) {
t.Run("初始化字段正确", func(t *testing.T) {
fs := NewFailoverState(5, true)
require.Equal(t, 5, fs.MaxSwitches)
require.Equal(t, 0, fs.SwitchCount)
require.NotNil(t, fs.FailedAccountIDs)
require.Empty(t, fs.FailedAccountIDs)
require.NotNil(t, fs.SameAccountRetryCount)
require.Empty(t, fs.SameAccountRetryCount)
require.Nil(t, fs.LastFailoverErr)
require.False(t, fs.ForceCacheBilling)
require.True(t, fs.hasBoundSession)
})
t.Run("无绑定会话", func(t *testing.T) {
fs := NewFailoverState(3, false)
require.Equal(t, 3, fs.MaxSwitches)
require.False(t, fs.hasBoundSession)
})
t.Run("零最大切换次数", func(t *testing.T) {
fs := NewFailoverState(0, false)
require.Equal(t, 0, fs.MaxSwitches)
})
}
// ---------------------------------------------------------------------------
// sleepWithContext 测试
// ---------------------------------------------------------------------------
func TestSleepWithContext(t *testing.T) {
t.Run("零时长立即返回true", func(t *testing.T) {
start := time.Now()
ok := sleepWithContext(context.Background(), 0)
require.True(t, ok)
require.Less(t, time.Since(start), 50*time.Millisecond)
})
t.Run("负时长立即返回true", func(t *testing.T) {
start := time.Now()
ok := sleepWithContext(context.Background(), -1*time.Second)
require.True(t, ok)
require.Less(t, time.Since(start), 50*time.Millisecond)
})
t.Run("正常等待后返回true", func(t *testing.T) {
start := time.Now()
ok := sleepWithContext(context.Background(), 50*time.Millisecond)
elapsed := time.Since(start)
require.True(t, ok)
require.GreaterOrEqual(t, elapsed, 40*time.Millisecond)
require.Less(t, elapsed, 500*time.Millisecond)
})
t.Run("已取消context立即返回false", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
start := time.Now()
ok := sleepWithContext(ctx, 5*time.Second)
require.False(t, ok)
require.Less(t, time.Since(start), 50*time.Millisecond)
})
t.Run("等待期间context取消返回false", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(30 * time.Millisecond)
cancel()
}()
start := time.Now()
ok := sleepWithContext(ctx, 5*time.Second)
elapsed := time.Since(start)
require.False(t, ok)
require.Less(t, elapsed, 500*time.Millisecond)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 基本切换流程
// ---------------------------------------------------------------------------
func TestHandleFailoverError_BasicSwitch(t *testing.T) {
t.Run("非重试错误_非Antigravity_直接切换", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(500, false, false)
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SwitchCount)
require.Contains(t, fs.FailedAccountIDs, int64(100))
require.Equal(t, err, fs.LastFailoverErr)
require.False(t, fs.ForceCacheBilling)
require.Empty(t, mock.calls, "不应调用 TempUnschedule")
})
t.Run("非重试错误_Antigravity_第一次切换无延迟", func(t *testing.T) {
// switchCount 从 0→1 时sleepFailoverDelay(ctx, 1) 的延时 = (1-1)*1s = 0
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(500, false, false)
start := time.Now()
action := fs.HandleFailoverError(context.Background(), mock, 100, service.PlatformAntigravity, err)
elapsed := time.Since(start)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SwitchCount)
require.Less(t, elapsed, 200*time.Millisecond, "第一次切换延迟应为 0")
})
t.Run("非重试错误_Antigravity_第二次切换有1秒延迟", func(t *testing.T) {
// switchCount 从 1→2 时sleepFailoverDelay(ctx, 2) 的延时 = (2-1)*1s = 1s
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
fs.SwitchCount = 1 // 模拟已切换一次
err := newTestFailoverErr(500, false, false)
start := time.Now()
action := fs.HandleFailoverError(context.Background(), mock, 200, service.PlatformAntigravity, err)
elapsed := time.Since(start)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 2, fs.SwitchCount)
require.GreaterOrEqual(t, elapsed, 800*time.Millisecond, "第二次切换延迟应约 1s")
require.Less(t, elapsed, 3*time.Second)
})
t.Run("连续切换直到耗尽", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(2, false)
// 第一次切换0→1
err1 := newTestFailoverErr(500, false, false)
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SwitchCount)
// 第二次切换1→2
err2 := newTestFailoverErr(502, false, false)
action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 2, fs.SwitchCount)
// 第三次已耗尽SwitchCount(2) >= MaxSwitches(2)
err3 := newTestFailoverErr(503, false, false)
action = fs.HandleFailoverError(context.Background(), mock, 300, "openai", err3)
require.Equal(t, FailoverExhausted, action)
require.Equal(t, 2, fs.SwitchCount, "耗尽时不应继续递增")
// 验证失败账号列表
require.Len(t, fs.FailedAccountIDs, 3)
require.Contains(t, fs.FailedAccountIDs, int64(100))
require.Contains(t, fs.FailedAccountIDs, int64(200))
require.Contains(t, fs.FailedAccountIDs, int64(300))
// LastFailoverErr 应为最后一次的错误
require.Equal(t, err3, fs.LastFailoverErr)
})
t.Run("MaxSwitches为0时首次即耗尽", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(0, false)
err := newTestFailoverErr(500, false, false)
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverExhausted, action)
require.Equal(t, 0, fs.SwitchCount)
require.Contains(t, fs.FailedAccountIDs, int64(100))
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 缓存计费 (ForceCacheBilling)
// ---------------------------------------------------------------------------
func TestHandleFailoverError_CacheBilling(t *testing.T) {
t.Run("hasBoundSession为true时设置ForceCacheBilling", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, true) // hasBoundSession=true
err := newTestFailoverErr(500, false, false)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.True(t, fs.ForceCacheBilling)
})
t.Run("failoverErr.ForceCacheBilling为true时设置", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(500, false, true) // ForceCacheBilling=true
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.True(t, fs.ForceCacheBilling)
})
t.Run("两者均为false时不设置", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(500, false, false)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.False(t, fs.ForceCacheBilling)
})
t.Run("一旦设置不会被后续错误重置", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
// 第一次ForceCacheBilling=true → 设置
err1 := newTestFailoverErr(500, false, true)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
require.True(t, fs.ForceCacheBilling)
// 第二次ForceCacheBilling=false → 仍然保持 true
err2 := newTestFailoverErr(502, false, false)
fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
require.True(t, fs.ForceCacheBilling, "ForceCacheBilling 一旦设置不应被重置")
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 同账号重试 (RetryableOnSameAccount)
// ---------------------------------------------------------------------------
func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
t.Run("第一次重试返回FailoverContinue", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(400, true, false)
start := time.Now()
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
elapsed := time.Since(start)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SameAccountRetryCount[100])
require.Equal(t, 0, fs.SwitchCount, "同账号重试不应增加切换计数")
require.NotContains(t, fs.FailedAccountIDs, int64(100), "同账号重试不应加入失败列表")
require.Empty(t, mock.calls, "同账号重试期间不应调用 TempUnschedule")
// 验证等待了 sameAccountRetryDelay (500ms)
require.GreaterOrEqual(t, elapsed, 400*time.Millisecond)
require.Less(t, elapsed, 2*time.Second)
})
t.Run("第二次重试仍返回FailoverContinue", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(400, true, false)
// 第一次
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SameAccountRetryCount[100])
// 第二次
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 2, fs.SameAccountRetryCount[100])
require.Empty(t, mock.calls, "两次重试期间均不应调用 TempUnschedule")
})
t.Run("第三次重试耗尽_触发TempUnschedule并切换", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(400, true, false)
// 第一次、第二次重试
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, 2, fs.SameAccountRetryCount[100])
// 第三次:重试已达到 maxSameAccountRetries(2),应切换账号
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SwitchCount)
require.Contains(t, fs.FailedAccountIDs, int64(100))
// 验证 TempUnschedule 被调用
require.Len(t, mock.calls, 1)
require.Equal(t, int64(100), mock.calls[0].accountID)
require.Equal(t, err, mock.calls[0].failoverErr)
})
t.Run("不同账号独立跟踪重试次数", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(5, false)
err := newTestFailoverErr(400, true, false)
// 账号 100 第一次重试
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SameAccountRetryCount[100])
// 账号 200 第一次重试(独立计数)
action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SameAccountRetryCount[200])
require.Equal(t, 1, fs.SameAccountRetryCount[100], "账号 100 的计数不应受影响")
})
t.Run("重试耗尽后再次遇到同账号_直接切换", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(5, false)
err := newTestFailoverErr(400, true, false)
// 耗尽账号 100 的重试
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
// 第三次: 重试耗尽 → 切换
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
// 再次遇到账号 100计数仍为 2条件不满足 → 直接切换
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Len(t, mock.calls, 2, "第二次耗尽也应调用 TempUnschedule")
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — TempUnschedule 调用验证
// ---------------------------------------------------------------------------
func TestHandleFailoverError_TempUnschedule(t *testing.T) {
t.Run("非重试错误不调用TempUnschedule", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(500, false, false) // RetryableOnSameAccount=false
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Empty(t, mock.calls)
})
t.Run("重试错误耗尽后调用TempUnschedule_传入正确参数", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(502, true, false)
// 耗尽重试
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
require.Len(t, mock.calls, 1)
require.Equal(t, int64(42), mock.calls[0].accountID)
require.Equal(t, 502, mock.calls[0].failoverErr.StatusCode)
require.True(t, mock.calls[0].failoverErr.RetryableOnSameAccount)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — Context 取消
// ---------------------------------------------------------------------------
func TestHandleFailoverError_ContextCanceled(t *testing.T) {
t.Run("同账号重试sleep期间context取消", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(400, true, false)
ctx, cancel := context.WithCancel(context.Background())
cancel() // 立即取消
start := time.Now()
action := fs.HandleFailoverError(ctx, mock, 100, "openai", err)
elapsed := time.Since(start)
require.Equal(t, FailoverCanceled, action)
require.Less(t, elapsed, 100*time.Millisecond, "应立即返回")
// 重试计数仍应递增
require.Equal(t, 1, fs.SameAccountRetryCount[100])
})
t.Run("Antigravity延迟期间context取消", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
fs.SwitchCount = 1 // 下一次 switchCount=2 → delay = 1s
err := newTestFailoverErr(500, false, false)
ctx, cancel := context.WithCancel(context.Background())
cancel() // 立即取消
start := time.Now()
action := fs.HandleFailoverError(ctx, mock, 100, service.PlatformAntigravity, err)
elapsed := time.Since(start)
require.Equal(t, FailoverCanceled, action)
require.Less(t, elapsed, 100*time.Millisecond, "应立即返回而非等待 1s")
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — FailedAccountIDs 跟踪
// ---------------------------------------------------------------------------
func TestHandleFailoverError_FailedAccountIDs(t *testing.T) {
t.Run("切换时添加到失败列表", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
require.Contains(t, fs.FailedAccountIDs, int64(100))
fs.HandleFailoverError(context.Background(), mock, 200, "openai", newTestFailoverErr(502, false, false))
require.Contains(t, fs.FailedAccountIDs, int64(200))
require.Len(t, fs.FailedAccountIDs, 2)
})
t.Run("耗尽时也添加到失败列表", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(0, false)
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
require.Equal(t, FailoverExhausted, action)
require.Contains(t, fs.FailedAccountIDs, int64(100))
})
t.Run("同账号重试期间不添加到失败列表", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(400, true, false))
require.Equal(t, FailoverContinue, action)
require.NotContains(t, fs.FailedAccountIDs, int64(100))
})
t.Run("同一账号多次切换不重复添加", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(5, false)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
require.Len(t, fs.FailedAccountIDs, 1, "map 天然去重")
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — LastFailoverErr 更新
// ---------------------------------------------------------------------------
func TestHandleFailoverError_LastFailoverErr(t *testing.T) {
t.Run("每次调用都更新LastFailoverErr", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err1 := newTestFailoverErr(500, false, false)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
require.Equal(t, err1, fs.LastFailoverErr)
err2 := newTestFailoverErr(502, false, false)
fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
require.Equal(t, err2, fs.LastFailoverErr)
})
t.Run("同账号重试时也更新LastFailoverErr", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(400, true, false)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, err, fs.LastFailoverErr)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 综合集成场景
// ---------------------------------------------------------------------------
func TestHandleFailoverError_IntegrationScenario(t *testing.T) {
t.Run("模拟完整failover流程_多账号混合重试与切换", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, true) // hasBoundSession=true
// 1. 账号 100 遇到可重试错误,同账号重试 2 次
retryErr := newTestFailoverErr(400, true, false)
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
require.Equal(t, FailoverContinue, action)
require.True(t, fs.ForceCacheBilling, "hasBoundSession=true 应设置 ForceCacheBilling")
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
require.Equal(t, FailoverContinue, action)
// 2. 账号 100 重试耗尽 → TempUnschedule + 切换
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SwitchCount)
require.Len(t, mock.calls, 1)
// 3. 账号 200 遇到不可重试错误 → 直接切换
switchErr := newTestFailoverErr(500, false, false)
action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", switchErr)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 2, fs.SwitchCount)
// 4. 账号 300 遇到不可重试错误 → 再切换
action = fs.HandleFailoverError(context.Background(), mock, 300, "openai", switchErr)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 3, fs.SwitchCount)
// 5. 账号 400 → 已耗尽 (SwitchCount=3 >= MaxSwitches=3)
action = fs.HandleFailoverError(context.Background(), mock, 400, "openai", switchErr)
require.Equal(t, FailoverExhausted, action)
// 最终状态验证
require.Equal(t, 3, fs.SwitchCount, "耗尽时不再递增")
require.Len(t, fs.FailedAccountIDs, 4, "4个不同账号都在失败列表中")
require.True(t, fs.ForceCacheBilling)
require.Len(t, mock.calls, 1, "只有账号 100 触发了 TempUnschedule")
})
t.Run("模拟Antigravity平台完整流程", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(2, false)
err := newTestFailoverErr(500, false, false)
// 第一次切换delay = 0s
start := time.Now()
action := fs.HandleFailoverError(context.Background(), mock, 100, service.PlatformAntigravity, err)
elapsed := time.Since(start)
require.Equal(t, FailoverContinue, action)
require.Less(t, elapsed, 200*time.Millisecond, "第一次切换延迟为 0")
// 第二次切换delay = 1s
start = time.Now()
action = fs.HandleFailoverError(context.Background(), mock, 200, service.PlatformAntigravity, err)
elapsed = time.Since(start)
require.Equal(t, FailoverContinue, action)
require.GreaterOrEqual(t, elapsed, 800*time.Millisecond, "第二次切换延迟约 1s")
// 第三次:耗尽(无延迟,因为在检查延迟之前就返回了)
start = time.Now()
action = fs.HandleFailoverError(context.Background(), mock, 300, service.PlatformAntigravity, err)
elapsed = time.Since(start)
require.Equal(t, FailoverExhausted, action)
require.Less(t, elapsed, 200*time.Millisecond, "耗尽时不应有延迟")
})
t.Run("ForceCacheBilling通过错误标志设置", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false) // hasBoundSession=false
// 第一次ForceCacheBilling=false
err1 := newTestFailoverErr(500, false, false)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
require.False(t, fs.ForceCacheBilling)
// 第二次ForceCacheBilling=trueAntigravity 粘性会话切换)
err2 := newTestFailoverErr(500, false, true)
fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
require.True(t, fs.ForceCacheBilling, "错误标志应触发 ForceCacheBilling")
// 第三次ForceCacheBilling=false但状态仍保持 true
err3 := newTestFailoverErr(500, false, false)
fs.HandleFailoverError(context.Background(), mock, 300, "openai", err3)
require.True(t, fs.ForceCacheBilling, "不应重置")
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 边界条件
// ---------------------------------------------------------------------------
func TestHandleFailoverError_EdgeCases(t *testing.T) {
t.Run("StatusCode为0的错误也能正常处理", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(0, false, false)
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
})
t.Run("AccountID为0也能正常跟踪", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(500, true, false)
action := fs.HandleFailoverError(context.Background(), mock, 0, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SameAccountRetryCount[0])
})
t.Run("负AccountID也能正常跟踪", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(500, true, false)
action := fs.HandleFailoverError(context.Background(), mock, -1, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SameAccountRetryCount[-1])
})
t.Run("空平台名称不触发Antigravity延迟", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
fs.SwitchCount = 1
err := newTestFailoverErr(500, false, false)
start := time.Now()
action := fs.HandleFailoverError(context.Background(), mock, 100, "", err)
elapsed := time.Since(start)
require.Equal(t, FailoverContinue, action)
require.Less(t, elapsed, 200*time.Millisecond, "空平台不应触发 Antigravity 延迟")
})
}
// ---------------------------------------------------------------------------
// HandleSelectionExhausted 测试
// ---------------------------------------------------------------------------
func TestHandleSelectionExhausted(t *testing.T) {
t.Run("无LastFailoverErr时返回Exhausted", func(t *testing.T) {
fs := NewFailoverState(3, false)
// LastFailoverErr 为 nil
action := fs.HandleSelectionExhausted(context.Background())
require.Equal(t, FailoverExhausted, action)
})
t.Run("非503错误返回Exhausted", func(t *testing.T) {
fs := NewFailoverState(3, false)
fs.LastFailoverErr = newTestFailoverErr(500, false, false)
action := fs.HandleSelectionExhausted(context.Background())
require.Equal(t, FailoverExhausted, action)
})
t.Run("503且未耗尽_等待后返回Continue并清除失败列表", func(t *testing.T) {
fs := NewFailoverState(3, false)
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
fs.FailedAccountIDs[100] = struct{}{}
fs.SwitchCount = 1
start := time.Now()
action := fs.HandleSelectionExhausted(context.Background())
elapsed := time.Since(start)
require.Equal(t, FailoverContinue, action)
require.Empty(t, fs.FailedAccountIDs, "应清除失败账号列表")
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "应等待约 2s")
require.Less(t, elapsed, 5*time.Second)
})
t.Run("503但SwitchCount已超过MaxSwitches_返回Exhausted", func(t *testing.T) {
fs := NewFailoverState(2, false)
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
fs.SwitchCount = 3 // > MaxSwitches(2)
start := time.Now()
action := fs.HandleSelectionExhausted(context.Background())
elapsed := time.Since(start)
require.Equal(t, FailoverExhausted, action)
require.Less(t, elapsed, 100*time.Millisecond, "不应等待")
})
t.Run("503但context已取消_返回Canceled", func(t *testing.T) {
fs := NewFailoverState(3, false)
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
ctx, cancel := context.WithCancel(context.Background())
cancel()
start := time.Now()
action := fs.HandleSelectionExhausted(ctx)
elapsed := time.Since(start)
require.Equal(t, FailoverCanceled, action)
require.Less(t, elapsed, 100*time.Millisecond, "应立即返回")
})
t.Run("503且SwitchCount等于MaxSwitches_仍可重试", func(t *testing.T) {
fs := NewFailoverState(2, false)
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
fs.SwitchCount = 2 // == MaxSwitches条件是 <=,仍可重试
action := fs.HandleSelectionExhausted(context.Background())
require.Equal(t, FailoverContinue, action)
})
}

View File

@@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
@@ -19,11 +18,13 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// GatewayHandler handles API gateway requests
@@ -35,10 +36,12 @@ type GatewayHandler struct {
billingCacheService *service.BillingCacheService
usageService *service.UsageService
apiKeyService *service.APIKeyService
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
maxAccountSwitchesGemini int
cfg *config.Config
}
// NewGatewayHandler creates a new GatewayHandler
@@ -51,6 +54,7 @@ func NewGatewayHandler(
billingCacheService *service.BillingCacheService,
usageService *service.UsageService,
apiKeyService *service.APIKeyService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
errorPassthroughService *service.ErrorPassthroughService,
cfg *config.Config,
) *GatewayHandler {
@@ -74,10 +78,12 @@ func NewGatewayHandler(
billingCacheService: billingCacheService,
usageService: usageService,
apiKeyService: apiKeyService,
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
maxAccountSwitches: maxAccountSwitches,
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
cfg: cfg,
}
}
@@ -96,6 +102,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.gateway.messages",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
@@ -122,6 +135,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
reqModel := parsedReq.Model
reqStream := parsedReq.Stream
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
@@ -161,9 +175,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
log.Printf("Increment wait count failed: %v", err)
reqLog.Warn("gateway.user_wait_counter_increment_failed", zap.Error(err))
// On error, allow request to proceed
} else if !canWait {
reqLog.Info("gateway.user_wait_queue_full", zap.Int("max_wait", maxWait))
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
@@ -180,7 +195,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 1. 首先获取用户并发槽位
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
reqLog.Warn("gateway.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
@@ -197,7 +212,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
@@ -227,17 +242,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
var sessionBoundAccountID int64
if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
if sessionBoundAccountID > 0 {
prefetchedGroupID := int64(0)
if apiKey.GroupID != nil {
prefetchedGroupID = *apiKey.GroupID
}
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID)
c.Request = c.Request.WithContext(ctx)
}
}
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
if platform == service.PlatformGemini {
maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
var lastFailoverErr *service.UpstreamFailoverError
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession)
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
@@ -247,34 +266,31 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
if err != nil {
if len(failedAccountIDs) == 0 {
if len(fs.FailedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
failedAccountIDs = make(map[int64]struct{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
case FailoverCanceled:
return
default: // FailoverExhausted
if fs.LastFailoverErr != nil {
h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return
}
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID)
setOpsSelectedAccount(c, account.ID, account.Platform)
// 检查请求拦截预热请求、SUGGESTION MODE等
if account.IsInterceptWarmupEnabled() {
@@ -302,21 +318,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
reqLog.Warn("gateway.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
reqLog.Info("gateway.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
if err == nil && canWait {
accountWaitCounted = true
}
// Ensure the wait counter is decremented if we exit before acquiring the slot.
defer func() {
releaseWait := func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
}()
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
@@ -327,17 +346,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
&streamStarted,
)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
reqLog.Warn("gateway.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
// Slot acquired: no longer waiting in queue.
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
releaseWait()
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
@@ -346,8 +363,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
requestCtx := c.Request.Context()
if switchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
if fs.SwitchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
}
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
@@ -360,43 +377,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
lastFailoverErr = failoverErr
if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true
}
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
sameAccountRetryCount[account.ID]++
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
if !sleepSameAccountRetryDelay(c.Request.Context()) {
return
}
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch action {
case FailoverContinue:
continue
}
// 同账号重试用尽,执行临时封禁并切换账号
if failoverErr.RetryableOnSameAccount {
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
}
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
case FailoverExhausted:
h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted)
return
case FailoverCanceled:
return
}
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
}
}
continue
}
// 错误响应已在Forward中处理这里只记录日志
log.Printf("Forward request failed: %v", err)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("gateway.forward_failed",
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
return
}
@@ -404,24 +401,29 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
// 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Account: account,
Subscription: subscription,
UserAgent: ua,
UserAgent: userAgent,
IPAddress: clientIP,
ForceCacheBilling: fcb,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
log.Printf("Record usage failed: %v", err)
logger.L().With(
zap.String("component", "handler.gateway.messages"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", account.ID),
).Error("gateway.record_usage_failed", zap.Error(err))
}
}(result, account, userAgent, clientIP, forceCacheBilling)
})
return
}
}
@@ -442,44 +444,36 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
for {
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
var lastFailoverErr *service.UpstreamFailoverError
fs := NewFailoverState(h.maxAccountSwitches, hasBoundSession)
retryWithFallback := false
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
for {
// 选择支持该模型的账号
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID)
if err != nil {
if len(failedAccountIDs) == 0 {
if len(fs.FailedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
failedAccountIDs = make(map[int64]struct{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
case FailoverCanceled:
return
default: // FailoverExhausted
if fs.LastFailoverErr != nil {
h.handleFailoverExhausted(c, fs.LastFailoverErr, platform, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return
}
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID)
setOpsSelectedAccount(c, account.ID, account.Platform)
// 检查请求拦截预热请求、SUGGESTION MODE等
if account.IsInterceptWarmupEnabled() {
@@ -507,20 +501,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
reqLog.Warn("gateway.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
reqLog.Info("gateway.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
if err == nil && canWait {
accountWaitCounted = true
}
defer func() {
releaseWait := func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
}()
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
@@ -531,16 +529,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
&streamStarted,
)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
reqLog.Warn("gateway.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
// Slot acquired: no longer waiting in queue.
releaseWait()
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
@@ -549,8 +546,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
requestCtx := c.Request.Context()
if switchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
if fs.SwitchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
}
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
@@ -563,18 +560,26 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err != nil {
var promptTooLongErr *service.PromptTooLongError
if errors.As(err, &promptTooLongErr) {
log.Printf("Prompt too long from antigravity: group=%d fallback_group_id=%v fallback_used=%v", currentAPIKey.GroupID, fallbackGroupID, fallbackUsed)
reqLog.Warn("gateway.prompt_too_long_from_antigravity",
zap.Any("current_group_id", currentAPIKey.GroupID),
zap.Any("fallback_group_id", fallbackGroupID),
zap.Bool("fallback_used", fallbackUsed),
)
if !fallbackUsed && fallbackGroupID != nil && *fallbackGroupID > 0 {
fallbackGroup, err := h.gatewayService.ResolveGroupByID(c.Request.Context(), *fallbackGroupID)
if err != nil {
log.Printf("Resolve fallback group failed: %v", err)
reqLog.Warn("gateway.resolve_fallback_group_failed", zap.Int64("fallback_group_id", *fallbackGroupID), zap.Error(err))
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
return
}
if fallbackGroup.Platform != service.PlatformAnthropic ||
fallbackGroup.SubscriptionType == service.SubscriptionTypeSubscription ||
fallbackGroup.FallbackGroupIDOnInvalidRequest != nil {
log.Printf("Fallback group invalid: group=%d platform=%s subscription=%s", fallbackGroup.ID, fallbackGroup.Platform, fallbackGroup.SubscriptionType)
reqLog.Warn("gateway.fallback_group_invalid",
zap.Int64("fallback_group_id", fallbackGroup.ID),
zap.String("fallback_platform", fallbackGroup.Platform),
zap.String("fallback_subscription_type", fallbackGroup.SubscriptionType),
)
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
return
}
@@ -598,43 +603,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
lastFailoverErr = failoverErr
if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true
}
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
sameAccountRetryCount[account.ID]++
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
if !sleepSameAccountRetryDelay(c.Request.Context()) {
return
}
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch action {
case FailoverContinue:
continue
}
// 同账号重试用尽,执行临时封禁并切换账号
if failoverErr.RetryableOnSameAccount {
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
}
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
case FailoverExhausted:
h.handleFailoverExhausted(c, fs.LastFailoverErr, account.Platform, streamStarted)
return
case FailoverCanceled:
return
}
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
}
}
continue
}
// 错误响应已在Forward中处理这里只记录日志
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("gateway.forward_failed",
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
return
}
@@ -642,24 +627,29 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
// 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: currentAPIKey,
User: currentAPIKey.User,
Account: usedAccount,
Account: account,
Subscription: currentSubscription,
UserAgent: ua,
UserAgent: userAgent,
IPAddress: clientIP,
ForceCacheBilling: fcb,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
log.Printf("Record usage failed: %v", err)
logger.L().With(
zap.String("component", "handler.gateway.messages"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", currentAPIKey.ID),
zap.Any("group_id", currentAPIKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", account.ID),
).Error("gateway.record_usage_failed", zap.Error(err))
}
}(result, account, userAgent, clientIP, forceCacheBilling)
})
return
}
if !retryWithFallback {
@@ -682,6 +672,17 @@ func (h *GatewayHandler) Models(c *gin.Context) {
groupID = &apiKey.Group.ID
platform = apiKey.Group.Platform
}
if forcedPlatform, ok := middleware2.GetForcePlatformFromContext(c); ok && strings.TrimSpace(forcedPlatform) != "" {
platform = forcedPlatform
}
if platform == service.PlatformSora {
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": service.DefaultSoraModels(h.cfg),
})
return
}
// Get available models from account configurations (without platform filter)
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")
@@ -893,65 +894,6 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
}
const (
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
maxSameAccountRetries = 2
// sameAccountRetryDelay 同账号重试间隔
sameAccountRetryDelay = 500 * time.Millisecond
)
// sleepSameAccountRetryDelay 同账号重试固定延时,返回 false 表示 context 已取消。
func sleepSameAccountRetryDelay(ctx context.Context) bool {
select {
case <-ctx.Done():
return false
case <-time.After(sameAccountRetryDelay):
return true
}
}
// sleepFailoverDelay 账号切换线性递增延时第1次0s、第2次1s、第3次2s…
// 返回 false 表示 context 已取消。
func sleepFailoverDelay(ctx context.Context, switchCount int) bool {
delay := time.Duration(switchCount-1) * time.Second
if delay <= 0 {
return true
}
select {
case <-ctx.Done():
return false
case <-time.After(delay):
return true
}
}
// sleepAntigravitySingleAccountBackoff Antigravity 平台单账号分组的 503 退避重试延时。
// 当分组内只有一个可用账号且上游返回 503MODEL_CAPACITY_EXHAUSTED时使用
// 采用短固定延时策略。Service 层在 SingleAccountRetry 模式下已经做了充分的原地重试
// (最多 3 次、总等待 30s所以 Handler 层的退避只需短暂等待即可。
// 返回 false 表示 context 已取消。
func sleepAntigravitySingleAccountBackoff(ctx context.Context, retryCount int) bool {
// 固定短延时2s
// Service 层已经在原地等待了足够长的时间retryDelay × 重试次数),
// Handler 层只需短暂间隔后重新进入 Service 层即可。
const delay = 2 * time.Second
log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)", delay, retryCount)
select {
case <-ctx.Done():
return false
case <-time.After(delay):
return true
}
}
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody
@@ -1040,6 +982,15 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
h.errorResponse(c, status, errType, message)
}
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
if c == nil || c.Writer == nil || c.Writer.Written() {
return false
}
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
return true
}
// errorResponse 返回Claude API格式的错误响应
func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
@@ -1067,6 +1018,12 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.gateway.count_tokens",
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
@@ -1094,6 +1051,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
reqLog = reqLog.With(zap.String("model", parsedReq.Model), zap.Bool("stream", parsedReq.Stream))
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
@@ -1127,14 +1085,15 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
if err != nil {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
reqLog.Warn("gateway.count_tokens_select_account_failed", zap.Error(err))
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable")
return
}
setOpsSelectedAccount(c, account.ID)
setOpsSelectedAccount(c, account.ID, account.Platform)
// 转发请求(不记录使用量)
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil {
log.Printf("Forward count_tokens request failed: %v", err)
reqLog.Error("gateway.count_tokens_forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
// 错误响应已在 ForwardCountTokens 中处理
return
}
@@ -1398,7 +1357,25 @@ func billingErrorDetails(err error) (status int, code, message string) {
}
msg := pkgerrors.Message(err)
if msg == "" {
msg = err.Error()
logger.L().With(
zap.String("component", "handler.gateway.billing"),
zap.Error(err),
).Warn("gateway.billing_error_missing_message")
msg = "Billing error"
}
return http.StatusForbidden, "billing_error", msg
}
func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
if task == nil {
return
}
if h.usageRecordWorkerPool != nil {
h.usageRecordWorkerPool.Submit(task)
return
}
// 回退路径worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
task(ctx)
}

View File

@@ -0,0 +1,49 @@
package handler
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGatewayEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &GatewayHandler{}
wrote := h.ensureForwardErrorResponse(c, false)
require.True(t, wrote)
require.Equal(t, http.StatusBadGateway, w.Code)
var parsed map[string]any
err := json.Unmarshal(w.Body.Bytes(), &parsed)
require.NoError(t, err)
assert.Equal(t, "error", parsed["type"])
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errorObj["type"])
assert.Equal(t, "Upstream request failed", errorObj["message"])
}
func TestGatewayEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.String(http.StatusTeapot, "already written")
h := &GatewayHandler{}
wrote := h.ensureForwardErrorResponse(c, false)
require.False(t, wrote)
require.Equal(t, http.StatusTeapot, w.Code)
assert.Equal(t, "already written", w.Body.String())
}

View File

@@ -1,51 +0,0 @@
package handler
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// sleepAntigravitySingleAccountBackoff 测试
// ---------------------------------------------------------------------------
func TestSleepAntigravitySingleAccountBackoff_ReturnsTrue(t *testing.T) {
ctx := context.Background()
start := time.Now()
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
elapsed := time.Since(start)
require.True(t, ok, "should return true when context is not canceled")
// 固定延迟 2s
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "should wait approximately 2s")
require.Less(t, elapsed, 5*time.Second, "should not wait too long")
}
func TestSleepAntigravitySingleAccountBackoff_ContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // 立即取消
start := time.Now()
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
elapsed := time.Since(start)
require.False(t, ok, "should return false when context is canceled")
require.Less(t, elapsed, 500*time.Millisecond, "should return immediately on cancel")
}
func TestSleepAntigravitySingleAccountBackoff_FixedDelay(t *testing.T) {
// 验证不同 retryCount 都使用固定 2s 延迟
ctx := context.Background()
start := time.Now()
ok := sleepAntigravitySingleAccountBackoff(ctx, 5)
elapsed := time.Since(start)
require.True(t, ok)
// 即使 retryCount=5延迟仍然是固定的 2s
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond)
require.Less(t, elapsed, 5*time.Second)
}

View File

@@ -0,0 +1,340 @@
//go:build unit
package handler
import (
"bytes"
"context"
"encoding/json"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
middleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// 目标严格验证“antigravity 账号通过 /v1/messages 提供 Claude 服务时”,
// 当账号 credentials.intercept_warmup_requests=true 且请求为 Warmup 时,
// 后端会在转发上游前直接拦截并返回 mock 响应(不依赖上游)。
type fakeSchedulerCache struct {
accounts []*service.Account
}
func (f *fakeSchedulerCache) GetSnapshot(_ context.Context, _ service.SchedulerBucket) ([]*service.Account, bool, error) {
return f.accounts, true, nil
}
func (f *fakeSchedulerCache) SetSnapshot(_ context.Context, _ service.SchedulerBucket, _ []service.Account) error {
return nil
}
func (f *fakeSchedulerCache) GetAccount(_ context.Context, _ int64) (*service.Account, error) {
return nil, nil
}
func (f *fakeSchedulerCache) SetAccount(_ context.Context, _ *service.Account) error { return nil }
func (f *fakeSchedulerCache) DeleteAccount(_ context.Context, _ int64) error { return nil }
func (f *fakeSchedulerCache) UpdateLastUsed(_ context.Context, _ map[int64]time.Time) error {
return nil
}
func (f *fakeSchedulerCache) TryLockBucket(_ context.Context, _ service.SchedulerBucket, _ time.Duration) (bool, error) {
return true, nil
}
func (f *fakeSchedulerCache) ListBuckets(_ context.Context) ([]service.SchedulerBucket, error) {
return nil, nil
}
func (f *fakeSchedulerCache) GetOutboxWatermark(_ context.Context) (int64, error) { return 0, nil }
func (f *fakeSchedulerCache) SetOutboxWatermark(_ context.Context, _ int64) error { return nil }
type fakeGroupRepo struct {
group *service.Group
}
func (f *fakeGroupRepo) Create(context.Context, *service.Group) error { return nil }
func (f *fakeGroupRepo) GetByID(context.Context, int64) (*service.Group, error) {
return f.group, nil
}
func (f *fakeGroupRepo) GetByIDLite(context.Context, int64) (*service.Group, error) {
return f.group, nil
}
func (f *fakeGroupRepo) Update(context.Context, *service.Group) error { return nil }
func (f *fakeGroupRepo) Delete(context.Context, int64) error { return nil }
func (f *fakeGroupRepo) DeleteCascade(context.Context, int64) ([]int64, error) { return nil, nil }
func (f *fakeGroupRepo) List(context.Context, pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (f *fakeGroupRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (f *fakeGroupRepo) ListActive(context.Context) ([]service.Group, error) { return nil, nil }
func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service.Group, error) {
return nil, nil
}
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil }
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
return 0, nil
}
func (f *fakeGroupRepo) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
return nil, nil
}
func (f *fakeGroupRepo) BindAccountsToGroup(context.Context, int64, []int64) error { return nil }
func (f *fakeGroupRepo) UpdateSortOrders(context.Context, []service.GroupSortOrderUpdate) error {
return nil
}
type fakeConcurrencyCache struct{}
func (f *fakeConcurrencyCache) AcquireAccountSlot(context.Context, int64, int, string) (bool, error) {
return true, nil
}
func (f *fakeConcurrencyCache) ReleaseAccountSlot(context.Context, int64, string) error { return nil }
func (f *fakeConcurrencyCache) GetAccountConcurrency(context.Context, int64) (int, error) {
return 0, nil
}
func (f *fakeConcurrencyCache) IncrementAccountWaitCount(context.Context, int64, int) (bool, error) {
return true, nil
}
func (f *fakeConcurrencyCache) DecrementAccountWaitCount(context.Context, int64) error { return nil }
func (f *fakeConcurrencyCache) GetAccountWaitingCount(context.Context, int64) (int, error) {
return 0, nil
}
func (f *fakeConcurrencyCache) AcquireUserSlot(context.Context, int64, int, string) (bool, error) {
return true, nil
}
func (f *fakeConcurrencyCache) ReleaseUserSlot(context.Context, int64, string) error { return nil }
func (f *fakeConcurrencyCache) GetUserConcurrency(context.Context, int64) (int, error) { return 0, nil }
func (f *fakeConcurrencyCache) IncrementWaitCount(context.Context, int64, int) (bool, error) {
return true, nil
}
func (f *fakeConcurrencyCache) DecrementWaitCount(context.Context, int64) error { return nil }
func (f *fakeConcurrencyCache) GetAccountsLoadBatch(context.Context, []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
return map[int64]*service.AccountLoadInfo{}, nil
}
func (f *fakeConcurrencyCache) GetUsersLoadBatch(context.Context, []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
return map[int64]*service.UserLoadInfo{}, nil
}
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
t.Helper()
schedulerCache := &fakeSchedulerCache{accounts: accounts}
schedulerSnapshot := service.NewSchedulerSnapshotService(schedulerCache, nil, nil, nil, nil)
gwSvc := service.NewGatewayService(
nil, // accountRepo (not used: scheduler snapshot hit)
&fakeGroupRepo{group: group},
nil, // usageLogRepo
nil, // userRepo
nil, // userSubRepo
nil, // userGroupRateRepo
nil, // cache (disable sticky)
nil, // cfg
schedulerSnapshot,
nil, // concurrencyService (disable load-aware; tryAcquire always acquired)
nil, // billingService
nil, // rateLimitService
nil, // billingCacheService
nil, // identityService
nil, // httpUpstream
nil, // deferredService
nil, // claudeTokenProvider
nil, // sessionLimitCache
nil, // digestStore
)
// RunModeSimple跳过计费检查避免引入 repo/cache 依赖。
cfg := &config.Config{RunMode: config.RunModeSimple}
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, cfg)
concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{})
concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0)
h := &GatewayHandler{
gatewayService: gwSvc,
billingCacheService: billingCacheSvc,
concurrencyHelper: concurrencyHelper,
// 这些字段对本测试不敏感,保持较小即可
maxAccountSwitches: 1,
maxAccountSwitchesGemini: 1,
}
cleanup := func() {
billingCacheSvc.Stop()
}
return h, cleanup
}
func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_MixedSchedulingV1(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(2001)
accountID := int64(1001)
group := &service.Group{
ID: groupID,
Hydrated: true,
Platform: service.PlatformAnthropic, // /v1/messagesClaude兼容入口
Status: service.StatusActive,
}
account := &service.Account{
ID: accountID,
Name: "ag-1",
Platform: service.PlatformAntigravity,
Type: service.AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "tok_xxx",
"intercept_warmup_requests": true,
},
Extra: map[string]any{
"mixed_scheduling": true, // 关键:允许被 anthropic 分组混合调度选中
},
Concurrency: 1,
Priority: 1,
Status: service.StatusActive,
Schedulable: true,
AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}},
}
h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account})
defer cleanup()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{
"model": "claude-sonnet-4-5",
"max_tokens": 256,
"messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}]
}`)
req := httptest.NewRequest("POST", "/v1/messages", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, group))
c.Request = req
apiKey := &service.APIKey{
ID: 3001,
UserID: 4001,
GroupID: &groupID,
Status: service.StatusActive,
User: &service.User{
ID: 4001,
Concurrency: 10,
Balance: 100,
},
Group: group,
}
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10})
h.Messages(c)
require.Equal(t, 200, rec.Code)
// 断言:确实选中了 antigravity 账号(不是纯函数测试,而是从 Handler 里验证调度结果)
selected, ok := c.Get(opsAccountIDKey)
require.True(t, ok)
require.Equal(t, accountID, selected)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "msg_mock_warmup", resp["id"])
require.Equal(t, "claude-sonnet-4-5", resp["model"])
content, ok := resp["content"].([]any)
require.True(t, ok)
require.Len(t, content, 1)
first, ok := content[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "New Conversation", first["text"])
}
func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_ForcePlatform(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(2002)
accountID := int64(1002)
group := &service.Group{
ID: groupID,
Hydrated: true,
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
}
account := &service.Account{
ID: accountID,
Name: "ag-2",
Platform: service.PlatformAntigravity,
Type: service.AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "tok_xxx",
"intercept_warmup_requests": true,
},
Concurrency: 1,
Priority: 1,
Status: service.StatusActive,
Schedulable: true,
AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}},
}
h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account})
defer cleanup()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{
"model": "claude-sonnet-4-5",
"max_tokens": 256,
"messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}]
}`)
req := httptest.NewRequest("POST", "/antigravity/v1/messages", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
// 模拟 routes/gateway.go 里的 ForcePlatform 中间件效果:
// - 写入 request.ContextService读取
// - 写入 gin.ContextHandler快速读取
ctx := context.WithValue(req.Context(), ctxkey.Group, group)
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformAntigravity)
req = req.WithContext(ctx)
c.Request = req
c.Set(string(middleware.ContextKeyForcePlatform), service.PlatformAntigravity)
apiKey := &service.APIKey{
ID: 3002,
UserID: 4002,
GroupID: &groupID,
Status: service.StatusActive,
User: &service.User{
ID: 4002,
Concurrency: 10,
Balance: 100,
},
Group: group,
}
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10})
h.Messages(c)
require.Equal(t, 200, rec.Code)
selected, ok := c.Get(opsAccountIDKey)
require.True(t, ok)
require.Equal(t, accountID, selected)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "msg_mock_warmup", resp["id"])
require.Equal(t, "claude-sonnet-4-5", resp["model"])
}

View File

@@ -4,8 +4,9 @@ import (
"context"
"encoding/json"
"fmt"
"math/rand"
"math/rand/v2"
"net/http"
"strings"
"sync"
"time"
@@ -20,14 +21,28 @@ var claudeCodeValidator = service.NewClaudeCodeValidator()
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
// 返回更新后的 context
func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
// 解析请求体为 map
var bodyMap map[string]any
if len(body) > 0 {
_ = json.Unmarshal(body, &bodyMap)
if c == nil || c.Request == nil {
return
}
// Fast path非 Claude CLI UA 直接判定 false避免热路径二次 JSON 反序列化。
if !claudeCodeValidator.ValidateUserAgent(c.GetHeader("User-Agent")) {
ctx := service.SetClaudeCodeClient(c.Request.Context(), false)
c.Request = c.Request.WithContext(ctx)
return
}
// 验证是否为 Claude Code 客户端
isClaudeCode := claudeCodeValidator.Validate(c.Request, bodyMap)
isClaudeCode := false
if !strings.Contains(c.Request.URL.Path, "messages") {
// 与 Validate 行为一致:非 messages 路径 UA 命中即可视为 Claude Code 客户端。
isClaudeCode = true
} else {
// 仅在确认为 Claude CLI 且 messages 路径时再做 body 解析。
var bodyMap map[string]any
if len(body) > 0 {
_ = json.Unmarshal(body, &bodyMap)
}
isClaudeCode = claudeCodeValidator.Validate(c.Request, bodyMap)
}
// 更新 request context
ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode)
@@ -104,31 +119,24 @@ func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFo
// wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation.
// 用于避免客户端断开或上游超时导致的并发槽位泄漏。
// 修复:添加 quit channel 确保 goroutine 及时退出,避免泄露
// 优化:基于 context.AfterFunc 注册回调,避免每请求额外守护 goroutine。
func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() {
if releaseFunc == nil {
return nil
}
var once sync.Once
quit := make(chan struct{})
var stop func() bool
release := func() {
once.Do(func() {
if stop != nil {
_ = stop()
}
releaseFunc()
close(quit) // 通知监听 goroutine 退出
})
}
go func() {
select {
case <-ctx.Done():
// Context 取消时释放资源
release()
case <-quit:
// 正常释放已完成goroutine 退出
return
}
}()
stop = context.AfterFunc(ctx, release)
return release
}
@@ -153,6 +161,32 @@ func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accou
h.concurrencyService.DecrementAccountWaitCount(ctx, accountID)
}
// TryAcquireUserSlot 尝试立即获取用户并发槽位。
// 返回值: (releaseFunc, acquired, error)
func (h *ConcurrencyHelper) TryAcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (func(), bool, error) {
result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency)
if err != nil {
return nil, false, err
}
if !result.Acquired {
return nil, false, nil
}
return result.ReleaseFunc, true, nil
}
// TryAcquireAccountSlot 尝试立即获取账号并发槽位。
// 返回值: (releaseFunc, acquired, error)
func (h *ConcurrencyHelper) TryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (func(), bool, error) {
result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
if err != nil {
return nil, false, err
}
if !result.Acquired {
return nil, false, nil
}
return result.ReleaseFunc, true, nil
}
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun.
@@ -160,13 +194,13 @@ func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency)
releaseFunc, acquired, err := h.TryAcquireUserSlot(ctx, userID, maxConcurrency)
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
if acquired {
return releaseFunc, nil
}
// Need to wait - handle streaming ping if needed
@@ -180,13 +214,13 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
releaseFunc, acquired, err := h.TryAcquireAccountSlot(ctx, accountID, maxConcurrency)
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
if acquired {
return releaseFunc, nil
}
// Need to wait - handle streaming ping if needed
@@ -196,27 +230,29 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted)
return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted, false)
}
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool, tryImmediate bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
// Try immediate acquire first (avoid unnecessary wait)
var result *service.AcquireResult
var err error
if slotType == "user" {
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
} else {
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
acquireSlot := func() (*service.AcquireResult, error) {
if slotType == "user" {
return h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
}
return h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
}
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
if tryImmediate {
result, err := acquireSlot()
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
}
// Determine if ping is needed (streaming + ping format defined)
@@ -242,7 +278,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
backoff := initialBackoff
timer := time.NewTimer(backoff)
defer timer.Stop()
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
for {
select {
@@ -268,15 +303,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
case <-timer.C:
// Try to acquire slot
var result *service.AcquireResult
var err error
if slotType == "user" {
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
} else {
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
}
result, err := acquireSlot()
if err != nil {
return nil, err
}
@@ -284,7 +311,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
if result.Acquired {
return result.ReleaseFunc, nil
}
backoff = nextBackoff(backoff, rng)
backoff = nextBackoff(backoff)
timer.Reset(backoff)
}
}
@@ -292,26 +319,22 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted)
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted, true)
}
// nextBackoff 计算下一次退避时间
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
// current: 当前退避时间
// rng: 随机数生成器(可为 nil此时不添加抖动
// 返回值下一次退避时间100ms ~ 2s 之间)
func nextBackoff(current time.Duration, rng *rand.Rand) time.Duration {
func nextBackoff(current time.Duration) time.Duration {
// 指数退避:当前时间 * 1.5
next := time.Duration(float64(current) * backoffMultiplier)
if next > maxBackoff {
next = maxBackoff
}
if rng == nil {
return next
}
// 添加 ±20% 的随机抖动jitter 范围 0.8 ~ 1.2
// 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis
jitter := 0.8 + rng.Float64()*0.4
jitter := 0.8 + rand.Float64()*0.4
jittered := time.Duration(float64(next) * jitter)
if jittered < initialBackoff {
return initialBackoff

View File

@@ -0,0 +1,106 @@
package handler
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- Task 6.2 验证: math/rand/v2 迁移后 nextBackoff 行为正确 ---
func TestNextBackoff_ExponentialGrowth(t *testing.T) {
// 验证退避时间指数增长(乘数 1.5
// 由于有随机抖动±20%),需要验证范围
current := initialBackoff // 100ms
for i := 0; i < 10; i++ {
next := nextBackoff(current)
// 退避结果应在 [initialBackoff, maxBackoff] 范围内
assert.GreaterOrEqual(t, int64(next), int64(initialBackoff),
"第 %d 次退避不应低于初始值 %v", i, initialBackoff)
assert.LessOrEqual(t, int64(next), int64(maxBackoff),
"第 %d 次退避不应超过最大值 %v", i, maxBackoff)
// 为下一轮提供当前退避值
current = next
}
}
func TestNextBackoff_BoundedByMaxBackoff(t *testing.T) {
// 即使输入非常大,输出也不超过 maxBackoff
for i := 0; i < 100; i++ {
result := nextBackoff(10 * time.Second)
assert.LessOrEqual(t, int64(result), int64(maxBackoff),
"退避值不应超过 maxBackoff")
}
}
func TestNextBackoff_BoundedByInitialBackoff(t *testing.T) {
// 即使输入非常小,输出也不低于 initialBackoff
for i := 0; i < 100; i++ {
result := nextBackoff(1 * time.Millisecond)
assert.GreaterOrEqual(t, int64(result), int64(initialBackoff),
"退避值不应低于 initialBackoff")
}
}
func TestNextBackoff_HasJitter(t *testing.T) {
// 验证多次调用会产生不同的值(随机抖动生效)
// 使用相同的输入调用 50 次,收集结果
results := make(map[time.Duration]bool)
current := 500 * time.Millisecond
for i := 0; i < 50; i++ {
result := nextBackoff(current)
results[result] = true
}
// 50 次调用应该至少有 2 个不同的值(抖动存在)
require.Greater(t, len(results), 1,
"nextBackoff 应产生随机抖动,但所有 50 次调用结果相同")
}
func TestNextBackoff_InitialValueGrows(t *testing.T) {
// 验证从初始值开始,退避趋势是增长的
current := initialBackoff
var sum time.Duration
runs := 100
for i := 0; i < runs; i++ {
next := nextBackoff(current)
sum += next
current = next
}
avg := sum / time.Duration(runs)
// 平均退避时间应大于初始值(因为指数增长 + 上限)
assert.Greater(t, int64(avg), int64(initialBackoff),
"平均退避时间应大于初始退避值")
}
func TestNextBackoff_ConvergesToMaxBackoff(t *testing.T) {
// 从初始值开始,经过多次退避后应收敛到 maxBackoff 附近
current := initialBackoff
for i := 0; i < 20; i++ {
current = nextBackoff(current)
}
// 经过 20 次迭代后,应该已经到达 maxBackoff 区间
// 由于抖动,允许 ±20% 的范围
lowerBound := time.Duration(float64(maxBackoff) * 0.8)
assert.GreaterOrEqual(t, int64(current), int64(lowerBound),
"经过多次退避后应收敛到 maxBackoff 附近")
}
func BenchmarkNextBackoff(b *testing.B) {
current := initialBackoff
for i := 0; i < b.N; i++ {
current = nextBackoff(current)
if current > maxBackoff {
current = initialBackoff
}
}
}

View File

@@ -0,0 +1,114 @@
package handler
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
type concurrencyCacheMock struct {
acquireUserSlotFn func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
acquireAccountSlotFn func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
releaseUserCalled int32
releaseAccountCalled int32
}
func (m *concurrencyCacheMock) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
if m.acquireAccountSlotFn != nil {
return m.acquireAccountSlotFn(ctx, accountID, maxConcurrency, requestID)
}
return false, nil
}
func (m *concurrencyCacheMock) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
atomic.AddInt32(&m.releaseAccountCalled, 1)
return nil
}
func (m *concurrencyCacheMock) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (m *concurrencyCacheMock) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return true, nil
}
func (m *concurrencyCacheMock) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
return nil
}
func (m *concurrencyCacheMock) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (m *concurrencyCacheMock) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
if m.acquireUserSlotFn != nil {
return m.acquireUserSlotFn(ctx, userID, maxConcurrency, requestID)
}
return false, nil
}
func (m *concurrencyCacheMock) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
atomic.AddInt32(&m.releaseUserCalled, 1)
return nil
}
func (m *concurrencyCacheMock) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
func (m *concurrencyCacheMock) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
return true, nil
}
func (m *concurrencyCacheMock) DecrementWaitCount(ctx context.Context, userID int64) error {
return nil
}
func (m *concurrencyCacheMock) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
return map[int64]*service.AccountLoadInfo{}, nil
}
func (m *concurrencyCacheMock) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
return map[int64]*service.UserLoadInfo{}, nil
}
func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
return nil
}
func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) {
cache := &concurrencyCacheMock{
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
},
}
helper := NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second)
release, acquired, err := helper.TryAcquireUserSlot(context.Background(), 101, 2)
require.NoError(t, err)
require.True(t, acquired)
require.NotNil(t, release)
release()
require.Equal(t, int32(1), atomic.LoadInt32(&cache.releaseUserCalled))
}
func TestConcurrencyHelper_TryAcquireAccountSlot_NotAcquired(t *testing.T) {
cache := &concurrencyCacheMock{
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return false, nil
},
}
helper := NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second)
release, acquired, err := helper.TryAcquireAccountSlot(context.Background(), 201, 1)
require.NoError(t, err)
require.False(t, acquired)
require.Nil(t, release)
require.Equal(t, int32(0), atomic.LoadInt32(&cache.releaseAccountCalled))
}

View File

@@ -0,0 +1,269 @@
package handler
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type helperConcurrencyCacheStub struct {
mu sync.Mutex
accountSeq []bool
userSeq []bool
accountAcquireCalls int
userAcquireCalls int
accountReleaseCalls int
userReleaseCalls int
}
func (s *helperConcurrencyCacheStub) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.accountAcquireCalls++
if len(s.accountSeq) == 0 {
return false, nil
}
v := s.accountSeq[0]
s.accountSeq = s.accountSeq[1:]
return v, nil
}
func (s *helperConcurrencyCacheStub) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
s.mu.Lock()
defer s.mu.Unlock()
s.accountReleaseCalls++
return nil
}
func (s *helperConcurrencyCacheStub) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (s *helperConcurrencyCacheStub) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return true, nil
}
func (s *helperConcurrencyCacheStub) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
return nil
}
func (s *helperConcurrencyCacheStub) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (s *helperConcurrencyCacheStub) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.userAcquireCalls++
if len(s.userSeq) == 0 {
return false, nil
}
v := s.userSeq[0]
s.userSeq = s.userSeq[1:]
return v, nil
}
func (s *helperConcurrencyCacheStub) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
s.mu.Lock()
defer s.mu.Unlock()
s.userReleaseCalls++
return nil
}
func (s *helperConcurrencyCacheStub) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
func (s *helperConcurrencyCacheStub) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
return true, nil
}
func (s *helperConcurrencyCacheStub) DecrementWaitCount(ctx context.Context, userID int64) error {
return nil
}
func (s *helperConcurrencyCacheStub) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
out := make(map[int64]*service.AccountLoadInfo, len(accounts))
for _, acc := range accounts {
out[acc.ID] = &service.AccountLoadInfo{AccountID: acc.ID}
}
return out, nil
}
func (s *helperConcurrencyCacheStub) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
out := make(map[int64]*service.UserLoadInfo, len(users))
for _, user := range users {
out[user.ID] = &service.UserLoadInfo{UserID: user.ID}
}
return out, nil
}
func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
return nil
}
func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(method, path, nil)
return c, rec
}
func validClaudeCodeBodyJSON() []byte {
return []byte(`{
"model":"claude-3-5-sonnet-20241022",
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}
}`)
}
func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
t.Run("non_cli_user_agent_sets_false", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
c.Request.Header.Set("User-Agent", "curl/8.6.0")
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON())
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
})
t.Run("cli_non_messages_path_sets_true", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodGet, "/v1/models")
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
SetClaudeCodeClientContext(c, nil)
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
})
t.Run("cli_messages_path_valid_body_sets_true", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
c.Request.Header.Set("X-App", "claude-code")
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
c.Request.Header.Set("anthropic-version", "2023-06-01")
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON())
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
})
t.Run("cli_messages_path_invalid_body_sets_false", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
// 缺少严格校验所需 header + body 字段
SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`))
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
})
}
func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) {
cache := &helperConcurrencyCacheStub{
accountSeq: []bool{false, true},
userSeq: []bool{false, true},
}
concurrency := service.NewConcurrencyService(cache)
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
t.Run("account_slot_acquired_after_retry", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, time.Second, false, &streamStarted, true)
require.NoError(t, err)
require.NotNil(t, release)
require.False(t, streamStarted)
release()
require.GreaterOrEqual(t, cache.accountAcquireCalls, 2)
require.GreaterOrEqual(t, cache.accountReleaseCalls, 1)
})
t.Run("user_slot_acquired_after_retry", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "user", 202, 3, time.Second, false, &streamStarted, true)
require.NoError(t, err)
require.NotNil(t, release)
release()
require.GreaterOrEqual(t, cache.userAcquireCalls, 2)
require.GreaterOrEqual(t, cache.userReleaseCalls, 1)
})
}
func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) {
cache := &helperConcurrencyCacheStub{
accountSeq: []bool{false, false, false},
}
concurrency := service.NewConcurrencyService(cache)
t.Run("timeout_returns_concurrency_error", func(t *testing.T) {
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 130*time.Millisecond, false, &streamStarted, true)
require.Nil(t, release)
var cErr *ConcurrencyError
require.ErrorAs(t, err, &cErr)
require.True(t, cErr.IsTimeout)
})
t.Run("stream_mode_sends_ping_before_timeout", func(t *testing.T) {
helper := NewConcurrencyHelper(concurrency, SSEPingFormatComment, 10*time.Millisecond)
c, rec := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 70*time.Millisecond, true, &streamStarted, true)
require.Nil(t, release)
var cErr *ConcurrencyError
require.ErrorAs(t, err, &cErr)
require.True(t, cErr.IsTimeout)
require.True(t, streamStarted)
require.Contains(t, rec.Body.String(), ":\n\n")
})
}
func TestWaitForSlotWithPingTimeout_AcquireError(t *testing.T) {
errCache := &helperConcurrencyCacheStubWithError{
err: errors.New("redis unavailable"),
}
concurrency := service.NewConcurrencyService(errCache)
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "account", 1, 1, 200*time.Millisecond, false, &streamStarted, true)
require.Nil(t, release)
require.Error(t, err)
require.Contains(t, err.Error(), "redis unavailable")
}
func TestAcquireAccountSlotWithWaitTimeout_ImmediateAttemptBeforeBackoff(t *testing.T) {
cache := &helperConcurrencyCacheStub{
accountSeq: []bool{false},
}
concurrency := service.NewConcurrencyService(cache)
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.AcquireAccountSlotWithWaitTimeout(c, 301, 1, 30*time.Millisecond, false, &streamStarted)
require.Nil(t, release)
var cErr *ConcurrencyError
require.ErrorAs(t, err, &cErr)
require.True(t, cErr.IsTimeout)
require.GreaterOrEqual(t, cache.accountAcquireCalls, 1)
}
type helperConcurrencyCacheStubWithError struct {
helperConcurrencyCacheStub
err error
}
func (s *helperConcurrencyCacheStubWithError) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return false, s.err
}

View File

@@ -8,11 +8,9 @@ import (
"encoding/json"
"errors"
"io"
"log"
"net/http"
"regexp"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
@@ -20,11 +18,13 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/uuid"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
@@ -143,6 +143,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError(c, http.StatusInternalServerError, "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.gemini_v1beta.models",
zap.Int64("user_id", authSubject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// 检查平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context否则要求 gemini 分组
if !middleware.HasForcePlatform(c) {
@@ -159,6 +166,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
stream := action == "streamGenerateContent"
reqLog = reqLog.With(zap.String("model", modelName), zap.String("action", action), zap.Bool("stream", stream))
body, err := io.ReadAll(c.Request.Body)
if err != nil {
@@ -187,8 +195,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait)
waitCounted := false
if err != nil {
log.Printf("Increment wait count failed: %v", err)
reqLog.Warn("gemini.user_wait_counter_increment_failed", zap.Error(err))
} else if !canWait {
reqLog.Info("gemini.user_wait_queue_full", zap.Int("max_wait", maxWait))
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
}
@@ -208,6 +217,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted)
if err != nil {
reqLog.Warn("gemini.user_slot_acquire_failed", zap.Error(err))
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
@@ -223,6 +233,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err))
status, _, message := billingErrorDetails(err)
googleError(c, status, message)
return
@@ -252,6 +263,15 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
var sessionBoundAccountID int64
if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
if sessionBoundAccountID > 0 {
prefetchedGroupID := int64(0)
if apiKey.GroupID != nil {
prefetchedGroupID = *apiKey.GroupID
}
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID)
c.Request = c.Request.WithContext(ctx)
}
}
// === Gemini 内容摘要会话 Fallback 逻辑 ===
@@ -296,8 +316,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
matchedDigestChain = foundMatchedChain
sessionBoundAccountID = foundAccountID
geminiSessionUUID = foundUUID
log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s",
safeShortPrefix(foundUUID, 8), foundAccountID, truncateDigestChain(geminiDigestChain))
reqLog.Info("gemini.digest_fallback_matched",
zap.String("session_uuid_prefix", safeShortPrefix(foundUUID, 8)),
zap.Int64("account_id", foundAccountID),
zap.String("digest_chain", truncateDigestChain(geminiDigestChain)),
)
// 关键:如果原 sessionKey 为空,使用 prefixHash + uuid 作为 sessionKey
// 这样 SelectAccountWithLoadAwareness 的粘性会话逻辑会优先使用匹配到的账号
@@ -321,11 +344,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
cleanedForUnknownBinding := false
maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
var lastFailoverErr *service.UpstreamFailoverError
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession)
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
@@ -335,41 +354,44 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
if err != nil {
if len(failedAccountIDs) == 0 {
if len(fs.FailedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
failedAccountIDs = make(map[int64]struct{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
}
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
case FailoverCanceled:
return
default: // FailoverExhausted
h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr)
return
}
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID)
setOpsSelectedAccount(c, account.ID, account.Platform)
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
// 注意Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
reqLog.Info("gemini.sticky_session_account_switched",
zap.Int64("from_account_id", sessionBoundAccountID),
zap.Int64("to_account_id", account.ID),
zap.Bool("clean_thought_signature", true),
)
body = service.CleanGeminiNativeThoughtSignatures(body)
sessionBoundAccountID = account.ID
} else if sessionKey != "" && sessionBoundAccountID == 0 && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
// 无缓存绑定但请求里已有 thoughtSignature常见于缓存丢失/TTL 过期后,客户端继续携带旧签名。
// 为避免第一次转发就 400这里做一次确定性清理让新账号重新生成签名链路。
log.Printf("[Gemini] Sticky session binding missing, cleaning thoughtSignature proactively")
reqLog.Info("gemini.sticky_session_binding_missing",
zap.Bool("clean_thought_signature", true),
)
body = service.CleanGeminiNativeThoughtSignatures(body)
cleanedForUnknownBinding = true
sessionBoundAccountID = account.ID
@@ -388,9 +410,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
accountWaitCounted := false
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
reqLog.Warn("gemini.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
reqLog.Info("gemini.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
}
@@ -412,6 +437,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
&streamStarted,
)
if err != nil {
reqLog.Warn("gemini.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
@@ -420,7 +446,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
accountWaitCounted = false
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
reqLog.Warn("gemini.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
@@ -429,8 +455,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 5) forward (根据平台分流)
var result *service.ForwardResult
requestCtx := c.Request.Context()
if switchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
if fs.SwitchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
}
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)
@@ -443,27 +469,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true
}
if switchCount >= maxAccountSwitches {
lastFailoverErr = failoverErr
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
failoverAction := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch failoverAction {
case FailoverContinue:
continue
case FailoverExhausted:
h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr)
return
case FailoverCanceled:
return
}
lastFailoverErr = failoverErr
switchCount++
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
}
}
continue
}
// ForwardNative already wrote the response
log.Printf("Gemini native forward failed: %v", err)
reqLog.Error("gemini.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
return
}
@@ -482,31 +500,39 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
account.ID,
matchedDigestChain,
); err != nil {
log.Printf("[Gemini] Failed to save digest session: %v", err)
reqLog.Warn("gemini.digest_session_save_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
// 6) record usage async (Gemini 使用长上下文双倍计费)
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string, fcb bool) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Account: account,
Subscription: subscription,
UserAgent: ua,
IPAddress: ip,
UserAgent: userAgent,
IPAddress: clientIP,
LongContextThreshold: 200000, // Gemini 200K 阈值
LongContextMultiplier: 2.0, // 超出部分双倍计费
ForceCacheBilling: fcb,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
log.Printf("Record usage failed: %v", err)
logger.L().With(
zap.String("component", "handler.gemini_v1beta.models"),
zap.Int64("user_id", authSubject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", modelName),
zap.Int64("account_id", account.ID),
).Error("gemini.record_usage_failed", zap.Error(err))
}
}(result, account, userAgent, clientIP, forceCacheBilling)
})
reqLog.Debug("gemini.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", fs.SwitchCount),
)
return
}
}

View File

@@ -39,6 +39,7 @@ type Handlers struct {
Admin *AdminHandlers
Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler
SoraGateway *SoraGatewayHandler
Setting *SettingHandler
Totp *TotpHandler
}

View File

@@ -0,0 +1,65 @@
package handler
import (
"context"
"strconv"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
func executeUserIdempotentJSON(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) {
coordinator := service.DefaultIdempotencyCoordinator()
if coordinator == nil {
data, err := execute(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, data)
return
}
actorScope := "user:0"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
actorScope = "user:" + strconv.FormatInt(subject.UserID, 10)
}
result, err := coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{
Scope: scope,
ActorScope: actorScope,
Method: c.Request.Method,
Route: c.FullPath(),
IdempotencyKey: c.GetHeader("Idempotency-Key"),
Payload: payload,
RequireKey: true,
TTL: ttl,
}, execute)
if err != nil {
if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) {
service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_fail_close")
logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=fail_close", c.Request.Method, c.FullPath(), scope)
}
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
response.ErrorFrom(c, err)
return
}
if result != nil && result.Replayed {
c.Header("X-Idempotency-Replayed", "true")
}
response.Success(c, result.Data)
}

Some files were not shown because too many files have changed in this diff Show More