Compare commits

...

45 Commits

Author SHA1 Message Date
Wesley Liddick
269a659200 Merge pull request #406 from geminiwen/main
fix(openai-oauth): 改进错误处理和代理支持
2026-01-28 13:53:44 +08:00
Wesley Liddick
2c31bf46b5 Merge pull request #401 from slovx2/heihuzi_main
feat(gemini): 为 Gemini 原生平台添加图片计费支持
2026-01-28 13:51:14 +08:00
Gemini Wen
8f6639f825 fix(response): add nil check for c.Request in error logging
Prevents panic when ErrorFrom is called in test contexts where
gin.CreateTestContext doesn't set up an HTTP request.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 19:26:44 +08:00
Gemini Wen
fc17d9d7df chore: bump version to 0.1.61 and fix tests
- Update VERSION from 0.1.46 to 0.1.61
- Remove ForceHTTP2 tests for OpenAI OAuth client (ForceHTTP2 was removed)
- Update createOpenAIReqClient test to use new single-arg signature

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 19:22:45 +08:00
Gemini Wen
ab092e88a8 fix(openai-oauth): 改进错误处理和代理支持
- 使用 ApplicationError 返回详细错误信息到前端
- 添加 User-Agent: codex-cli/0.91.0
- 移除 ForceHTTP2 以兼容 HTTP 代理
- 修复代理获取失败时静默忽略的问题
- 500 错误时记录完整错误日志

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 19:13:01 +08:00
shaw
56a1e29cdd fix(gateway): 修复 SSE 流式响应 usage 统计错误
message_delta 应完全覆盖 message_start 的 usage 数据,
而非仅在值为 0 时才更新。
2026-01-27 09:16:34 +08:00
song
0059a232a6 feat(gemini): 为 Gemini 原生平台添加图片计费支持
对齐 Antigravity 平台的图片计费逻辑:
- 添加 extractImageSize() 方法提取图片尺寸
- Forward() 和 ForwardNative() 返回 ImageCount/ImageSize
- 支持分组自定义图片价格和倍率
2026-01-26 20:51:40 +08:00
shaw
45676fdc8d fix(ci): 转义 Telegram 消息中的 Markdown 特殊字符
修复发布通知发送失败的问题,原因是 tag message 中包含未闭合的
Markdown 格式标记(如 user_id 中的 _ 被解析为斜体开始)导致
Telegram API 返回解析错误。

添加 sed 命令转义 _、*、` 和 [ 字符,避免被 Telegram Markdown
解析器错误处理。
2026-01-26 11:07:08 +08:00
Wesley Liddick
e32c5f534f Merge pull request #386 from IanShaw027/fix/openai-usage-limit-reset-time
fix(ratelimit): 修复 OpenAI usage_limit_reached 错误的重置时间解析
2026-01-26 10:22:42 +08:00
shaw
426d691c95 fix(urlvalidator): 移除 ValidateURLFormat 返回值的末尾斜杠
修复 API Key 账号 base_url 末尾带斜杠时导致的双斜杠问题
2026-01-26 10:21:41 +08:00
shaw
e9a4c8ab97 docs: 修改演示站点域名 2026-01-26 10:04:44 +08:00
ianshaw
a55cfebd09 fix(ratelimit): 修复 OpenAI usage_limit_reached 错误的重置时间解析
- 问题:OpenAI 的 usage_limit_reached 错误(需 37 小时重置)被错误地设置为 5 分钟
- 原因:handle429 只检查 Anthropic 响应头,没有解析 OpenAI 响应体中的 resets_in_seconds
- 修复:新增 parseOpenAIRateLimitResetTime 函数解析 OpenAI 响应体
- 影响:避免调度器不断尝试已达配额上限的账户
2026-01-26 09:57:44 +08:00
Wesley Liddick
34cc02f8c7 Merge pull request #393 from IanShaw027/fix/gemini-thought-signature-preserve
fix(gemini): 修复 thoughtSignature 跨账号验证错误
2026-01-26 09:23:46 +08:00
Wesley Liddick
624d9fddb7 Merge pull request #391 from geminiwen/main
fix(subscription): 修复订阅调整逻辑,已过期订阅从当前时间计算
2026-01-26 09:23:29 +08:00
Wesley Liddick
47fbe43324 Merge pull request #385 from DDZS987/fix/oauth-token-refresh-missing-project-id-retry
fix(oauth): 修复 OAuth 令牌刷新时 missing_project_id 误报问题
2026-01-26 09:22:48 +08:00
shaw
1245f07a2d feat(auth): 实现 TOTP 双因素认证功能
新增功能:
- 支持 Google Authenticator 等应用进行 TOTP 二次验证
- 用户可在个人设置中启用/禁用 2FA
- 登录时支持 TOTP 验证流程
- 管理后台可全局开关 TOTP 功能

安全增强:
- TOTP 密钥使用 AES-256-GCM 加密存储
- 添加 TOTP_ENCRYPTION_KEY 配置项,必须手动配置才能启用功能
- 防止服务重启导致加密密钥变更使用户无法登录
- 验证失败次数限制,防止暴力破解

配置说明:
- Docker 部署:在 .env 中设置 TOTP_ENCRYPTION_KEY
- 非 Docker 部署:在 config.yaml 中设置 totp.encryption_key
- 生成密钥命令:openssl rand -hex 32
2026-01-26 09:19:53 +08:00
ianshaw
839975b0cf feat(gemini): 支持 Gemini CLI 粘性会话与跨账号 thoughtSignature 清理
## 问题背景

1. Gemini CLI 没有明确的会话标识(如 Claude Code 的 metadata.user_id)
2. thoughtSignature 与具体上游账号强绑定,跨账号使用会导致 400 错误
3. 粘性会话切换账号或 cache 丢失时,旧签名会导致请求失败

## 解决方案

### 1. Gemini CLI 会话标识提取

- 从 `x-gemini-api-privileged-user-id` header 和请求体中的 tmp 目录哈希生成会话标识
- 组合策略:SHA256(privileged-user-id + ":" + tmp_dir_hash)
- 正则提取:`/\.gemini/tmp/([A-Fa-f0-9]{64})`

### 2. 跨账号 thoughtSignature 清理

实现三种场景的智能清理:

1. **Cache 命中 + 账号切换**
   - 粘性会话绑定的账号与当前选择的账号不同时清理

2. **同一请求内 failover 切换**
   - 通过 sessionBoundAccountID 跟踪,检测重试时的账号切换

3. **Gemini CLI + Cache 未命中 + 含签名**
   - 预防性清理,避免 cache 丢失后首次转发就 400
   - 仅对 Gemini CLI 请求且请求体包含 thoughtSignature 时触发

## 修改内容

### backend/internal/handler/gemini_v1beta_handler.go
- 添加 `extractGeminiCLISessionHash` 函数提取 Gemini CLI 会话标识
- 添加 `isGeminiCLIRequest` 函数识别 Gemini CLI 请求
- 实现账号切换检测与 thoughtSignature 清理逻辑
- 添加 `geminiCLITmpDirRegex` 正则表达式

### backend/internal/service/gateway_service.go
- 添加 `GetCachedSessionAccountID` 方法查询粘性会话绑定的账号 ID

### backend/internal/service/gemini_native_signature_cleaner.go (新增)
- 实现 `CleanGeminiNativeThoughtSignatures` 函数
- 递归清理 JSON 中的所有 thoughtSignature 字段
- 支持任意 JSON 顶层类型(object/array)

### backend/internal/handler/gemini_cli_session_test.go (新增)
- 测试 Gemini CLI 会话哈希提取逻辑
- 测试 tmp 目录正则匹配
- 覆盖有/无 privileged-user-id 的场景

## 影响范围

- 修复 Gemini CLI 多轮对话时账号切换导致的 400 错误
- 提高粘性会话的稳定性和容错能力
- 不影响其他客户端(Claude Code 等)的会话标识生成

## 测试

- 单元测试:go test -tags=unit ./internal/handler -run TestExtractGeminiCLISessionHash
- 单元测试:go test -tags=unit ./internal/handler -run TestGeminiCLITmpDirRegex
- 编译验证:go build ./cmd/server
2026-01-26 04:40:38 +08:00
ianshaw
8c1233393f fix(antigravity): 修复 Gemini 模型 thoughtSignature 被错误覆盖的问题
## 问题描述

在使用 Gemini 模型(gemini-3-flash-preview)时,出现 400 错误:
"Unable to submit request because Thought signature is not valid"

## 根本原因

在 `request_transformer.go` 的 `buildParts()` 函数中:
- 对于 `tool_use` 和 `thinking` 块,当 `allowDummyThought=true`(Gemini 模型)时
- 代码会无条件将客户端传入的真实 `thoughtSignature` 覆盖成 dummy 值
- 导致 Gemini API 验证签名失败(签名与上下文不匹配)

## 修复方案

修改 signature 处理逻辑:
1. **优先透传真实 signature**:如果客户端提供了有效的 signature,保留它
2. **缺失时才使用 dummy**:只有在 signature 缺失且是 Gemini 模型时,才使用 dummy signature
3. **Claude 模型特殊处理**:将 dummy signature 视为缺失,避免透传到需要真实签名的链路

## 修改内容

### request_transformer.go
- `thinking` 块(第 367-386 行):优先透传真实 signature
- `tool_use` 块(第 411-418 行):优先透传真实 signature

### request_transformer_test.go
- 修改测试用例名称,反映新的行为
- 新增测试用例验证"缺失时才使用 dummy"的逻辑

## 影响范围

- 修复 Gemini 模型在多轮对话中使用 tool_use 时的签名验证错误
- 不影响 Claude 模型的现有行为
- 提高跨账号切换时的稳定性

相关问题:#[issue_number]
2026-01-26 03:06:45 +08:00
Gemini Wen
9cdb0568cc fix(subscription): 修复订阅调整逻辑,已过期订阅从当前时间计算
- 已过期订阅延长时,从当前时间开始增加天数
- 已过期订阅不允许负向调整(缩短)
- 未过期订阅保持原逻辑,从原过期时间增减

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-25 18:12:15 +08:00
shaw
74e05b83ea fix(ratelimit): 修复 OpenAI 账号限流倒计时计算错误
- 解析 x-codex-* 响应头获取正确的重置时间
- 7d 限制用尽时使用 codex_7d_reset_after_seconds
- 提取 Normalize() 方法统一窗口规范化逻辑
2026-01-25 13:32:08 +08:00
Ubuntu
4ded9e7d49 fix(oauth): 为初始 OAuth 授权添加 LoadCodeAssist 重试机制
问题:
- 初始授权时 LoadCodeAssist 没有重试机制,失败后直接跳过
- 导致账号创建时就可能缺失 project_id
- 之后每次刷新都因为 missing_project_id 报错

修复:
- 统一使用 loadProjectIDWithRetry 方法(最多4次尝试)
- 初始授权和token刷新使用相同的重试策略
- 保留原注释说明部分账户可能没有 project_id

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-24 23:41:36 +08:00
Ubuntu
716272a1e2 fix(oauth): 彻底修复 project_id 丢失问题
根本原因:
- BuildAccountCredentials 只在 project_id 非空时才添加该字段
- LoadCodeAssist 失败时返回空字符串 → 新 credentials 不包含 project_id 键
- 普通合并逻辑只保留新 credentials 中不存在的键,无法覆盖空值

解决方案:
1. 在合并后特殊处理 project_id:如果新值为空但旧值非空,保留旧值
2. LoadCodeAssist 失败不再返回错误,只记录警告
3. Token 刷新成功(access_token 已更新)就不应标记账户为 error

改进效果:
- 即使 LoadCodeAssist 连续失败,已有的 project_id 也不会丢失
- 避免因临时网络问题将账户误标记为不可用
- 允许在下次刷新时自动重试获取 project_id

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-24 23:04:48 +08:00
shaw
9cc8352593 feat(auth): 密码重置邮件队列化与限流优化
- 邮件发送改为异步队列处理,避免并发导致发送失败
- 新增 Email 维度限流(30秒冷却期),防止邮件轰炸
- Token 验证使用常量时间比较,防止时序攻击
- 重构代码消除冗余,提取公共验证逻辑
2026-01-24 22:55:28 +08:00
shaw
43a1031e38 fix(test): 修复订阅相关测试失败问题
1. 使用未来日期(2099年)作为测试订阅的过期时间,避免
   normalizeSubscriptionStatus 将测试数据标记为过期
2. 修复 List 方法调用参数不足的问题(新增 sortBy/sortOrder 参数)
2026-01-24 21:10:02 +08:00
Wesley Liddick
a5547b2f30 Merge pull request #380 from DDZS987/fix/oauth-token-refresh-missing-project-id-retry
fix(oauth): 修复 OAuth 令牌刷新时 missing_project_id 误报问题
2026-01-24 20:29:43 +08:00
shaw
b0aa23540b feat(subscription): 订阅过期状态自动更新与服务端排序
- 新增 SubscriptionExpiryService 定时任务,每分钟更新过期订阅状态
- 订阅列表支持服务端排序(按过期时间、状态、创建时间)
- 实时显示正确的过期状态,无需等待定时任务
- 允许对已过期订阅进行续期操作
- DataTable 组件支持 serverSideSort 模式
2026-01-24 20:26:01 +08:00
Ubuntu
ffaa6c4a17 fix(oauth): 修复 OAuth 令牌刷新时 missing_project_id 误报问题
问题描述:
在网络波动环境下,LoadCodeAssist 临时失败会被错误地标记为
"missing_project_id" 不可重试错误,导致账户被禁用。但实际上
账户配置正确,手动刷新后即可恢复。

解决方案:
1. 为 LoadCodeAssist 增加重试机制(最多4次,指数退避)
2. 区分真正的配置缺失和临时网络故障:
   - 如果之前有 project_id,本次获取失败只保留旧值,不报错
   - 只有从未获取过 project_id 且本次也失败时,才标记为缺失
3. 优化错误判断逻辑,避免误报

改进效果:
- 提高在复杂网络环境(如家宽代理)下的鲁棒性
- 减少因临时网络波动导致的服务中断
- 保持真正配置错误的检测能力

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-24 17:44:56 +08:00
Wesley Liddick
fbf72f0ec4 Merge pull request #377 from lynoot/fix/non-streaming-chunk-aggregation
fix(gateway): aggregate all text chunks in non-streaming Gemini responses
2026-01-24 08:49:10 +08:00
lynoot
909b8a8f9c fix(gateway): aggregate all text chunks in non-streaming Gemini responses
Previously, collectGeminiSSE() only returned the last chunk received
from the upstream streaming response when converting to non-streaming.
This caused incomplete responses where only the final text fragment
was returned to clients.

For example, a request asking to "count from 1 to 10" would only
return "\n" (the last chunk) instead of "1\n2\n3\n...\n10\n".

This was especially problematic for JSON structured output where
the opening brace "{" from the first chunk was lost, resulting
in invalid JSON like: colors": ["red", "blue"]}

The fix:
- Collect all text parts from each SSE chunk into a slice
- Merge all collected text parts into the final response
- Reuse the same pattern as handleGeminiStreamToNonStreaming
  in antigravity_gateway_service.go

Fixes: non-streaming responses returning incomplete text
Fixes: structured output (JSON schema) returning invalid JSON
2026-01-23 13:54:09 +00:00
shaw
4a0fe3b143 feat(gateway): 增加 SUGGESTION MODE 请求拦截
扩展现有的预热请求拦截功能,新增对 SUGGESTION MODE 请求的拦截:
- 检测 messages 最后一条 user 消息是否以 [SUGGESTION MODE: 开头
- 拦截后返回空内容响应,节省 token 消耗
- 重构检测逻辑,合并为单一函数,只解析一次 JSON
2026-01-23 16:57:25 +08:00
shaw
a1292fac81 feat(oauth): 支持Anthropic的Team账号使用sk授权 2026-01-23 16:30:12 +08:00
shaw
7f98be4f91 fix(oauth): 更新 Anthropic 账号 OAuth 参数,同步最新客户端 2026-01-23 16:00:42 +08:00
shaw
fd73b8875d feat(frontend): 优化账号限流状态显示,直接展示倒计时 2026-01-23 15:48:25 +08:00
shaw
f9ab1daa3c feat: 保存并显示OAuth账号邮箱地址 2026-01-23 15:17:47 +08:00
shaw
d27b847442 fix(test): 修复测试中引用不存在的函数名
测试文件引用了 IsTokenVersionStale 函数,但实际函数名为 CheckTokenVersion,导致 CI 构建失败
2026-01-23 10:58:30 +08:00
shaw
dac6bc2228 fix(token-cache): 版本过时时使用最新token而非旧token
上次修复(2665230)只阻止了写入缓存,但仍返回旧token导致403。
现在版本过时时直接使用DB中的最新token返回。

- 重构 IsTokenVersionStale 为 CheckTokenVersion,返回最新account
- 消除重复DB查询,复用版本检查时已获取的account
2026-01-23 10:29:52 +08:00
Wesley Liddick
4bd3dbf2ce Merge pull request #354 from DuckyProject/fix/frontend-table
feat(frontend): 账号表格默认排序/持久化 + 自动刷新 + 更多菜单外部关闭
2026-01-22 21:17:48 +08:00
Wesley Liddick
226df1c23a Merge pull request #358 from 0xff26b9a8/main
fix(antigravity): 修复非流式 Claude To Antigravity 响应内容为空的问题
2026-01-22 21:16:54 +08:00
shaw
2665230a09 fix(token-cache): 修复异步刷新与请求线程的缓存竞态条件
- 新增 _token_version 版本号机制,防止过期 token 污染缓存
- TokenRefreshService 刷新成功后写入版本号并清除缓存
- TokenProvider 写入缓存前检查版本,过时则跳过
- ClearError 时同步清除 token 缓存
2026-01-22 21:09:28 +08:00
0xff26b9a8
4f0c2b794c style: gofmt antigravity_gateway_service.go 2026-01-22 14:38:55 +08:00
0xff26b9a8
e756064c19 fix(antigravity): 修复非流式 Claude To Antigravity 响应内容为空的问题
- 修复 TransformGeminiToClaude 的 JSON 解析逻辑,当 V1InternalResponse
  解析成功但 candidates 为空时,尝试直接解析为 GeminiResponse 格式
- 修复 handleClaudeStreamToNonStreaming 收集流式响应的逻辑,累积所有
  chunks 的内容而不是只保留最后一个(最后一个 chunk 通常 text 为空)
- 新增 mergeCollectedPartsToResponse 函数,合并所有类型的 parts
  (text、thinking、functionCall、inlineData),保持原始顺序
- 连续的普通 text parts 合并为一个,thinking/functionCall/inlineData 保持原样
2026-01-22 14:17:59 +08:00
Wesley Liddick
17dfb0af01 Merge pull request #346 from 0xff26b9a8/main
refactor(antigravity): 提取并同步 Schema 清理逻辑至 schema_cleaner.go
2026-01-22 08:46:11 +08:00
ducky
ff74f517df feat(frontend): 账号表格默认排序/持久化 + 自动刷新 + 更多菜单外部关闭 2026-01-21 22:43:25 +08:00
0xff26b9a8
477a9a180f fix: 修复 schema 清理逻辑 2026-01-21 10:58:39 +08:00
0xff26b9a8
da48df06d2 refactor(antigravity): 提取并同步 Schema 清理逻辑至 schema_cleaner.go
主要变更:
1. 重构代码结构:
   - 将 CleanJSONSchema 及其相关辅助函数从 request_transformer.go 提取到独立的 schema_cleaner.go 文件中,实现逻辑解耦。

2. 逻辑优化与修正:
   - 参考 Antigravity-Manager (json_schema.rs) 的实现逻辑,修正了 Schema 清洗策略。
2026-01-20 23:41:53 +08:00
125 changed files with 9077 additions and 1199 deletions

View File

@@ -222,8 +222,9 @@ jobs:
REPO="${{ github.repository }}"
GHCR_IMAGE="ghcr.io/${REPO,,}" # ${,,} converts to lowercase
# 获取 tag message 内容
# 获取 tag message 内容并转义 Markdown 特殊字符
TAG_MESSAGE='${{ steps.tag_message.outputs.message }}'
TAG_MESSAGE=$(echo "$TAG_MESSAGE" | sed 's/\([_*`\[]\)/\\\1/g')
# 限制消息长度Telegram 消息限制 4096 字符,预留空间给头尾固定内容)
if [ ${#TAG_MESSAGE} -gt 3500 ]; then

View File

@@ -18,7 +18,7 @@ English | [中文](README_CN.md)
## Demo
Try Sub2API online: **https://v2.pincc.ai/**
Try Sub2API online: **https://demo.sub2api.org/**
Demo credentials (shared demo environment; **not** created automatically for self-hosted installs):

View File

@@ -1 +1 @@
0.1.46
0.1.61

View File

@@ -70,6 +70,7 @@ func provideCleanup(
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
subscriptionExpiry *service.SubscriptionExpiryService,
usageCleanup *service.UsageCleanupService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
@@ -138,6 +139,10 @@ func provideCleanup(
accountExpiry.Stop()
return nil
}},
{"SubscriptionExpiryService", func() error {
subscriptionExpiry.Stop()
return nil
}},
{"PricingService", func() error {
pricing.Stop()
return nil

View File

@@ -63,7 +63,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService)
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
if err != nil {
return nil, err
}
totpCache := repository.NewTotpCache(redisClient)
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, totpService)
userHandler := handler.NewUserHandler(userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
@@ -165,7 +171,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
totpHandler := handler.NewTotpHandler(totpService)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
@@ -178,7 +185,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -211,6 +219,7 @@ func provideCleanup(
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
subscriptionExpiry *service.SubscriptionExpiryService,
usageCleanup *service.UsageCleanupService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
@@ -278,6 +287,10 @@ func provideCleanup(
accountExpiry.Stop()
return nil
}},
{"SubscriptionExpiryService", func() error {
subscriptionExpiry.Stop()
return nil
}},
{"PricingService", func() error {
pricing.Stop()
return nil

View File

@@ -610,6 +610,9 @@ var (
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
{Name: "username", Type: field.TypeString, Size: 100, Default: ""},
{Name: "notes", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
}
// UsersTable holds the schema information for the "users" table.
UsersTable = &schema.Table{

View File

@@ -14360,6 +14360,9 @@ type UserMutation struct {
status *string
username *string
notes *string
totp_secret_encrypted *string
totp_enabled *bool
totp_enabled_at *time.Time
clearedFields map[string]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
@@ -14937,6 +14940,140 @@ func (m *UserMutation) ResetNotes() {
m.notes = nil
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (m *UserMutation) SetTotpSecretEncrypted(s string) {
m.totp_secret_encrypted = &s
}
// TotpSecretEncrypted returns the value of the "totp_secret_encrypted" field in the mutation.
func (m *UserMutation) TotpSecretEncrypted() (r string, exists bool) {
v := m.totp_secret_encrypted
if v == nil {
return
}
return *v, true
}
// OldTotpSecretEncrypted returns the old "totp_secret_encrypted" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldTotpSecretEncrypted(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldTotpSecretEncrypted is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldTotpSecretEncrypted requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldTotpSecretEncrypted: %w", err)
}
return oldValue.TotpSecretEncrypted, nil
}
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
func (m *UserMutation) ClearTotpSecretEncrypted() {
m.totp_secret_encrypted = nil
m.clearedFields[user.FieldTotpSecretEncrypted] = struct{}{}
}
// TotpSecretEncryptedCleared returns if the "totp_secret_encrypted" field was cleared in this mutation.
func (m *UserMutation) TotpSecretEncryptedCleared() bool {
_, ok := m.clearedFields[user.FieldTotpSecretEncrypted]
return ok
}
// ResetTotpSecretEncrypted resets all changes to the "totp_secret_encrypted" field.
func (m *UserMutation) ResetTotpSecretEncrypted() {
m.totp_secret_encrypted = nil
delete(m.clearedFields, user.FieldTotpSecretEncrypted)
}
// SetTotpEnabled sets the "totp_enabled" field.
func (m *UserMutation) SetTotpEnabled(b bool) {
m.totp_enabled = &b
}
// TotpEnabled returns the value of the "totp_enabled" field in the mutation.
func (m *UserMutation) TotpEnabled() (r bool, exists bool) {
v := m.totp_enabled
if v == nil {
return
}
return *v, true
}
// OldTotpEnabled returns the old "totp_enabled" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldTotpEnabled(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldTotpEnabled is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldTotpEnabled requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldTotpEnabled: %w", err)
}
return oldValue.TotpEnabled, nil
}
// ResetTotpEnabled resets all changes to the "totp_enabled" field.
func (m *UserMutation) ResetTotpEnabled() {
m.totp_enabled = nil
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (m *UserMutation) SetTotpEnabledAt(t time.Time) {
m.totp_enabled_at = &t
}
// TotpEnabledAt returns the value of the "totp_enabled_at" field in the mutation.
func (m *UserMutation) TotpEnabledAt() (r time.Time, exists bool) {
v := m.totp_enabled_at
if v == nil {
return
}
return *v, true
}
// OldTotpEnabledAt returns the old "totp_enabled_at" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldTotpEnabledAt(ctx context.Context) (v *time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldTotpEnabledAt is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldTotpEnabledAt requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldTotpEnabledAt: %w", err)
}
return oldValue.TotpEnabledAt, nil
}
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
func (m *UserMutation) ClearTotpEnabledAt() {
m.totp_enabled_at = nil
m.clearedFields[user.FieldTotpEnabledAt] = struct{}{}
}
// TotpEnabledAtCleared returns if the "totp_enabled_at" field was cleared in this mutation.
func (m *UserMutation) TotpEnabledAtCleared() bool {
_, ok := m.clearedFields[user.FieldTotpEnabledAt]
return ok
}
// ResetTotpEnabledAt resets all changes to the "totp_enabled_at" field.
func (m *UserMutation) ResetTotpEnabledAt() {
m.totp_enabled_at = nil
delete(m.clearedFields, user.FieldTotpEnabledAt)
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *UserMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil {
@@ -15403,7 +15540,7 @@ func (m *UserMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UserMutation) Fields() []string {
fields := make([]string, 0, 11)
fields := make([]string, 0, 14)
if m.created_at != nil {
fields = append(fields, user.FieldCreatedAt)
}
@@ -15437,6 +15574,15 @@ func (m *UserMutation) Fields() []string {
if m.notes != nil {
fields = append(fields, user.FieldNotes)
}
if m.totp_secret_encrypted != nil {
fields = append(fields, user.FieldTotpSecretEncrypted)
}
if m.totp_enabled != nil {
fields = append(fields, user.FieldTotpEnabled)
}
if m.totp_enabled_at != nil {
fields = append(fields, user.FieldTotpEnabledAt)
}
return fields
}
@@ -15467,6 +15613,12 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
return m.Username()
case user.FieldNotes:
return m.Notes()
case user.FieldTotpSecretEncrypted:
return m.TotpSecretEncrypted()
case user.FieldTotpEnabled:
return m.TotpEnabled()
case user.FieldTotpEnabledAt:
return m.TotpEnabledAt()
}
return nil, false
}
@@ -15498,6 +15650,12 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
return m.OldUsername(ctx)
case user.FieldNotes:
return m.OldNotes(ctx)
case user.FieldTotpSecretEncrypted:
return m.OldTotpSecretEncrypted(ctx)
case user.FieldTotpEnabled:
return m.OldTotpEnabled(ctx)
case user.FieldTotpEnabledAt:
return m.OldTotpEnabledAt(ctx)
}
return nil, fmt.Errorf("unknown User field %s", name)
}
@@ -15584,6 +15742,27 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
}
m.SetNotes(v)
return nil
case user.FieldTotpSecretEncrypted:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetTotpSecretEncrypted(v)
return nil
case user.FieldTotpEnabled:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetTotpEnabled(v)
return nil
case user.FieldTotpEnabledAt:
v, ok := value.(time.Time)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetTotpEnabledAt(v)
return nil
}
return fmt.Errorf("unknown User field %s", name)
}
@@ -15644,6 +15823,12 @@ func (m *UserMutation) ClearedFields() []string {
if m.FieldCleared(user.FieldDeletedAt) {
fields = append(fields, user.FieldDeletedAt)
}
if m.FieldCleared(user.FieldTotpSecretEncrypted) {
fields = append(fields, user.FieldTotpSecretEncrypted)
}
if m.FieldCleared(user.FieldTotpEnabledAt) {
fields = append(fields, user.FieldTotpEnabledAt)
}
return fields
}
@@ -15661,6 +15846,12 @@ func (m *UserMutation) ClearField(name string) error {
case user.FieldDeletedAt:
m.ClearDeletedAt()
return nil
case user.FieldTotpSecretEncrypted:
m.ClearTotpSecretEncrypted()
return nil
case user.FieldTotpEnabledAt:
m.ClearTotpEnabledAt()
return nil
}
return fmt.Errorf("unknown User nullable field %s", name)
}
@@ -15702,6 +15893,15 @@ func (m *UserMutation) ResetField(name string) error {
case user.FieldNotes:
m.ResetNotes()
return nil
case user.FieldTotpSecretEncrypted:
m.ResetTotpSecretEncrypted()
return nil
case user.FieldTotpEnabled:
m.ResetTotpEnabled()
return nil
case user.FieldTotpEnabledAt:
m.ResetTotpEnabledAt()
return nil
}
return fmt.Errorf("unknown User field %s", name)
}

View File

@@ -736,6 +736,10 @@ func init() {
userDescNotes := userFields[7].Descriptor()
// user.DefaultNotes holds the default value on creation for the notes field.
user.DefaultNotes = userDescNotes.Default.(string)
// userDescTotpEnabled is the schema descriptor for totp_enabled field.
userDescTotpEnabled := userFields[9].Descriptor()
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
_ = userallowedgroupFields
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.

View File

@@ -61,6 +61,17 @@ func (User) Fields() []ent.Field {
field.String("notes").
SchemaType(map[string]string{dialect.Postgres: "text"}).
Default(""),
// TOTP 双因素认证字段
field.String("totp_secret_encrypted").
SchemaType(map[string]string{dialect.Postgres: "text"}).
Optional().
Nillable(),
field.Bool("totp_enabled").
Default(false),
field.Time("totp_enabled_at").
Optional().
Nillable(),
}
}

View File

@@ -39,6 +39,12 @@ type User struct {
Username string `json:"username,omitempty"`
// Notes holds the value of the "notes" field.
Notes string `json:"notes,omitempty"`
// TotpSecretEncrypted holds the value of the "totp_secret_encrypted" field.
TotpSecretEncrypted *string `json:"totp_secret_encrypted,omitempty"`
// TotpEnabled holds the value of the "totp_enabled" field.
TotpEnabled bool `json:"totp_enabled,omitempty"`
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserQuery when eager-loading is set.
Edges UserEdges `json:"edges"`
@@ -156,13 +162,15 @@ func (*User) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case user.FieldTotpEnabled:
values[i] = new(sql.NullBool)
case user.FieldBalance:
values[i] = new(sql.NullFloat64)
case user.FieldID, user.FieldConcurrency:
values[i] = new(sql.NullInt64)
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes:
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted:
values[i] = new(sql.NullString)
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt:
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
@@ -252,6 +260,26 @@ func (_m *User) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.Notes = value.String
}
case user.FieldTotpSecretEncrypted:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field totp_secret_encrypted", values[i])
} else if value.Valid {
_m.TotpSecretEncrypted = new(string)
*_m.TotpSecretEncrypted = value.String
}
case user.FieldTotpEnabled:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field totp_enabled", values[i])
} else if value.Valid {
_m.TotpEnabled = value.Bool
}
case user.FieldTotpEnabledAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field totp_enabled_at", values[i])
} else if value.Valid {
_m.TotpEnabledAt = new(time.Time)
*_m.TotpEnabledAt = value.Time
}
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -367,6 +395,19 @@ func (_m *User) String() string {
builder.WriteString(", ")
builder.WriteString("notes=")
builder.WriteString(_m.Notes)
builder.WriteString(", ")
if v := _m.TotpSecretEncrypted; v != nil {
builder.WriteString("totp_secret_encrypted=")
builder.WriteString(*v)
}
builder.WriteString(", ")
builder.WriteString("totp_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.TotpEnabled))
builder.WriteString(", ")
if v := _m.TotpEnabledAt; v != nil {
builder.WriteString("totp_enabled_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteByte(')')
return builder.String()
}

View File

@@ -37,6 +37,12 @@ const (
FieldUsername = "username"
// FieldNotes holds the string denoting the notes field in the database.
FieldNotes = "notes"
// FieldTotpSecretEncrypted holds the string denoting the totp_secret_encrypted field in the database.
FieldTotpSecretEncrypted = "totp_secret_encrypted"
// FieldTotpEnabled holds the string denoting the totp_enabled field in the database.
FieldTotpEnabled = "totp_enabled"
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
FieldTotpEnabledAt = "totp_enabled_at"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -134,6 +140,9 @@ var Columns = []string{
FieldStatus,
FieldUsername,
FieldNotes,
FieldTotpSecretEncrypted,
FieldTotpEnabled,
FieldTotpEnabledAt,
}
var (
@@ -188,6 +197,8 @@ var (
UsernameValidator func(string) error
// DefaultNotes holds the default value on creation for the "notes" field.
DefaultNotes string
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
DefaultTotpEnabled bool
)
// OrderOption defines the ordering options for the User queries.
@@ -253,6 +264,21 @@ func ByNotes(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldNotes, opts...).ToFunc()
}
// ByTotpSecretEncrypted orders the results by the totp_secret_encrypted field.
func ByTotpSecretEncrypted(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpSecretEncrypted, opts...).ToFunc()
}
// ByTotpEnabled orders the results by the totp_enabled field.
func ByTotpEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabled, opts...).ToFunc()
}
// ByTotpEnabledAt orders the results by the totp_enabled_at field.
func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
}
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {

View File

@@ -110,6 +110,21 @@ func Notes(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldNotes, v))
}
// TotpSecretEncrypted applies equality check predicate on the "totp_secret_encrypted" field. It's identical to TotpSecretEncryptedEQ.
func TotpSecretEncrypted(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpSecretEncrypted, v))
}
// TotpEnabled applies equality check predicate on the "totp_enabled" field. It's identical to TotpEnabledEQ.
func TotpEnabled(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabled, v))
}
// TotpEnabledAt applies equality check predicate on the "totp_enabled_at" field. It's identical to TotpEnabledAtEQ.
func TotpEnabledAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
@@ -710,6 +725,141 @@ func NotesContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldNotes, v))
}
// TotpSecretEncryptedEQ applies the EQ predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedNEQ applies the NEQ predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedNEQ(v string) predicate.User {
return predicate.User(sql.FieldNEQ(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedIn applies the In predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedIn(vs ...string) predicate.User {
return predicate.User(sql.FieldIn(FieldTotpSecretEncrypted, vs...))
}
// TotpSecretEncryptedNotIn applies the NotIn predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedNotIn(vs ...string) predicate.User {
return predicate.User(sql.FieldNotIn(FieldTotpSecretEncrypted, vs...))
}
// TotpSecretEncryptedGT applies the GT predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedGT(v string) predicate.User {
return predicate.User(sql.FieldGT(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedGTE applies the GTE predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedGTE(v string) predicate.User {
return predicate.User(sql.FieldGTE(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedLT applies the LT predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedLT(v string) predicate.User {
return predicate.User(sql.FieldLT(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedLTE applies the LTE predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedLTE(v string) predicate.User {
return predicate.User(sql.FieldLTE(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedContains applies the Contains predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedContains(v string) predicate.User {
return predicate.User(sql.FieldContains(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedHasPrefix applies the HasPrefix predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedHasPrefix(v string) predicate.User {
return predicate.User(sql.FieldHasPrefix(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedHasSuffix applies the HasSuffix predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedHasSuffix(v string) predicate.User {
return predicate.User(sql.FieldHasSuffix(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedIsNil applies the IsNil predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedIsNil() predicate.User {
return predicate.User(sql.FieldIsNull(FieldTotpSecretEncrypted))
}
// TotpSecretEncryptedNotNil applies the NotNil predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpSecretEncrypted))
}
// TotpSecretEncryptedEqualFold applies the EqualFold predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedEqualFold(v string) predicate.User {
return predicate.User(sql.FieldEqualFold(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedContainsFold applies the ContainsFold predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldTotpSecretEncrypted, v))
}
// TotpEnabledEQ applies the EQ predicate on the "totp_enabled" field.
func TotpEnabledEQ(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabled, v))
}
// TotpEnabledNEQ applies the NEQ predicate on the "totp_enabled" field.
func TotpEnabledNEQ(v bool) predicate.User {
return predicate.User(sql.FieldNEQ(FieldTotpEnabled, v))
}
// TotpEnabledAtEQ applies the EQ predicate on the "totp_enabled_at" field.
func TotpEnabledAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
}
// TotpEnabledAtNEQ applies the NEQ predicate on the "totp_enabled_at" field.
func TotpEnabledAtNEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldNEQ(FieldTotpEnabledAt, v))
}
// TotpEnabledAtIn applies the In predicate on the "totp_enabled_at" field.
func TotpEnabledAtIn(vs ...time.Time) predicate.User {
return predicate.User(sql.FieldIn(FieldTotpEnabledAt, vs...))
}
// TotpEnabledAtNotIn applies the NotIn predicate on the "totp_enabled_at" field.
func TotpEnabledAtNotIn(vs ...time.Time) predicate.User {
return predicate.User(sql.FieldNotIn(FieldTotpEnabledAt, vs...))
}
// TotpEnabledAtGT applies the GT predicate on the "totp_enabled_at" field.
func TotpEnabledAtGT(v time.Time) predicate.User {
return predicate.User(sql.FieldGT(FieldTotpEnabledAt, v))
}
// TotpEnabledAtGTE applies the GTE predicate on the "totp_enabled_at" field.
func TotpEnabledAtGTE(v time.Time) predicate.User {
return predicate.User(sql.FieldGTE(FieldTotpEnabledAt, v))
}
// TotpEnabledAtLT applies the LT predicate on the "totp_enabled_at" field.
func TotpEnabledAtLT(v time.Time) predicate.User {
return predicate.User(sql.FieldLT(FieldTotpEnabledAt, v))
}
// TotpEnabledAtLTE applies the LTE predicate on the "totp_enabled_at" field.
func TotpEnabledAtLTE(v time.Time) predicate.User {
return predicate.User(sql.FieldLTE(FieldTotpEnabledAt, v))
}
// TotpEnabledAtIsNil applies the IsNil predicate on the "totp_enabled_at" field.
func TotpEnabledAtIsNil() predicate.User {
return predicate.User(sql.FieldIsNull(FieldTotpEnabledAt))
}
// TotpEnabledAtNotNil applies the NotNil predicate on the "totp_enabled_at" field.
func TotpEnabledAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
}
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.User {
return predicate.User(func(s *sql.Selector) {

View File

@@ -167,6 +167,48 @@ func (_c *UserCreate) SetNillableNotes(v *string) *UserCreate {
return _c
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (_c *UserCreate) SetTotpSecretEncrypted(v string) *UserCreate {
_c.mutation.SetTotpSecretEncrypted(v)
return _c
}
// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
func (_c *UserCreate) SetNillableTotpSecretEncrypted(v *string) *UserCreate {
if v != nil {
_c.SetTotpSecretEncrypted(*v)
}
return _c
}
// SetTotpEnabled sets the "totp_enabled" field.
func (_c *UserCreate) SetTotpEnabled(v bool) *UserCreate {
_c.mutation.SetTotpEnabled(v)
return _c
}
// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
func (_c *UserCreate) SetNillableTotpEnabled(v *bool) *UserCreate {
if v != nil {
_c.SetTotpEnabled(*v)
}
return _c
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (_c *UserCreate) SetTotpEnabledAt(v time.Time) *UserCreate {
_c.mutation.SetTotpEnabledAt(v)
return _c
}
// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
if v != nil {
_c.SetTotpEnabledAt(*v)
}
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -362,6 +404,10 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultNotes
_c.mutation.SetNotes(v)
}
if _, ok := _c.mutation.TotpEnabled(); !ok {
v := user.DefaultTotpEnabled
_c.mutation.SetTotpEnabled(v)
}
return nil
}
@@ -422,6 +468,9 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.Notes(); !ok {
return &ValidationError{Name: "notes", err: errors.New(`ent: missing required field "User.notes"`)}
}
if _, ok := _c.mutation.TotpEnabled(); !ok {
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
}
return nil
}
@@ -493,6 +542,18 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldNotes, field.TypeString, value)
_node.Notes = value
}
if value, ok := _c.mutation.TotpSecretEncrypted(); ok {
_spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
_node.TotpSecretEncrypted = &value
}
if value, ok := _c.mutation.TotpEnabled(); ok {
_spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
_node.TotpEnabled = value
}
if value, ok := _c.mutation.TotpEnabledAt(); ok {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
_node.TotpEnabledAt = &value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -815,6 +876,54 @@ func (u *UserUpsert) UpdateNotes() *UserUpsert {
return u
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (u *UserUpsert) SetTotpSecretEncrypted(v string) *UserUpsert {
u.Set(user.FieldTotpSecretEncrypted, v)
return u
}
// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
func (u *UserUpsert) UpdateTotpSecretEncrypted() *UserUpsert {
u.SetExcluded(user.FieldTotpSecretEncrypted)
return u
}
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
func (u *UserUpsert) ClearTotpSecretEncrypted() *UserUpsert {
u.SetNull(user.FieldTotpSecretEncrypted)
return u
}
// SetTotpEnabled sets the "totp_enabled" field.
func (u *UserUpsert) SetTotpEnabled(v bool) *UserUpsert {
u.Set(user.FieldTotpEnabled, v)
return u
}
// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
func (u *UserUpsert) UpdateTotpEnabled() *UserUpsert {
u.SetExcluded(user.FieldTotpEnabled)
return u
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (u *UserUpsert) SetTotpEnabledAt(v time.Time) *UserUpsert {
u.Set(user.FieldTotpEnabledAt, v)
return u
}
// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
func (u *UserUpsert) UpdateTotpEnabledAt() *UserUpsert {
u.SetExcluded(user.FieldTotpEnabledAt)
return u
}
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
u.SetNull(user.FieldTotpEnabledAt)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -1021,6 +1130,62 @@ func (u *UserUpsertOne) UpdateNotes() *UserUpsertOne {
})
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (u *UserUpsertOne) SetTotpSecretEncrypted(v string) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetTotpSecretEncrypted(v)
})
}
// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateTotpSecretEncrypted() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateTotpSecretEncrypted()
})
}
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
func (u *UserUpsertOne) ClearTotpSecretEncrypted() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.ClearTotpSecretEncrypted()
})
}
// SetTotpEnabled sets the "totp_enabled" field.
func (u *UserUpsertOne) SetTotpEnabled(v bool) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetTotpEnabled(v)
})
}
// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateTotpEnabled() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateTotpEnabled()
})
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (u *UserUpsertOne) SetTotpEnabledAt(v time.Time) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetTotpEnabledAt(v)
})
}
// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateTotpEnabledAt() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateTotpEnabledAt()
})
}
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.ClearTotpEnabledAt()
})
}
// Exec executes the query.
func (u *UserUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -1393,6 +1558,62 @@ func (u *UserUpsertBulk) UpdateNotes() *UserUpsertBulk {
})
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (u *UserUpsertBulk) SetTotpSecretEncrypted(v string) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetTotpSecretEncrypted(v)
})
}
// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateTotpSecretEncrypted() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateTotpSecretEncrypted()
})
}
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
func (u *UserUpsertBulk) ClearTotpSecretEncrypted() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.ClearTotpSecretEncrypted()
})
}
// SetTotpEnabled sets the "totp_enabled" field.
func (u *UserUpsertBulk) SetTotpEnabled(v bool) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetTotpEnabled(v)
})
}
// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateTotpEnabled() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateTotpEnabled()
})
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (u *UserUpsertBulk) SetTotpEnabledAt(v time.Time) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetTotpEnabledAt(v)
})
}
// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateTotpEnabledAt() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateTotpEnabledAt()
})
}
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.ClearTotpEnabledAt()
})
}
// Exec executes the query.
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {

View File

@@ -187,6 +187,60 @@ func (_u *UserUpdate) SetNillableNotes(v *string) *UserUpdate {
return _u
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (_u *UserUpdate) SetTotpSecretEncrypted(v string) *UserUpdate {
_u.mutation.SetTotpSecretEncrypted(v)
return _u
}
// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
func (_u *UserUpdate) SetNillableTotpSecretEncrypted(v *string) *UserUpdate {
if v != nil {
_u.SetTotpSecretEncrypted(*v)
}
return _u
}
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
func (_u *UserUpdate) ClearTotpSecretEncrypted() *UserUpdate {
_u.mutation.ClearTotpSecretEncrypted()
return _u
}
// SetTotpEnabled sets the "totp_enabled" field.
func (_u *UserUpdate) SetTotpEnabled(v bool) *UserUpdate {
_u.mutation.SetTotpEnabled(v)
return _u
}
// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
func (_u *UserUpdate) SetNillableTotpEnabled(v *bool) *UserUpdate {
if v != nil {
_u.SetTotpEnabled(*v)
}
return _u
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (_u *UserUpdate) SetTotpEnabledAt(v time.Time) *UserUpdate {
_u.mutation.SetTotpEnabledAt(v)
return _u
}
// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
func (_u *UserUpdate) SetNillableTotpEnabledAt(v *time.Time) *UserUpdate {
if v != nil {
_u.SetTotpEnabledAt(*v)
}
return _u
}
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
_u.mutation.ClearTotpEnabledAt()
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -603,6 +657,21 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.Notes(); ok {
_spec.SetField(user.FieldNotes, field.TypeString, value)
}
if value, ok := _u.mutation.TotpSecretEncrypted(); ok {
_spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
}
if _u.mutation.TotpSecretEncryptedCleared() {
_spec.ClearField(user.FieldTotpSecretEncrypted, field.TypeString)
}
if value, ok := _u.mutation.TotpEnabled(); ok {
_spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
}
if value, ok := _u.mutation.TotpEnabledAt(); ok {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
}
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1147,6 +1216,60 @@ func (_u *UserUpdateOne) SetNillableNotes(v *string) *UserUpdateOne {
return _u
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (_u *UserUpdateOne) SetTotpSecretEncrypted(v string) *UserUpdateOne {
_u.mutation.SetTotpSecretEncrypted(v)
return _u
}
// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableTotpSecretEncrypted(v *string) *UserUpdateOne {
if v != nil {
_u.SetTotpSecretEncrypted(*v)
}
return _u
}
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
func (_u *UserUpdateOne) ClearTotpSecretEncrypted() *UserUpdateOne {
_u.mutation.ClearTotpSecretEncrypted()
return _u
}
// SetTotpEnabled sets the "totp_enabled" field.
func (_u *UserUpdateOne) SetTotpEnabled(v bool) *UserUpdateOne {
_u.mutation.SetTotpEnabled(v)
return _u
}
// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableTotpEnabled(v *bool) *UserUpdateOne {
if v != nil {
_u.SetTotpEnabled(*v)
}
return _u
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (_u *UserUpdateOne) SetTotpEnabledAt(v time.Time) *UserUpdateOne {
_u.mutation.SetTotpEnabledAt(v)
return _u
}
// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableTotpEnabledAt(v *time.Time) *UserUpdateOne {
if v != nil {
_u.SetTotpEnabledAt(*v)
}
return _u
}
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
_u.mutation.ClearTotpEnabledAt()
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1593,6 +1716,21 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if value, ok := _u.mutation.Notes(); ok {
_spec.SetField(user.FieldNotes, field.TypeString, value)
}
if value, ok := _u.mutation.TotpSecretEncrypted(); ok {
_spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
}
if _u.mutation.TotpSecretEncryptedCleared() {
_spec.ClearField(user.FieldTotpSecretEncrypted, field.TypeString)
}
if value, ok := _u.mutation.TotpEnabled(); ok {
_spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
}
if value, ok := _u.mutation.TotpEnabledAt(); ok {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
}
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,

View File

@@ -37,6 +37,7 @@ require (
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
@@ -106,6 +107,7 @@ require (
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/pquerna/otp v1.5.0 // indirect
github.com/quic-go/qpack v0.6.0 // indirect
github.com/quic-go/quic-go v0.57.1 // indirect
github.com/refraction-networking/utls v1.8.1 // indirect

View File

@@ -20,6 +20,8 @@ github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
@@ -217,6 +219,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=

View File

@@ -47,6 +47,7 @@ type Config struct {
Redis RedisConfig `mapstructure:"redis"`
Ops OpsConfig `mapstructure:"ops"`
JWT JWTConfig `mapstructure:"jwt"`
Totp TotpConfig `mapstructure:"totp"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
@@ -466,6 +467,16 @@ type JWTConfig struct {
ExpireHour int `mapstructure:"expire_hour"`
}
// TotpConfig TOTP 双因素认证配置
type TotpConfig struct {
// EncryptionKey 用于加密 TOTP 密钥的 AES-256 密钥32 字节 hex 编码)
// 如果为空,将自动生成一个随机密钥(仅适用于开发环境)
EncryptionKey string `mapstructure:"encryption_key"`
// EncryptionKeyConfigured 标记加密密钥是否为手动配置(非自动生成)
// 只有手动配置了密钥才允许在管理后台启用 TOTP 功能
EncryptionKeyConfigured bool `mapstructure:"-"`
}
type TurnstileConfig struct {
Required bool `mapstructure:"required"`
}
@@ -626,6 +637,20 @@ func Load() (*Config, error) {
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
}
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
if cfg.Totp.EncryptionKey == "" {
key, err := generateJWTSecret(32) // Reuse the same random generation function
if err != nil {
return nil, fmt.Errorf("generate totp encryption key error: %w", err)
}
cfg.Totp.EncryptionKey = key
cfg.Totp.EncryptionKeyConfigured = false
log.Println("Warning: TOTP encryption key auto-generated. Consider setting a fixed key for production.")
} else {
cfg.Totp.EncryptionKeyConfigured = true
}
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("validate config error: %w", err)
}
@@ -756,6 +781,9 @@ func setDefaults() {
viper.SetDefault("jwt.secret", "")
viper.SetDefault("jwt.expire_hour", 24)
// TOTP
viper.SetDefault("totp.encryption_key", "")
// Default
// Admin credentials are created via the setup flow (web wizard / CLI / AUTO_SETUP).
// Do not ship fixed defaults here to avoid insecure "known credentials" in production.

View File

@@ -547,9 +547,18 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
}
}
// 如果 project_id 获取失败,先更新凭证,再标记账户为 error
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
// 这确保了即使 LoadCodeAssist 失败project_id 也不会丢失
if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
newCredentials["project_id"] = oldProjectID
}
}
// 如果 project_id 获取失败,更新凭证但不标记为 error
// LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试
if tokenInfo.ProjectIDMissing {
// 先更新凭证
// 先更新凭证token 本身刷新成功了)
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: newCredentials,
})
@@ -557,14 +566,10 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
return
}
// 标记账户为 error
if setErr := h.adminService.SetAccountError(c.Request.Context(), accountID, "missing_project_id: 账户缺少project id可能无法使用Antigravity"); setErr != nil {
response.InternalError(c, "Failed to set account error: "+setErr.Error())
return
}
// 标记为 error,只返回警告信息
response.Success(c, gin.H{
"message": "Token refreshed but project_id is missing, account marked as error",
"warning": "missing_project_id",
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
"warning": "missing_project_id_temporary",
})
return
}
@@ -668,6 +673,15 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
return
}
// 清除错误后,同时清除 token 缓存,确保下次请求会获取最新的 token触发刷新或从 DB 读取)
// 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil {
// 缓存失效失败只记录日志,不影响主流程
_ = c.Error(invalidateErr)
}
}
response.Success(c, dto.AccountFromService(account))
}

View File

@@ -48,6 +48,9 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
TotpEnabled: settings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername,
@@ -89,9 +92,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
// UpdateSettingsRequest 更新设置请求
type UpdateSettingsRequest struct {
// 注册设置
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
// 邮件服务设置
SMTPHost string `json:"smtp_host"`
@@ -198,6 +203,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
// TOTP 双因素认证参数验证
// 只有手动配置了加密密钥才允许启用 TOTP 功能
if req.TotpEnabled && !previousSettings.TotpEnabled {
// 尝试启用 TOTP检查加密密钥是否已手动配置
if !h.settingService.IsTotpEncryptionKeyConfigured() {
response.BadRequest(c, "Cannot enable TOTP: TOTP_ENCRYPTION_KEY environment variable must be configured first. Generate a key with 'openssl rand -hex 32' and set it in your environment.")
return
}
}
// LinuxDo Connect 参数验证
if req.LinuxDoConnectEnabled {
req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID)
@@ -243,6 +258,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
PromoCodeEnabled: req.PromoCodeEnabled,
PasswordResetEnabled: req.PasswordResetEnabled,
TotpEnabled: req.TotpEnabled,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
@@ -318,6 +335,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
TotpEnabled: updatedSettings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername,
@@ -384,6 +404,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
changed = append(changed, "email_verify_enabled")
}
if before.PasswordResetEnabled != after.PasswordResetEnabled {
changed = append(changed, "password_reset_enabled")
}
if before.TotpEnabled != after.TotpEnabled {
changed = append(changed, "totp_enabled")
}
if before.SMTPHost != after.SMTPHost {
changed = append(changed, "smtp_host")
}

View File

@@ -77,7 +77,11 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
}
status := c.Query("status")
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
// Parse sorting parameters
sortBy := c.DefaultQuery("sort_by", "created_at")
sortOrder := c.DefaultQuery("sort_order", "desc")
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -1,6 +1,8 @@
package handler
import (
"log/slog"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
@@ -18,16 +20,18 @@ type AuthHandler struct {
userService *service.UserService
settingSvc *service.SettingService
promoService *service.PromoService
totpService *service.TotpService
}
// NewAuthHandler creates a new AuthHandler
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService) *AuthHandler {
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, totpService *service.TotpService) *AuthHandler {
return &AuthHandler{
cfg: cfg,
authService: authService,
userService: userService,
settingSvc: settingService,
promoService: promoService,
totpService: totpService,
}
}
@@ -144,6 +148,100 @@ func (h *AuthHandler) Login(c *gin.Context) {
return
}
// Check if TOTP 2FA is enabled for this user
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
// Create a temporary login session for 2FA
tempToken, err := h.totpService.CreateLoginSession(c.Request.Context(), user.ID, user.Email)
if err != nil {
response.InternalError(c, "Failed to create 2FA session")
return
}
response.Success(c, TotpLoginResponse{
Requires2FA: true,
TempToken: tempToken,
UserEmailMasked: service.MaskEmail(user.Email),
})
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
User: dto.UserFromService(user),
})
}
// TotpLoginResponse represents the response when 2FA is required
type TotpLoginResponse struct {
Requires2FA bool `json:"requires_2fa"`
TempToken string `json:"temp_token,omitempty"`
UserEmailMasked string `json:"user_email_masked,omitempty"`
}
// Login2FARequest represents the 2FA login request
type Login2FARequest struct {
TempToken string `json:"temp_token" binding:"required"`
TotpCode string `json:"totp_code" binding:"required,len=6"`
}
// Login2FA completes the login with 2FA verification
// POST /api/v1/auth/login/2fa
func (h *AuthHandler) Login2FA(c *gin.Context) {
var req Login2FARequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
slog.Debug("login_2fa_request",
"temp_token_len", len(req.TempToken),
"totp_code_len", len(req.TotpCode))
// Get the login session
session, err := h.totpService.GetLoginSession(c.Request.Context(), req.TempToken)
if err != nil || session == nil {
tokenPrefix := ""
if len(req.TempToken) >= 8 {
tokenPrefix = req.TempToken[:8]
}
slog.Debug("login_2fa_session_invalid",
"temp_token_prefix", tokenPrefix,
"error", err)
response.BadRequest(c, "Invalid or expired 2FA session")
return
}
slog.Debug("login_2fa_session_found",
"user_id", session.UserID,
"email", session.Email)
// Verify the TOTP code
if err := h.totpService.VerifyCode(c.Request.Context(), session.UserID, req.TotpCode); err != nil {
slog.Debug("login_2fa_verify_failed",
"user_id", session.UserID,
"error", err)
response.ErrorFrom(c, err)
return
}
// Delete the login session
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
// Get the user
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Generate the JWT token
token, err := h.authService.GenerateToken(user)
if err != nil {
response.InternalError(c, "Failed to generate token")
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
@@ -247,3 +345,85 @@ func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
BonusAmount: promoCode.BonusAmount,
})
}
// ForgotPasswordRequest 忘记密码请求
type ForgotPasswordRequest struct {
Email string `json:"email" binding:"required,email"`
TurnstileToken string `json:"turnstile_token"`
}
// ForgotPasswordResponse 忘记密码响应
type ForgotPasswordResponse struct {
Message string `json:"message"`
}
// ForgotPassword 请求密码重置
// POST /api/v1/auth/forgot-password
func (h *AuthHandler) ForgotPassword(c *gin.Context) {
var req ForgotPasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Turnstile 验证
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
response.ErrorFrom(c, err)
return
}
// Build frontend base URL from request
scheme := "https"
if c.Request.TLS == nil {
// Check X-Forwarded-Proto header (common in reverse proxy setups)
if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" {
scheme = proto
} else {
scheme = "http"
}
}
frontendBaseURL := scheme + "://" + c.Request.Host
// Request password reset (async)
// Note: This returns success even if email doesn't exist (to prevent enumeration)
if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, ForgotPasswordResponse{
Message: "If your email is registered, you will receive a password reset link shortly.",
})
}
// ResetPasswordRequest 重置密码请求
type ResetPasswordRequest struct {
Email string `json:"email" binding:"required,email"`
Token string `json:"token" binding:"required"`
NewPassword string `json:"new_password" binding:"required,min=6"`
}
// ResetPasswordResponse 重置密码响应
type ResetPasswordResponse struct {
Message string `json:"message"`
}
// ResetPassword 重置密码
// POST /api/v1/auth/reset-password
func (h *AuthHandler) ResetPassword(c *gin.Context) {
var req ResetPasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Reset password
if err := h.authService.ResetPassword(c.Request.Context(), req.Email, req.Token, req.NewPassword); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, ResetPasswordResponse{
Message: "Your password has been reset successfully. You can now log in with your new password.",
})
}

View File

@@ -2,9 +2,12 @@ package dto
// SystemSettings represents the admin settings API response payload.
type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
@@ -54,21 +57,23 @@ type SystemSettings struct {
}
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version"`
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version"`
}
// StreamTimeoutSettings 流超时处理配置 DTO

View File

@@ -209,17 +209,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account := selection.Account
setOpsSelectedAccount(c, account.ID)
// 检查预热请求拦截(在账号选择后、转发前检查
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
// 检查请求拦截(预热请求、SUGGESTION MODE等
if account.IsInterceptWarmupEnabled() {
interceptType := detectInterceptType(body)
if interceptType != InterceptTypeNone {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if reqStream {
sendMockInterceptStream(c, reqModel, interceptType)
} else {
sendMockInterceptResponse(c, reqModel, interceptType)
}
return
}
if reqStream {
sendMockWarmupStream(c, reqModel)
} else {
sendMockWarmupResponse(c, reqModel)
}
return
}
// 3. 获取账号并发槽位
@@ -344,17 +347,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account := selection.Account
setOpsSelectedAccount(c, account.ID)
// 检查预热请求拦截(在账号选择后、转发前检查
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
// 检查请求拦截(预热请求、SUGGESTION MODE等
if account.IsInterceptWarmupEnabled() {
interceptType := detectInterceptType(body)
if interceptType != InterceptTypeNone {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if reqStream {
sendMockInterceptStream(c, reqModel, interceptType)
} else {
sendMockInterceptResponse(c, reqModel, interceptType)
}
return
}
if reqStream {
sendMockWarmupStream(c, reqModel)
} else {
sendMockWarmupResponse(c, reqModel)
}
return
}
// 3. 获取账号并发槽位
@@ -765,17 +771,30 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
}
}
// isWarmupRequest 检测是否为预热请求标题生成、Warmup等
func isWarmupRequest(body []byte) bool {
// 快速检查如果body不包含关键字直接返回false
// InterceptType 表示请求拦截类型
type InterceptType int
const (
InterceptTypeNone InterceptType = iota
InterceptTypeWarmup // 预热请求(返回 "New Conversation"
InterceptTypeSuggestionMode // SUGGESTION MODE返回空字符串
)
// detectInterceptType 检测请求是否需要拦截,返回拦截类型
func detectInterceptType(body []byte) InterceptType {
// 快速检查:如果不包含任何关键字,直接返回
bodyStr := string(body)
if !strings.Contains(bodyStr, "title") && !strings.Contains(bodyStr, "Warmup") {
return false
hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:")
hasWarmupKeyword := strings.Contains(bodyStr, "title") || strings.Contains(bodyStr, "Warmup")
if !hasSuggestionMode && !hasWarmupKeyword {
return InterceptTypeNone
}
// 解析完整请求
// 解析请求(只解析一次)
var req struct {
Messages []struct {
Role string `json:"role"`
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
@@ -786,43 +805,71 @@ func isWarmupRequest(body []byte) bool {
} `json:"system"`
}
if err := json.Unmarshal(body, &req); err != nil {
return false
return InterceptTypeNone
}
// 检查 messages 中的标题提示模式
for _, msg := range req.Messages {
for _, content := range msg.Content {
if content.Type == "text" {
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
content.Text == "Warmup" {
return true
// 检查 SUGGESTION MODE最后一条 user 消息)
if hasSuggestionMode && len(req.Messages) > 0 {
lastMsg := req.Messages[len(req.Messages)-1]
if lastMsg.Role == "user" && len(lastMsg.Content) > 0 &&
lastMsg.Content[0].Type == "text" &&
strings.HasPrefix(lastMsg.Content[0].Text, "[SUGGESTION MODE:") {
return InterceptTypeSuggestionMode
}
}
// 检查 Warmup 请求
if hasWarmupKeyword {
// 检查 messages 中的标题提示模式
for _, msg := range req.Messages {
for _, content := range msg.Content {
if content.Type == "text" {
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
content.Text == "Warmup" {
return InterceptTypeWarmup
}
}
}
}
// 检查 system 中的标题提取模式
for _, sys := range req.System {
if strings.Contains(sys.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
return InterceptTypeWarmup
}
}
}
// 检查 system 中的标题提取模式
for _, system := range req.System {
if strings.Contains(system.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
return true
}
}
return false
return InterceptTypeNone
}
// sendMockWarmupStream 发送流式 mock 响应(用于预热请求拦截)
func sendMockWarmupStream(c *gin.Context, model string) {
// sendMockInterceptStream 发送流式 mock 响应(用于请求拦截)
func sendMockInterceptStream(c *gin.Context, model string, interceptType InterceptType) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
// 根据拦截类型决定响应内容
var msgID string
var outputTokens int
var textDeltas []string
switch interceptType {
case InterceptTypeSuggestionMode:
msgID = "msg_mock_suggestion"
outputTokens = 1
textDeltas = []string{""} // 空内容
default: // InterceptTypeWarmup
msgID = "msg_mock_warmup"
outputTokens = 2
textDeltas = []string{"New", " Conversation"}
}
// Build message_start event with proper JSON marshaling
messageStart := map[string]any{
"type": "message_start",
"message": map[string]any{
"id": "msg_mock_warmup",
"id": msgID,
"type": "message",
"role": "assistant",
"model": model,
@@ -837,16 +884,46 @@ func sendMockWarmupStream(c *gin.Context, model string) {
}
messageStartJSON, _ := json.Marshal(messageStart)
// Build events
events := []string{
`event: message_start` + "\n" + `data: ` + string(messageStartJSON),
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
`event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
`event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
`event: content_block_stop` + "\n" + `data: {"index":0,"type":"content_block_stop"}`,
`event: message_delta` + "\n" + `data: {"delta":{"stop_reason":"end_turn","stop_sequence":null},"type":"message_delta","usage":{"input_tokens":10,"output_tokens":2}}`,
`event: message_stop` + "\n" + `data: {"type":"message_stop"}`,
}
// Add text deltas
for _, text := range textDeltas {
delta := map[string]any{
"type": "content_block_delta",
"index": 0,
"delta": map[string]string{
"type": "text_delta",
"text": text,
},
}
deltaJSON, _ := json.Marshal(delta)
events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON))
}
// Add final events
messageDelta := map[string]any{
"type": "message_delta",
"delta": map[string]any{
"stop_reason": "end_turn",
"stop_sequence": nil,
},
"usage": map[string]int{
"input_tokens": 10,
"output_tokens": outputTokens,
},
}
messageDeltaJSON, _ := json.Marshal(messageDelta)
events = append(events,
`event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`,
`event: message_delta`+"\n"+`data: `+string(messageDeltaJSON),
`event: message_stop`+"\n"+`data: {"type":"message_stop"}`,
)
for _, event := range events {
_, _ = c.Writer.WriteString(event + "\n\n")
c.Writer.Flush()
@@ -854,18 +931,32 @@ func sendMockWarmupStream(c *gin.Context, model string) {
}
}
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
func sendMockWarmupResponse(c *gin.Context, model string) {
// sendMockInterceptResponse 发送非流式 mock 响应(用于请求拦截)
func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) {
var msgID, text string
var outputTokens int
switch interceptType {
case InterceptTypeSuggestionMode:
msgID = "msg_mock_suggestion"
text = ""
outputTokens = 1
default: // InterceptTypeWarmup
msgID = "msg_mock_warmup"
text = "New Conversation"
outputTokens = 2
}
c.JSON(http.StatusOK, gin.H{
"id": "msg_mock_warmup",
"id": msgID,
"type": "message",
"role": "assistant",
"model": model,
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
"content": []gin.H{{"type": "text", "text": text}},
"stop_reason": "end_turn",
"usage": gin.H{
"input_tokens": 10,
"output_tokens": 2,
"output_tokens": outputTokens,
},
})
}

View File

@@ -0,0 +1,122 @@
//go:build unit
package handler
import (
"crypto/sha256"
"encoding/hex"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestExtractGeminiCLISessionHash(t *testing.T) {
tests := []struct {
name string
body string
privilegedUserID string
wantEmpty bool
wantHash string
}{
{
name: "with privileged-user-id and tmp dir",
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: false,
wantHash: func() string {
combined := "90785f52-8bbe-4b17-b111-a1ddea1636c3:f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"
hash := sha256.Sum256([]byte(combined))
return hex.EncodeToString(hash[:])
}(),
},
{
name: "without privileged-user-id but with tmp dir",
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
privilegedUserID: "",
wantEmpty: false,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "without tmp dir",
body: `{"contents":[{"parts":[{"text":"Hello world"}]}]}`,
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: true,
},
{
name: "empty body",
body: "",
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 创建测试上下文
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/test", nil)
if tt.privilegedUserID != "" {
c.Request.Header.Set("x-gemini-api-privileged-user-id", tt.privilegedUserID)
}
// 调用函数
result := extractGeminiCLISessionHash(c, []byte(tt.body))
// 验证结果
if tt.wantEmpty {
require.Empty(t, result, "expected empty session hash")
} else {
require.NotEmpty(t, result, "expected non-empty session hash")
require.Equal(t, tt.wantHash, result, "session hash mismatch")
}
})
}
}
func TestGeminiCLITmpDirRegex(t *testing.T) {
tests := []struct {
name string
input string
wantMatch bool
wantHash string
}{
{
name: "valid tmp dir path",
input: "/Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
wantMatch: true,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "valid tmp dir path in text",
input: "The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740\nOther text",
wantMatch: true,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "invalid hash length",
input: "/Users/ianshaw/.gemini/tmp/abc123",
wantMatch: false,
},
{
name: "no tmp dir",
input: "Hello world",
wantMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
match := geminiCLITmpDirRegex.FindStringSubmatch(tt.input)
if tt.wantMatch {
require.NotNil(t, match, "expected regex to match")
require.Len(t, match, 2, "expected 2 capture groups")
require.Equal(t, tt.wantHash, match[1], "hash mismatch")
} else {
require.Nil(t, match, "expected regex not to match")
}
})
}
}

View File

@@ -1,11 +1,15 @@
package handler
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"io"
"log"
"net/http"
"regexp"
"strings"
"time"
@@ -19,6 +23,17 @@ import (
"github.com/gin-gonic/gin"
)
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
func isGeminiCLIRequest(c *gin.Context, body []byte) bool {
if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" {
return true
}
return geminiCLITmpDirRegex.Match(body)
}
// GeminiV1BetaListModels proxies:
// GET /v1beta/models
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
@@ -214,12 +229,26 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 3) select account (sticky session based on request body)
parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 优先使用 Gemini CLI 的会话标识privileged-user-id + tmp 目录哈希)
sessionHash := extractGeminiCLISessionHash(c, body)
if sessionHash == "" {
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
}
sessionKey := sessionHash
if sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
// 查询粘性会话绑定的账号 ID用于检测账号切换
var sessionBoundAccountID int64
if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
}
isCLI := isGeminiCLIRequest(c, body)
cleanedForUnknownBinding := false
maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
@@ -238,6 +267,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
account := selection.Account
setOpsSelectedAccount(c, account.ID)
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
// 注意Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
body = service.CleanGeminiNativeThoughtSignatures(body)
sessionBoundAccountID = account.ID
} else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
// 无缓存绑定但请求里已有 thoughtSignature常见于缓存丢失/TTL 过期后CLI 继续携带旧签名。
// 为避免第一次转发就 400这里做一次确定性清理让新账号重新生成签名链路。
log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively")
body = service.CleanGeminiNativeThoughtSignatures(body)
cleanedForUnknownBinding = true
sessionBoundAccountID = account.ID
} else if sessionBoundAccountID == 0 {
// 记录本次请求中首次选择到的账号,便于同一请求内 failover 时检测切换。
sessionBoundAccountID = account.ID
}
// 4) account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
@@ -433,3 +480,38 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
}
return false
}
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
//
// 会话标识生成策略:
// 1. 从请求体中提取 tmp 目录哈希64位十六进制
// 2. 从 header 中提取 privileged-user-idUUID
// 3. 组合两者生成 SHA256 哈希作为最终的会话标识
//
// 如果找不到 tmp 目录哈希,返回空字符串(不使用粘性会话)。
//
// extractGeminiCLISessionHash extracts session identifier from Gemini CLI requests.
// Combines x-gemini-api-privileged-user-id header with tmp directory hash from request body.
func extractGeminiCLISessionHash(c *gin.Context, body []byte) string {
// 1. 从请求体中提取 tmp 目录哈希
match := geminiCLITmpDirRegex.FindSubmatch(body)
if len(match) < 2 {
return "" // 没有找到 tmp 目录,不使用粘性会话
}
tmpDirHash := string(match[1])
// 2. 提取 privileged-user-id
privilegedUserID := strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id"))
// 3. 组合生成最终的 session hash
if privilegedUserID != "" {
// 组合两个标识符privileged-user-id + tmp 目录哈希
combined := privilegedUserID + ":" + tmpDirHash
hash := sha256.Sum256([]byte(combined))
return hex.EncodeToString(hash[:])
}
// 如果没有 privileged-user-id直接使用 tmp 目录哈希
return tmpDirHash
}

View File

@@ -37,6 +37,7 @@ type Handlers struct {
Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler
Setting *SettingHandler
Totp *TotpHandler
}
// BuildInfo contains build-time information

View File

@@ -32,20 +32,21 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
}
response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: h.version,
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: h.version,
})
}

View File

@@ -0,0 +1,181 @@
package handler
import (
"github.com/gin-gonic/gin"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// TotpHandler handles TOTP-related requests
type TotpHandler struct {
totpService *service.TotpService
}
// NewTotpHandler creates a new TotpHandler
func NewTotpHandler(totpService *service.TotpService) *TotpHandler {
return &TotpHandler{
totpService: totpService,
}
}
// TotpStatusResponse represents the TOTP status response
type TotpStatusResponse struct {
Enabled bool `json:"enabled"`
EnabledAt *int64 `json:"enabled_at,omitempty"` // Unix timestamp
FeatureEnabled bool `json:"feature_enabled"`
}
// GetStatus returns the TOTP status for the current user
// GET /api/v1/user/totp/status
func (h *TotpHandler) GetStatus(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
status, err := h.totpService.GetStatus(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
resp := TotpStatusResponse{
Enabled: status.Enabled,
FeatureEnabled: status.FeatureEnabled,
}
if status.EnabledAt != nil {
ts := status.EnabledAt.Unix()
resp.EnabledAt = &ts
}
response.Success(c, resp)
}
// TotpSetupRequest represents the request to initiate TOTP setup
type TotpSetupRequest struct {
EmailCode string `json:"email_code"`
Password string `json:"password"`
}
// TotpSetupResponse represents the TOTP setup response
type TotpSetupResponse struct {
Secret string `json:"secret"`
QRCodeURL string `json:"qr_code_url"`
SetupToken string `json:"setup_token"`
Countdown int `json:"countdown"`
}
// InitiateSetup starts the TOTP setup process
// POST /api/v1/user/totp/setup
func (h *TotpHandler) InitiateSetup(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req TotpSetupRequest
if err := c.ShouldBindJSON(&req); err != nil {
// Allow empty body (optional params)
req = TotpSetupRequest{}
}
result, err := h.totpService.InitiateSetup(c.Request.Context(), subject.UserID, req.EmailCode, req.Password)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, TotpSetupResponse{
Secret: result.Secret,
QRCodeURL: result.QRCodeURL,
SetupToken: result.SetupToken,
Countdown: result.Countdown,
})
}
// TotpEnableRequest represents the request to enable TOTP
type TotpEnableRequest struct {
TotpCode string `json:"totp_code" binding:"required,len=6"`
SetupToken string `json:"setup_token" binding:"required"`
}
// Enable completes the TOTP setup
// POST /api/v1/user/totp/enable
func (h *TotpHandler) Enable(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req TotpEnableRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.totpService.CompleteSetup(c.Request.Context(), subject.UserID, req.TotpCode, req.SetupToken); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"success": true})
}
// TotpDisableRequest represents the request to disable TOTP
type TotpDisableRequest struct {
EmailCode string `json:"email_code"`
Password string `json:"password"`
}
// Disable disables TOTP for the current user
// POST /api/v1/user/totp/disable
func (h *TotpHandler) Disable(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req TotpDisableRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.totpService.Disable(c.Request.Context(), subject.UserID, req.EmailCode, req.Password); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"success": true})
}
// GetVerificationMethod returns the verification method for TOTP operations
// GET /api/v1/user/totp/verification-method
func (h *TotpHandler) GetVerificationMethod(c *gin.Context) {
method := h.totpService.GetVerificationMethod(c.Request.Context())
response.Success(c, method)
}
// SendVerifyCode sends an email verification code for TOTP operations
// POST /api/v1/user/totp/send-code
func (h *TotpHandler) SendVerifyCode(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"success": true})
}

View File

@@ -70,6 +70,7 @@ func ProvideHandlers(
gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler,
settingHandler *SettingHandler,
totpHandler *TotpHandler,
) *Handlers {
return &Handlers{
Auth: authHandler,
@@ -82,6 +83,7 @@ func ProvideHandlers(
Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler,
Setting: settingHandler,
Totp: totpHandler,
}
}
@@ -96,6 +98,7 @@ var ProviderSet = wire.NewSet(
NewSubscriptionHandler,
NewGatewayHandler,
NewOpenAIGatewayHandler,
NewTotpHandler,
ProvideSettingHandler,
// Admin handlers

View File

@@ -7,13 +7,11 @@ import (
"fmt"
"log"
"math/rand"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
@@ -369,8 +367,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
Text: block.Thinking,
Thought: true,
}
// 保留原有 signatureClaude 模型需要有效的 signature
if block.Signature != "" {
// signature 处理:
// - Claude 模型allowDummyThought=false必须是上游返回的真实 signaturedummy 视为缺失)
// - Gemini 模型allowDummyThought=true优先透传真实 signature缺失时使用 dummy signature
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
part.ThoughtSignature = block.Signature
} else if !allowDummyThought {
// Claude 模型需要有效 signature在缺失时降级为普通文本并在上层禁用 thinking mode。
@@ -409,12 +409,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
},
}
// tool_use 的 signature 处理:
// - Gemini 模型:使用 dummy signature跳过 thought_signature 校验
// - Claude 模型:透传上游返回的真实 signatureVertex/Google 需要完整签名链路)
if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature
} else if block.Signature != "" && block.Signature != dummyThoughtSignature {
// - Claude 模型allowDummyThought=false必须是上游返回的真实 signaturedummy 视为缺失
// - Gemini 模型allowDummyThought=true优先透传真实 signature缺失时使用 dummy signature
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
part.ThoughtSignature = block.Signature
} else if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature
}
parts = append(parts, part)
@@ -594,11 +594,14 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
}
// 清理 JSON Schema
params := cleanJSONSchema(inputSchema)
// 1. 深度清理 [undefined] 值
DeepCleanUndefined(inputSchema)
// 2. 转换为符合 Gemini v1internal 的 schema
params := CleanJSONSchema(inputSchema)
// 为 nil schema 提供默认值
if params == nil {
params = map[string]any{
"type": "OBJECT",
"type": "object", // lowercase type
"properties": map[string]any{},
}
}
@@ -631,236 +634,3 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
FunctionDeclarations: funcDecls,
}}
}
// cleanJSONSchema 清理 JSON Schema移除 Antigravity/Gemini 不支持的字段
// 参考 proxycast 的实现,确保 schema 符合 JSON Schema draft 2020-12
func cleanJSONSchema(schema map[string]any) map[string]any {
if schema == nil {
return nil
}
cleaned := cleanSchemaValue(schema, "$")
result, ok := cleaned.(map[string]any)
if !ok {
return nil
}
// 确保有 type 字段(默认 OBJECT
if _, hasType := result["type"]; !hasType {
result["type"] = "OBJECT"
}
// 确保有 properties 字段(默认空对象)
if _, hasProps := result["properties"]; !hasProps {
result["properties"] = make(map[string]any)
}
// 验证 required 中的字段都存在于 properties 中
if required, ok := result["required"].([]any); ok {
if props, ok := result["properties"].(map[string]any); ok {
validRequired := make([]any, 0, len(required))
for _, r := range required {
if reqName, ok := r.(string); ok {
if _, exists := props[reqName]; exists {
validRequired = append(validRequired, r)
}
}
}
if len(validRequired) > 0 {
result["required"] = validRequired
} else {
delete(result, "required")
}
}
}
return result
}
var schemaValidationKeys = map[string]bool{
"minLength": true,
"maxLength": true,
"pattern": true,
"minimum": true,
"maximum": true,
"exclusiveMinimum": true,
"exclusiveMaximum": true,
"multipleOf": true,
"uniqueItems": true,
"minItems": true,
"maxItems": true,
"minProperties": true,
"maxProperties": true,
"patternProperties": true,
"propertyNames": true,
"dependencies": true,
"dependentSchemas": true,
"dependentRequired": true,
}
var warnedSchemaKeys sync.Map
func schemaCleaningWarningsEnabled() bool {
// 可通过环境变量强制开关方便排查SUB2API_SCHEMA_CLEAN_WARN=true/false
if v := strings.TrimSpace(os.Getenv("SUB2API_SCHEMA_CLEAN_WARN")); v != "" {
switch strings.ToLower(v) {
case "1", "true", "yes", "on":
return true
case "0", "false", "no", "off":
return false
}
}
// 默认:非 release 模式下输出debug/test
return gin.Mode() != gin.ReleaseMode
}
func warnSchemaKeyRemovedOnce(key, path string) {
if !schemaCleaningWarningsEnabled() {
return
}
if !schemaValidationKeys[key] {
return
}
if _, loaded := warnedSchemaKeys.LoadOrStore(key, struct{}{}); loaded {
return
}
log.Printf("[SchemaClean] removed unsupported JSON Schema validation field key=%q path=%q", key, path)
}
// excludedSchemaKeys 不支持的 schema 字段
// 基于 Claude API (Vertex AI) 的实际支持情况
// 支持: type, description, enum, properties, required, additionalProperties, items
// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段
var excludedSchemaKeys = map[string]bool{
// 元 schema 字段
"$schema": true,
"$id": true,
"$ref": true,
// 字符串验证Gemini 不支持)
"minLength": true,
"maxLength": true,
"pattern": true,
// 数字验证Claude API 通过 Vertex AI 不支持这些字段)
"minimum": true,
"maximum": true,
"exclusiveMinimum": true,
"exclusiveMaximum": true,
"multipleOf": true,
// 数组验证Claude API 通过 Vertex AI 不支持这些字段)
"uniqueItems": true,
"minItems": true,
"maxItems": true,
// 组合 schemaGemini 不支持)
"oneOf": true,
"anyOf": true,
"allOf": true,
"not": true,
"if": true,
"then": true,
"else": true,
"$defs": true,
"definitions": true,
// 对象验证(仅保留 properties/required/additionalProperties
"minProperties": true,
"maxProperties": true,
"patternProperties": true,
"propertyNames": true,
"dependencies": true,
"dependentSchemas": true,
"dependentRequired": true,
// 其他不支持的字段
"default": true,
"const": true,
"examples": true,
"deprecated": true,
"readOnly": true,
"writeOnly": true,
"contentMediaType": true,
"contentEncoding": true,
// Claude 特有字段
"strict": true,
}
// cleanSchemaValue 递归清理 schema 值
func cleanSchemaValue(value any, path string) any {
switch v := value.(type) {
case map[string]any:
result := make(map[string]any)
for k, val := range v {
// 跳过不支持的字段
if excludedSchemaKeys[k] {
warnSchemaKeyRemovedOnce(k, path)
continue
}
// 特殊处理 type 字段
if k == "type" {
result[k] = cleanTypeValue(val)
continue
}
// 特殊处理 format 字段:只保留 Gemini 支持的 format 值
if k == "format" {
if formatStr, ok := val.(string); ok {
// Gemini 只支持 date-time, date, time
if formatStr == "date-time" || formatStr == "date" || formatStr == "time" {
result[k] = val
}
// 其他 format 值直接跳过
}
continue
}
// 特殊处理 additionalPropertiesClaude API 只支持布尔值,不支持 schema 对象
if k == "additionalProperties" {
if boolVal, ok := val.(bool); ok {
result[k] = boolVal
} else {
// 如果是 schema 对象,转换为 false更安全的默认值
result[k] = false
}
continue
}
// 递归清理所有值
result[k] = cleanSchemaValue(val, path+"."+k)
}
return result
case []any:
// 递归处理数组中的每个元素
cleaned := make([]any, 0, len(v))
for i, item := range v {
cleaned = append(cleaned, cleanSchemaValue(item, fmt.Sprintf("%s[%d]", path, i)))
}
return cleaned
default:
return value
}
}
// cleanTypeValue 处理 type 字段,转换为大写
func cleanTypeValue(value any) any {
switch v := value.(type) {
case string:
return strings.ToUpper(v)
case []any:
// 联合类型 ["string", "null"] -> 取第一个非 null 类型
for _, t := range v {
if ts, ok := t.(string); ok && ts != "null" {
return strings.ToUpper(ts)
}
}
// 如果只有 null返回 STRING
return "STRING"
default:
return value
}
}

View File

@@ -100,7 +100,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"}
]`
t.Run("Gemini uses dummy tool_use signature", func(t *testing.T) {
t.Run("Gemini preserves provided tool_use signature", func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true)
if err != nil {
@@ -109,6 +109,23 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
if len(parts) != 1 || parts[0].FunctionCall == nil {
t.Fatalf("expected 1 functionCall part, got %+v", parts)
}
if parts[0].ThoughtSignature != "sig_tool_abc" {
t.Fatalf("expected preserved tool signature %q, got %q", "sig_tool_abc", parts[0].ThoughtSignature)
}
})
t.Run("Gemini falls back to dummy tool_use signature when missing", func(t *testing.T) {
contentNoSig := `[
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}}
]`
toolIDToName := make(map[string]string)
parts, _, err := buildParts(json.RawMessage(contentNoSig), toolIDToName, true)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
if len(parts) != 1 || parts[0].FunctionCall == nil {
t.Fatalf("expected 1 functionCall part, got %+v", parts)
}
if parts[0].ThoughtSignature != dummyThoughtSignature {
t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature)
}

View File

@@ -3,6 +3,7 @@ package antigravity
import (
"encoding/json"
"fmt"
"log"
"strings"
)
@@ -19,6 +20,15 @@ func TransformGeminiToClaude(geminiResp []byte, originalModel string) ([]byte, *
v1Resp.Response = directResp
v1Resp.ResponseID = directResp.ResponseID
v1Resp.ModelVersion = directResp.ModelVersion
} else if len(v1Resp.Response.Candidates) == 0 {
// 第一次解析成功但 candidates 为空,说明是直接的 GeminiResponse 格式
var directResp GeminiResponse
if err2 := json.Unmarshal(geminiResp, &directResp); err2 != nil {
return nil, nil, fmt.Errorf("parse gemini response as direct: %w", err2)
}
v1Resp.Response = directResp
v1Resp.ResponseID = directResp.ResponseID
v1Resp.ModelVersion = directResp.ModelVersion
}
// 使用处理器转换
@@ -173,16 +183,20 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
p.trailingSignature = ""
}
p.textBuilder += part.Text
// 非空 text 带签名 - 立即刷新并输出空 thinking 块
// 非空 text 带签名 - 特殊处理:先输出 text再输出空 thinking 块
if signature != "" {
p.flushText()
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "text",
Text: part.Text,
})
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking",
Thinking: "",
Signature: signature,
})
} else {
// 普通 text (无签名) - 累积到 builder
p.textBuilder += part.Text
}
}
}
@@ -242,6 +256,14 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
var finishReason string
if len(geminiResp.Candidates) > 0 {
finishReason = geminiResp.Candidates[0].FinishReason
if finishReason == "MALFORMED_FUNCTION_CALL" {
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in response for model %s", originalModel)
if geminiResp.Candidates[0].Content != nil {
if b, err := json.Marshal(geminiResp.Candidates[0].Content); err == nil {
log.Printf("[Antigravity] Malformed content: %s", string(b))
}
}
}
}
stopReason := "end_turn"

View File

@@ -0,0 +1,519 @@
package antigravity
import (
"fmt"
"strings"
)
// CleanJSONSchema 清理 JSON Schema移除 Antigravity/Gemini 不支持的字段
// 参考 Antigravity-Manager/src-tauri/src/proxy/common/json_schema.rs 实现
// 确保 schema 符合 JSON Schema draft 2020-12 且适配 Gemini v1internal
func CleanJSONSchema(schema map[string]any) map[string]any {
if schema == nil {
return nil
}
// 0. 预处理:展开 $ref (Schema Flattening)
// (Go map 是引用的,直接修改 schema)
flattenRefs(schema, extractDefs(schema))
// 递归清理
cleaned := cleanJSONSchemaRecursive(schema)
result, ok := cleaned.(map[string]any)
if !ok {
return nil
}
return result
}
// extractDefs 提取并移除定义的 helper
func extractDefs(schema map[string]any) map[string]any {
defs := make(map[string]any)
if d, ok := schema["$defs"].(map[string]any); ok {
for k, v := range d {
defs[k] = v
}
delete(schema, "$defs")
}
if d, ok := schema["definitions"].(map[string]any); ok {
for k, v := range d {
defs[k] = v
}
delete(schema, "definitions")
}
return defs
}
// flattenRefs 递归展开 $ref
func flattenRefs(schema map[string]any, defs map[string]any) {
if len(defs) == 0 {
return // 无需展开
}
// 检查并替换 $ref
if ref, ok := schema["$ref"].(string); ok {
delete(schema, "$ref")
// 解析引用名 (例如 #/$defs/MyType -> MyType)
parts := strings.Split(ref, "/")
refName := parts[len(parts)-1]
if defSchema, exists := defs[refName]; exists {
if defMap, ok := defSchema.(map[string]any); ok {
// 合并定义内容 (不覆盖现有 key)
for k, v := range defMap {
if _, has := schema[k]; !has {
schema[k] = deepCopy(v) // 需深拷贝避免共享引用
}
}
// 递归处理刚刚合并进来的内容
flattenRefs(schema, defs)
}
}
}
// 遍历子节点
for _, v := range schema {
if subMap, ok := v.(map[string]any); ok {
flattenRefs(subMap, defs)
} else if subArr, ok := v.([]any); ok {
for _, item := range subArr {
if itemMap, ok := item.(map[string]any); ok {
flattenRefs(itemMap, defs)
}
}
}
}
}
// deepCopy 深拷贝 (简单实现,仅针对 JSON 类型)
func deepCopy(src any) any {
if src == nil {
return nil
}
switch v := src.(type) {
case map[string]any:
dst := make(map[string]any)
for k, val := range v {
dst[k] = deepCopy(val)
}
return dst
case []any:
dst := make([]any, len(v))
for i, val := range v {
dst[i] = deepCopy(val)
}
return dst
default:
return src
}
}
// cleanJSONSchemaRecursive 递归核心清理逻辑
// 返回处理后的值 (通常是 input map但可能修改内部结构)
func cleanJSONSchemaRecursive(value any) any {
schemaMap, ok := value.(map[string]any)
if !ok {
return value
}
// 0. [NEW] 合并 allOf
mergeAllOf(schemaMap)
// 1. [CRITICAL] 深度递归处理子项
if props, ok := schemaMap["properties"].(map[string]any); ok {
for _, v := range props {
cleanJSONSchemaRecursive(v)
}
// Go 中不需要像 Rust 那样显式处理 nullable_keys remove required
// 因为我们在子项处理中会正确设置 type 和 description
} else if items, ok := schemaMap["items"]; ok {
// [FIX] Gemini 期望 "items" 是单个 Schema 对象(列表验证),而不是数组(元组验证)。
if itemsArr, ok := items.([]any); ok {
// 策略:将元组 [A, B] 视为 A、B 中的最佳匹配项。
best := extractBestSchemaFromUnion(itemsArr)
if best == nil {
// 回退到通用字符串
best = map[string]any{"type": "string"}
}
// 用处理后的对象替换原有数组
cleanedBest := cleanJSONSchemaRecursive(best)
schemaMap["items"] = cleanedBest
} else {
cleanJSONSchemaRecursive(items)
}
} else {
// 遍历所有值递归
for _, v := range schemaMap {
if _, isMap := v.(map[string]any); isMap {
cleanJSONSchemaRecursive(v)
} else if arr, isArr := v.([]any); isArr {
for _, item := range arr {
cleanJSONSchemaRecursive(item)
}
}
}
}
// 2. [FIX] 处理 anyOf/oneOf 联合类型: 合并属性而非直接删除
var unionArray []any
typeStr, _ := schemaMap["type"].(string)
if typeStr == "" || typeStr == "object" {
if anyOf, ok := schemaMap["anyOf"].([]any); ok {
unionArray = anyOf
} else if oneOf, ok := schemaMap["oneOf"].([]any); ok {
unionArray = oneOf
}
}
if len(unionArray) > 0 {
if bestBranch := extractBestSchemaFromUnion(unionArray); bestBranch != nil {
if bestMap, ok := bestBranch.(map[string]any); ok {
// 合并分支内容
for k, v := range bestMap {
if k == "properties" {
targetProps, _ := schemaMap["properties"].(map[string]any)
if targetProps == nil {
targetProps = make(map[string]any)
schemaMap["properties"] = targetProps
}
if sourceProps, ok := v.(map[string]any); ok {
for pk, pv := range sourceProps {
if _, exists := targetProps[pk]; !exists {
targetProps[pk] = deepCopy(pv)
}
}
}
} else if k == "required" {
targetReq, _ := schemaMap["required"].([]any)
if sourceReq, ok := v.([]any); ok {
for _, rv := range sourceReq {
// 简单的去重添加
exists := false
for _, tr := range targetReq {
if tr == rv {
exists = true
break
}
}
if !exists {
targetReq = append(targetReq, rv)
}
}
schemaMap["required"] = targetReq
}
} else if _, exists := schemaMap[k]; !exists {
schemaMap[k] = deepCopy(v)
}
}
}
}
}
// 3. [SAFETY] 检查当前对象是否为 JSON Schema 节点
looksLikeSchema := hasKey(schemaMap, "type") ||
hasKey(schemaMap, "properties") ||
hasKey(schemaMap, "items") ||
hasKey(schemaMap, "enum") ||
hasKey(schemaMap, "anyOf") ||
hasKey(schemaMap, "oneOf") ||
hasKey(schemaMap, "allOf")
if looksLikeSchema {
// 4. [ROBUST] 约束迁移
migrateConstraints(schemaMap)
// 5. [CRITICAL] 白名单过滤
allowedFields := map[string]bool{
"type": true,
"description": true,
"properties": true,
"required": true,
"items": true,
"enum": true,
"title": true,
}
for k := range schemaMap {
if !allowedFields[k] {
delete(schemaMap, k)
}
}
// 6. [SAFETY] 处理空 Object
if t, _ := schemaMap["type"].(string); t == "object" {
hasProps := false
if props, ok := schemaMap["properties"].(map[string]any); ok && len(props) > 0 {
hasProps = true
}
if !hasProps {
schemaMap["properties"] = map[string]any{
"reason": map[string]any{
"type": "string",
"description": "Reason for calling this tool",
},
}
schemaMap["required"] = []any{"reason"}
}
}
// 7. [SAFETY] Required 字段对齐
if props, ok := schemaMap["properties"].(map[string]any); ok {
if req, ok := schemaMap["required"].([]any); ok {
var validReq []any
for _, r := range req {
if rStr, ok := r.(string); ok {
if _, exists := props[rStr]; exists {
validReq = append(validReq, r)
}
}
}
if len(validReq) > 0 {
schemaMap["required"] = validReq
} else {
delete(schemaMap, "required")
}
}
}
// 8. 处理 type 字段 (Lowercase + Nullable 提取)
isEffectivelyNullable := false
if typeVal, exists := schemaMap["type"]; exists {
var selectedType string
switch v := typeVal.(type) {
case string:
lower := strings.ToLower(v)
if lower == "null" {
isEffectivelyNullable = true
selectedType = "string" // fallback
} else {
selectedType = lower
}
case []any:
// ["string", "null"]
for _, t := range v {
if ts, ok := t.(string); ok {
lower := strings.ToLower(ts)
if lower == "null" {
isEffectivelyNullable = true
} else if selectedType == "" {
selectedType = lower
}
}
}
if selectedType == "" {
selectedType = "string"
}
}
schemaMap["type"] = selectedType
} else {
// 默认 object 如果有 properties (虽然上面白名单过滤可能删了 type 如果它不在... 但 type 必在 allowlist)
// 如果没有 type但有 properties补一个
if hasKey(schemaMap, "properties") {
schemaMap["type"] = "object"
} else {
// 默认为 string ? or object? Gemini 通常需要明确 type
schemaMap["type"] = "object"
}
}
if isEffectivelyNullable {
desc, _ := schemaMap["description"].(string)
if !strings.Contains(desc, "nullable") {
if desc != "" {
desc += " "
}
desc += "(nullable)"
schemaMap["description"] = desc
}
}
// 9. Enum 值强制转字符串
if enumVals, ok := schemaMap["enum"].([]any); ok {
hasNonString := false
for i, val := range enumVals {
if _, isStr := val.(string); !isStr {
hasNonString = true
if val == nil {
enumVals[i] = "null"
} else {
enumVals[i] = fmt.Sprintf("%v", val)
}
}
}
// If we mandated string values, we must ensure type is string
if hasNonString {
schemaMap["type"] = "string"
}
}
}
return schemaMap
}
func hasKey(m map[string]any, k string) bool {
_, ok := m[k]
return ok
}
func migrateConstraints(m map[string]any) {
constraints := []struct {
key string
label string
}{
{"minLength", "minLen"},
{"maxLength", "maxLen"},
{"pattern", "pattern"},
{"minimum", "min"},
{"maximum", "max"},
{"multipleOf", "multipleOf"},
{"exclusiveMinimum", "exclMin"},
{"exclusiveMaximum", "exclMax"},
{"minItems", "minItems"},
{"maxItems", "maxItems"},
{"propertyNames", "propertyNames"},
{"format", "format"},
}
var hints []string
for _, c := range constraints {
if val, ok := m[c.key]; ok && val != nil {
hints = append(hints, fmt.Sprintf("%s: %v", c.label, val))
}
}
if len(hints) > 0 {
suffix := fmt.Sprintf(" [Constraint: %s]", strings.Join(hints, ", "))
desc, _ := m["description"].(string)
if !strings.Contains(desc, suffix) {
m["description"] = desc + suffix
}
}
}
// mergeAllOf 合并 allOf
func mergeAllOf(m map[string]any) {
allOf, ok := m["allOf"].([]any)
if !ok {
return
}
delete(m, "allOf")
mergedProps := make(map[string]any)
mergedReq := make(map[string]bool)
otherFields := make(map[string]any)
for _, sub := range allOf {
if subMap, ok := sub.(map[string]any); ok {
// Props
if props, ok := subMap["properties"].(map[string]any); ok {
for k, v := range props {
mergedProps[k] = v
}
}
// Required
if reqs, ok := subMap["required"].([]any); ok {
for _, r := range reqs {
if s, ok := r.(string); ok {
mergedReq[s] = true
}
}
}
// Others
for k, v := range subMap {
if k != "properties" && k != "required" && k != "allOf" {
if _, exists := otherFields[k]; !exists {
otherFields[k] = v
}
}
}
}
}
// Apply
for k, v := range otherFields {
if _, exists := m[k]; !exists {
m[k] = v
}
}
if len(mergedProps) > 0 {
existProps, _ := m["properties"].(map[string]any)
if existProps == nil {
existProps = make(map[string]any)
m["properties"] = existProps
}
for k, v := range mergedProps {
if _, exists := existProps[k]; !exists {
existProps[k] = v
}
}
}
if len(mergedReq) > 0 {
existReq, _ := m["required"].([]any)
var validReqs []any
for _, r := range existReq {
if s, ok := r.(string); ok {
validReqs = append(validReqs, s)
delete(mergedReq, s) // already exists
}
}
// append new
for r := range mergedReq {
validReqs = append(validReqs, r)
}
m["required"] = validReqs
}
}
// extractBestSchemaFromUnion 从 anyOf/oneOf 中选取最佳分支
func extractBestSchemaFromUnion(unionArray []any) any {
var bestOption any
bestScore := -1
for _, item := range unionArray {
score := scoreSchemaOption(item)
if score > bestScore {
bestScore = score
bestOption = item
}
}
return bestOption
}
func scoreSchemaOption(val any) int {
m, ok := val.(map[string]any)
if !ok {
return 0
}
typeStr, _ := m["type"].(string)
if hasKey(m, "properties") || typeStr == "object" {
return 3
}
if hasKey(m, "items") || typeStr == "array" {
return 2
}
if typeStr != "" && typeStr != "null" {
return 1
}
return 0
}
// DeepCleanUndefined 深度清理值为 "[undefined]" 的字段
func DeepCleanUndefined(value any) {
if value == nil {
return
}
switch v := value.(type) {
case map[string]any:
for k, val := range v {
if s, ok := val.(string); ok && s == "[undefined]" {
delete(v, k)
continue
}
DeepCleanUndefined(val)
}
case []any:
for _, val := range v {
DeepCleanUndefined(val)
}
}
}

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"log"
"strings"
)
@@ -102,6 +103,14 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
// 检查是否结束
if len(geminiResp.Candidates) > 0 {
finishReason := geminiResp.Candidates[0].FinishReason
if finishReason == "MALFORMED_FUNCTION_CALL" {
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in stream for model %s", p.originalModel)
if geminiResp.Candidates[0].Content != nil {
if b, err := json.Marshal(geminiResp.Candidates[0].Content); err == nil {
log.Printf("[Antigravity] Malformed content: %s", string(b))
}
}
}
if finishReason != "" {
_, _ = result.Write(p.emitFinish(finishReason))
}

View File

@@ -24,9 +24,9 @@ const (
RedirectURI = "https://platform.claude.com/oauth/code/callback"
// Scopes - Browser URL (includes org:create_api_key for user authorization)
ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code"
ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers"
// Scopes - Internal API call (org:create_api_key not supported in API)
ScopeAPI = "user:profile user:inference user:sessions:claude_code"
ScopeAPI = "user:profile user:inference user:sessions:claude_code user:mcp_servers"
// Scopes - Setup token (inference only)
ScopeInference = "user:inference"
@@ -215,5 +215,6 @@ type OrgInfo struct {
// AccountInfo represents account info from OAuth response
type AccountInfo struct {
UUID string `json:"uuid"`
UUID string `json:"uuid"`
EmailAddress string `json:"email_address"`
}

View File

@@ -2,6 +2,7 @@
package response
import (
"log"
"math"
"net/http"
@@ -74,6 +75,12 @@ func ErrorFrom(c *gin.Context, err error) bool {
}
statusCode, status := infraerrors.ToHTTP(err)
// Log internal errors with full details for debugging
if statusCode >= 500 && c.Request != nil {
log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, err.Error())
}
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
return true
}

View File

@@ -0,0 +1,278 @@
//go:build integration
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
//
// Integration tests for verifying TLS fingerprint correctness.
// These tests make actual network requests to external services and should be run manually.
//
// Run with: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
package tlsfingerprint
import (
"context"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"time"
)
// skipIfExternalServiceUnavailable checks if the external service is available.
// If not, it skips the test instead of failing.
func skipIfExternalServiceUnavailable(t *testing.T, err error) {
t.Helper()
if err != nil {
// Check for common network/TLS errors that indicate external service issues
errStr := err.Error()
if strings.Contains(errStr, "certificate has expired") ||
strings.Contains(errStr, "certificate is not yet valid") ||
strings.Contains(errStr, "connection refused") ||
strings.Contains(errStr, "no such host") ||
strings.Contains(errStr, "network is unreachable") ||
strings.Contains(errStr, "timeout") {
t.Skipf("skipping test: external service unavailable: %v", err)
}
t.Fatalf("failed to get fingerprint: %v", err)
}
}
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
// This test uses tls.peet.ws to verify the fingerprint.
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
func TestJA3Fingerprint(t *testing.T) {
// Skip if network is unavailable or if running in short mode
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
profile := &Profile{
Name: "Claude CLI Test",
EnableGREASE: false,
}
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
// Use tls.peet.ws fingerprint detection API
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
skipIfExternalServiceUnavailable(t, err)
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
}
// Log all fingerprint information
t.Logf("JA3: %s", fpResp.TLS.JA3)
t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
t.Logf("JA4: %s", fpResp.TLS.JA4)
t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
// Verify JA3 hash matches expected value
expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
if fpResp.TLS.JA3Hash == expectedJA3Hash {
t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
} else {
t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
}
// Verify JA4 fingerprint
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
// The suffix _a33745022dd6_1f22a2ca17c4 should match
expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
} else {
t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
}
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
// d = domain (SNI present), i = IP (no SNI)
// Since we connect to tls.peet.ws (domain), we expect 'd'
expectedJA4Prefix := "t13d5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
} else {
// Also accept 'i' variant for IP connections
altPrefix := "t13i5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
} else {
t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
}
}
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
} else {
t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
}
// Verify extension list (should be 11 extensions including SNI)
// Expected: 0-11-10-35-16-22-23-13-43-45-51
expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
} else {
t.Logf("Warning: JA3 extension list may differ")
}
}
// TestProfileExpectation defines expected fingerprint values for a profile.
type TestProfileExpectation struct {
Profile *Profile
ExpectedJA3 string // Expected JA3 hash (empty = don't check)
ExpectedJA4 string // Expected full JA4 (empty = don't check)
JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
}
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
// Run with: go test -v -tags=integration -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
func TestAllProfiles(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
// Define all profiles to test with their expected fingerprints
// These profiles are from config.yaml gateway.tls_fingerprint.profiles
profiles := []TestProfileExpectation{
{
// Linux x64 Node.js v22.17.1
// Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
Profile: &Profile{
Name: "linux_x64_node_v22171",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part
},
{
// MacOS arm64 Node.js v22.18.0
// Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
// Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
Profile: &Profile{
Name: "macos_arm64_node_v22180",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
},
}
for _, tc := range profiles {
tc := tc // capture range variable
t.Run(tc.Profile.Name, func(t *testing.T) {
fp := fetchFingerprint(t, tc.Profile)
if fp == nil {
return // fetchFingerprint already called t.Fatal
}
t.Logf("Profile: %s", tc.Profile.Name)
t.Logf(" JA3: %s", fp.JA3)
t.Logf(" JA3 Hash: %s", fp.JA3Hash)
t.Logf(" JA4: %s", fp.JA4)
t.Logf(" PeetPrint: %s", fp.PeetPrint)
t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
// Verify expectations
if tc.ExpectedJA3 != "" {
if fp.JA3Hash == tc.ExpectedJA3 {
t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
} else {
t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
}
}
if tc.ExpectedJA4 != "" {
if fp.JA4 == tc.ExpectedJA4 {
t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
} else {
t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
}
}
// Check JA4 cipher hash (stable middle part)
// JA4 format: prefix_cipherHash_extHash
if tc.JA4CipherHash != "" {
if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
} else {
t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
}
}
})
}
}
// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
t.Helper()
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
return nil
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
skipIfExternalServiceUnavailable(t, err)
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
return nil
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
return nil
}
return &fpResp.TLS
}

View File

@@ -1,21 +1,16 @@
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
//
// Integration tests for verifying TLS fingerprint correctness.
// These tests make actual network requests and should be run manually.
// Unit tests for TLS fingerprint dialer.
// Integration tests that require external network are in dialer_integration_test.go
// and require the 'integration' build tag.
//
// Run with: go test -v ./internal/pkg/tlsfingerprint/...
// Run integration tests: go test -v -run TestJA3 ./internal/pkg/tlsfingerprint/...
// Run unit tests: go test -v ./internal/pkg/tlsfingerprint/...
// Run integration tests: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
package tlsfingerprint
import (
"context"
"encoding/json"
"io"
"net/http"
"net/url"
"strings"
"testing"
"time"
)
// FingerprintResponse represents the response from tls.peet.ws/api/all.
@@ -36,148 +31,6 @@ type TLSInfo struct {
SessionID string `json:"session_id"`
}
// TestDialerBasicConnection tests that the dialer can establish TLS connections.
func TestDialerBasicConnection(t *testing.T) {
if testing.Short() {
t.Skip("skipping network test in short mode")
}
// Create a dialer with default profile
profile := &Profile{
Name: "Test Profile",
EnableGREASE: false,
}
dialer := NewDialer(profile, nil)
// Create HTTP client with custom TLS dialer
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
// Make a request to a known HTTPS endpoint
resp, err := client.Get("https://www.google.com")
if err != nil {
t.Fatalf("failed to connect: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
// This test uses tls.peet.ws to verify the fingerprint.
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
func TestJA3Fingerprint(t *testing.T) {
// Skip if network is unavailable or if running in short mode
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
profile := &Profile{
Name: "Claude CLI Test",
EnableGREASE: false,
}
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
// Use tls.peet.ws fingerprint detection API
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("failed to get fingerprint: %v", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
}
// Log all fingerprint information
t.Logf("JA3: %s", fpResp.TLS.JA3)
t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
t.Logf("JA4: %s", fpResp.TLS.JA4)
t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
// Verify JA3 hash matches expected value
expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
if fpResp.TLS.JA3Hash == expectedJA3Hash {
t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
} else {
t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
}
// Verify JA4 fingerprint
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
// The suffix _a33745022dd6_1f22a2ca17c4 should match
expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
} else {
t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
}
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
// d = domain (SNI present), i = IP (no SNI)
// Since we connect to tls.peet.ws (domain), we expect 'd'
expectedJA4Prefix := "t13d5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
} else {
// Also accept 'i' variant for IP connections
altPrefix := "t13i5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
} else {
t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
}
}
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
} else {
t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
}
// Verify extension list (should be 11 extensions including SNI)
// Expected: 0-11-10-35-16-22-23-13-43-45-51
expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
} else {
t.Logf("Warning: JA3 extension list may differ")
}
}
// TestDialerWithProfile tests that different profiles produce different fingerprints.
func TestDialerWithProfile(t *testing.T) {
// Create two dialers with different profiles
@@ -305,139 +158,3 @@ func mustParseURL(rawURL string) *url.URL {
}
return u
}
// TestProfileExpectation defines expected fingerprint values for a profile.
type TestProfileExpectation struct {
Profile *Profile
ExpectedJA3 string // Expected JA3 hash (empty = don't check)
ExpectedJA4 string // Expected full JA4 (empty = don't check)
JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
}
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
// Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
func TestAllProfiles(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
// Define all profiles to test with their expected fingerprints
// These profiles are from config.yaml gateway.tls_fingerprint.profiles
profiles := []TestProfileExpectation{
{
// Linux x64 Node.js v22.17.1
// Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
Profile: &Profile{
Name: "linux_x64_node_v22171",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part
},
{
// MacOS arm64 Node.js v22.18.0
// Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
// Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
Profile: &Profile{
Name: "macos_arm64_node_v22180",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
},
}
for _, tc := range profiles {
tc := tc // capture range variable
t.Run(tc.Profile.Name, func(t *testing.T) {
fp := fetchFingerprint(t, tc.Profile)
if fp == nil {
return // fetchFingerprint already called t.Fatal
}
t.Logf("Profile: %s", tc.Profile.Name)
t.Logf(" JA3: %s", fp.JA3)
t.Logf(" JA3 Hash: %s", fp.JA3Hash)
t.Logf(" JA4: %s", fp.JA4)
t.Logf(" PeetPrint: %s", fp.PeetPrint)
t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
// Verify expectations
if tc.ExpectedJA3 != "" {
if fp.JA3Hash == tc.ExpectedJA3 {
t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
} else {
t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
}
}
if tc.ExpectedJA4 != "" {
if fp.JA4 == tc.ExpectedJA4 {
t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
} else {
t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
}
}
// Check JA4 cipher hash (stable middle part)
// JA4 format: prefix_cipherHash_extHash
if tc.JA4CipherHash != "" {
if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
} else {
t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
}
}
})
}
}
// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
t.Helper()
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
return nil
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("failed to get fingerprint: %v", err)
return nil
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
return nil
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
return nil
}
return &fpResp.TLS
}

View File

@@ -0,0 +1,95 @@
package repository
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// AESEncryptor implements SecretEncryptor using AES-256-GCM
type AESEncryptor struct {
key []byte
}
// NewAESEncryptor creates a new AES encryptor
func NewAESEncryptor(cfg *config.Config) (service.SecretEncryptor, error) {
key, err := hex.DecodeString(cfg.Totp.EncryptionKey)
if err != nil {
return nil, fmt.Errorf("invalid totp encryption key: %w", err)
}
if len(key) != 32 {
return nil, fmt.Errorf("totp encryption key must be 32 bytes (64 hex chars), got %d bytes", len(key))
}
return &AESEncryptor{key: key}, nil
}
// Encrypt encrypts plaintext using AES-256-GCM
// Output format: base64(nonce + ciphertext + tag)
func (e *AESEncryptor) Encrypt(plaintext string) (string, error) {
block, err := aes.NewCipher(e.key)
if err != nil {
return "", fmt.Errorf("create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("create gcm: %w", err)
}
// Generate a random nonce
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", fmt.Errorf("generate nonce: %w", err)
}
// Encrypt the plaintext
// Seal appends the ciphertext and tag to the nonce
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
// Encode as base64
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// Decrypt decrypts ciphertext using AES-256-GCM
func (e *AESEncryptor) Decrypt(ciphertext string) (string, error) {
// Decode from base64
data, err := base64.StdEncoding.DecodeString(ciphertext)
if err != nil {
return "", fmt.Errorf("decode base64: %w", err)
}
block, err := aes.NewCipher(e.key)
if err != nil {
return "", fmt.Errorf("create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("create gcm: %w", err)
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return "", fmt.Errorf("ciphertext too short")
}
// Extract nonce and ciphertext
nonce, ciphertextData := data[:nonceSize], data[nonceSize:]
// Decrypt
plaintext, err := gcm.Open(nil, nonce, ciphertextData, nil)
if err != nil {
return "", fmt.Errorf("decrypt: %w", err)
}
return string(plaintext), nil
}

View File

@@ -387,17 +387,20 @@ func userEntityToService(u *dbent.User) *service.User {
return nil
}
return &service.User{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
ID: u.ID,
Email: u.Email,
Username: u.Username,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
TotpSecretEncrypted: u.TotpSecretEncrypted,
TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}

View File

@@ -35,7 +35,9 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
client := s.clientFactory(proxyURL)
var orgs []struct {
UUID string `json:"uuid"`
UUID string `json:"uuid"`
Name string `json:"name"`
RavenType *string `json:"raven_type"` // nil for personal, "team" for team organization
}
targetURL := s.baseURL + "/api/organizations"
@@ -65,7 +67,23 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
return "", fmt.Errorf("no organizations found")
}
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
// 如果只有一个组织,直接使用
if len(orgs) == 1 {
log.Printf("[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
return orgs[0].UUID, nil
}
// 如果有多个组织,优先选择 raven_type 为 "team" 的组织
for _, org := range orgs {
if org.RavenType != nil && *org.RavenType == "team" {
log.Printf("[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s",
org.UUID, org.Name, *org.RavenType)
return org.UUID, nil
}
}
// 如果没有 team 类型的组织,使用第一个
log.Printf("[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
return orgs[0].UUID, nil
}

View File

@@ -9,13 +9,27 @@ import (
"github.com/redis/go-redis/v9"
)
const verifyCodeKeyPrefix = "verify_code:"
const (
verifyCodeKeyPrefix = "verify_code:"
passwordResetKeyPrefix = "password_reset:"
passwordResetSentAtKeyPrefix = "password_reset_sent:"
)
// verifyCodeKey generates the Redis key for email verification code.
func verifyCodeKey(email string) string {
return verifyCodeKeyPrefix + email
}
// passwordResetKey generates the Redis key for password reset token.
func passwordResetKey(email string) string {
return passwordResetKeyPrefix + email
}
// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
func passwordResetSentAtKey(email string) string {
return passwordResetSentAtKeyPrefix + email
}
type emailCache struct {
rdb *redis.Client
}
@@ -50,3 +64,45 @@ func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) e
key := verifyCodeKey(email)
return c.rdb.Del(ctx, key).Err()
}
// Password reset token methods
func (c *emailCache) GetPasswordResetToken(ctx context.Context, email string) (*service.PasswordResetTokenData, error) {
key := passwordResetKey(email)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
}
var data service.PasswordResetTokenData
if err := json.Unmarshal([]byte(val), &data); err != nil {
return nil, err
}
return &data, nil
}
func (c *emailCache) SetPasswordResetToken(ctx context.Context, email string, data *service.PasswordResetTokenData, ttl time.Duration) error {
key := passwordResetKey(email)
val, err := json.Marshal(data)
if err != nil {
return err
}
return c.rdb.Set(ctx, key, val, ttl).Err()
}
func (c *emailCache) DeletePasswordResetToken(ctx context.Context, email string) error {
key := passwordResetKey(email)
return c.rdb.Del(ctx, key).Err()
}
// Password reset email cooldown methods
func (c *emailCache) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
key := passwordResetSentAtKey(email)
exists, err := c.rdb.Exists(ctx, key).Result()
return err == nil && exists > 0
}
func (c *emailCache) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
key := passwordResetSentAtKey(email)
return c.rdb.Set(ctx, key, "1", ttl).Err()
}

View File

@@ -2,11 +2,11 @@ package repository
import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3"
@@ -22,7 +22,7 @@ type openaiOAuthService struct {
}
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(s.tokenURL, proxyURL)
client := createOpenAIReqClient(proxyURL)
if redirectURI == "" {
redirectURI = openai.DefaultRedirectURI
@@ -39,23 +39,24 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
resp, err := client.R().
SetContext(ctx).
SetHeader("User-Agent", "codex-cli/0.91.0").
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_EXCHANGE_FAILED", "token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
}
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(s.tokenURL, proxyURL)
client := createOpenAIReqClient(proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
@@ -67,29 +68,25 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
resp, err := client.R().
SetContext(ctx).
SetHeader("User-Agent", "codex-cli/0.91.0").
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
}
func createOpenAIReqClient(tokenURL, proxyURL string) *req.Client {
forceHTTP2 := false
if parsedURL, err := url.Parse(tokenURL); err == nil {
forceHTTP2 = strings.EqualFold(parsedURL.Scheme, "https")
}
func createOpenAIReqClient(proxyURL string) *req.Client {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 120 * time.Second,
ForceHTTP2: forceHTTP2,
ProxyURL: proxyURL,
Timeout: 120 * time.Second,
})
}

View File

@@ -77,21 +77,9 @@ func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) {
require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts))
}
func TestCreateOpenAIReqClient_ForceHTTP2Enabled(t *testing.T) {
sharedReqClients = sync.Map{}
client := createOpenAIReqClient("https://auth.openai.com/oauth/token", "http://proxy.local:8080")
require.Equal(t, "2", forceHTTPVersion(t, client))
}
func TestCreateOpenAIReqClient_ForceHTTP2DisabledForHTTP(t *testing.T) {
sharedReqClients = sync.Map{}
client := createOpenAIReqClient("http://localhost/oauth/token", "http://proxy.local:8080")
require.Equal(t, "", forceHTTPVersion(t, client))
}
func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) {
sharedReqClients = sync.Map{}
client := createOpenAIReqClient("https://auth.openai.com/oauth/token", "http://proxy.local:8080")
client := createOpenAIReqClient("http://proxy.local:8080")
require.Equal(t, 120*time.Second, client.GetClient().Timeout)
}

View File

@@ -0,0 +1,149 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"github.com/Wei-Shaw/sub2api/internal/service"
)
const (
totpSetupKeyPrefix = "totp:setup:"
totpLoginKeyPrefix = "totp:login:"
totpAttemptsKeyPrefix = "totp:attempts:"
totpAttemptsTTL = 15 * time.Minute
)
// TotpCache implements service.TotpCache using Redis
type TotpCache struct {
rdb *redis.Client
}
// NewTotpCache creates a new TOTP cache
func NewTotpCache(rdb *redis.Client) service.TotpCache {
return &TotpCache{rdb: rdb}
}
// GetSetupSession retrieves a TOTP setup session
func (c *TotpCache) GetSetupSession(ctx context.Context, userID int64) (*service.TotpSetupSession, error) {
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
data, err := c.rdb.Get(ctx, key).Bytes()
if err != nil {
if err == redis.Nil {
return nil, nil
}
return nil, fmt.Errorf("get setup session: %w", err)
}
var session service.TotpSetupSession
if err := json.Unmarshal(data, &session); err != nil {
return nil, fmt.Errorf("unmarshal setup session: %w", err)
}
return &session, nil
}
// SetSetupSession stores a TOTP setup session
func (c *TotpCache) SetSetupSession(ctx context.Context, userID int64, session *service.TotpSetupSession, ttl time.Duration) error {
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
data, err := json.Marshal(session)
if err != nil {
return fmt.Errorf("marshal setup session: %w", err)
}
if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
return fmt.Errorf("set setup session: %w", err)
}
return nil
}
// DeleteSetupSession deletes a TOTP setup session
func (c *TotpCache) DeleteSetupSession(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
return c.rdb.Del(ctx, key).Err()
}
// GetLoginSession retrieves a TOTP login session
func (c *TotpCache) GetLoginSession(ctx context.Context, tempToken string) (*service.TotpLoginSession, error) {
key := totpLoginKeyPrefix + tempToken
data, err := c.rdb.Get(ctx, key).Bytes()
if err != nil {
if err == redis.Nil {
return nil, nil
}
return nil, fmt.Errorf("get login session: %w", err)
}
var session service.TotpLoginSession
if err := json.Unmarshal(data, &session); err != nil {
return nil, fmt.Errorf("unmarshal login session: %w", err)
}
return &session, nil
}
// SetLoginSession stores a TOTP login session
func (c *TotpCache) SetLoginSession(ctx context.Context, tempToken string, session *service.TotpLoginSession, ttl time.Duration) error {
key := totpLoginKeyPrefix + tempToken
data, err := json.Marshal(session)
if err != nil {
return fmt.Errorf("marshal login session: %w", err)
}
if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
return fmt.Errorf("set login session: %w", err)
}
return nil
}
// DeleteLoginSession deletes a TOTP login session
func (c *TotpCache) DeleteLoginSession(ctx context.Context, tempToken string) error {
key := totpLoginKeyPrefix + tempToken
return c.rdb.Del(ctx, key).Err()
}
// IncrementVerifyAttempts increments the verify attempt counter
func (c *TotpCache) IncrementVerifyAttempts(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
// Use pipeline for atomic increment and set TTL
pipe := c.rdb.Pipeline()
incrCmd := pipe.Incr(ctx, key)
pipe.Expire(ctx, key, totpAttemptsTTL)
if _, err := pipe.Exec(ctx); err != nil {
return 0, fmt.Errorf("increment verify attempts: %w", err)
}
count, err := incrCmd.Result()
if err != nil {
return 0, fmt.Errorf("get increment result: %w", err)
}
return int(count), nil
}
// GetVerifyAttempts gets the current verify attempt count
func (c *TotpCache) GetVerifyAttempts(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
count, err := c.rdb.Get(ctx, key).Int()
if err != nil {
if err == redis.Nil {
return 0, nil
}
return 0, fmt.Errorf("get verify attempts: %w", err)
}
return count, nil
}
// ClearVerifyAttempts clears the verify attempt counter
func (c *TotpCache) ClearVerifyAttempts(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"sort"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
@@ -466,3 +467,46 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
client := clientFromContext(ctx, r.client)
update := client.User.UpdateOneID(userID)
if encryptedSecret == nil {
update = update.ClearTotpSecretEncrypted()
} else {
update = update.SetTotpSecretEncrypted(*encryptedSecret)
}
_, err := update.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return nil
}
// EnableTotp 启用用户的 TOTP 双因素认证
func (r *userRepository) EnableTotp(ctx context.Context, userID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.User.UpdateOneID(userID).
SetTotpEnabled(true).
SetTotpEnabledAt(time.Now()).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return nil
}
// DisableTotp 禁用用户的 TOTP 双因素认证
func (r *userRepository) DisableTotp(ctx context.Context, userID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.User.UpdateOneID(userID).
SetTotpEnabled(false).
ClearTotpEnabledAt().
ClearTotpSecretEncrypted().
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return nil
}

View File

@@ -190,7 +190,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
}
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query()
if userID != nil {
@@ -199,7 +199,31 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
if groupID != nil {
q = q.Where(usersubscription.GroupIDEQ(*groupID))
}
if status != "" {
// Status filtering with real-time expiration check
now := time.Now()
switch status {
case service.SubscriptionStatusActive:
// Active: status is active AND not yet expired
q = q.Where(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtGT(now),
)
case service.SubscriptionStatusExpired:
// Expired: status is expired OR (status is active but already expired)
q = q.Where(
usersubscription.Or(
usersubscription.StatusEQ(service.SubscriptionStatusExpired),
usersubscription.And(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtLTE(now),
),
),
)
case "":
// No filter
default:
// Other status (e.g., revoked)
q = q.Where(usersubscription.StatusEQ(status))
}
@@ -208,11 +232,28 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
return nil, nil, err
}
// Apply sorting
q = q.WithUser().WithGroup().WithAssignedByUser()
// Determine sort field
var field string
switch sortBy {
case "expires_at":
field = usersubscription.FieldExpiresAt
case "status":
field = usersubscription.FieldStatus
default:
field = usersubscription.FieldCreatedAt
}
// Determine sort order (default: desc)
if sortOrder == "asc" && sortBy != "" {
q = q.Order(dbent.Asc(field))
} else {
q = q.Order(dbent.Desc(field))
}
subs, err := q.
WithUser().
WithGroup().
WithAssignedByUser().
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
Offset(params.Offset()).
Limit(params.Limit()).
All(ctx)

View File

@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
group := s.mustCreateGroup("g-list")
s.mustCreateSubscription(user.ID, group.ID, nil)
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "")
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "")
s.Require().NoError(err, "List")
s.Require().Len(subs, 1)
s.Require().Equal(int64(1), page.Total)
@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
s.mustCreateSubscription(user1.ID, group.ID, nil)
s.mustCreateSubscription(user2.ID, group.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "")
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(user1.ID, subs[0].UserID)
@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
s.mustCreateSubscription(user.ID, g1.ID, nil)
s.mustCreateSubscription(user.ID, g2.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "")
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(g1.ID, subs[0].GroupID)
@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
})
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)

View File

@@ -82,6 +82,10 @@ var ProviderSet = wire.NewSet(
NewSchedulerCache,
NewSchedulerOutboxRepository,
NewProxyLatencyCache,
NewTotpCache,
// Encryptors
NewAESEncryptor,
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier,

View File

@@ -193,20 +193,20 @@ func TestAPIContracts(t *testing.T) {
// 普通用户订阅接口不应包含 assigned_* / notes 等管理员字段。
deps.userSubRepo.SetByUserID(1, []service.UserSubscription{
{
ID: 501,
UserID: 1,
GroupID: 10,
StartsAt: deps.now,
ExpiresAt: deps.now.Add(24 * time.Hour),
Status: service.SubscriptionStatusActive,
ID: 501,
UserID: 1,
GroupID: 10,
StartsAt: deps.now,
ExpiresAt: time.Date(2099, 1, 2, 3, 4, 5, 0, time.UTC), // 使用未来日期避免 normalizeSubscriptionStatus 标记为过期
Status: service.SubscriptionStatusActive,
DailyUsageUSD: 1.23,
WeeklyUsageUSD: 2.34,
MonthlyUsageUSD: 3.45,
AssignedBy: ptr(int64(999)),
AssignedAt: deps.now,
Notes: "admin-note",
CreatedAt: deps.now,
UpdatedAt: deps.now,
AssignedBy: ptr(int64(999)),
AssignedAt: deps.now,
Notes: "admin-note",
CreatedAt: deps.now,
UpdatedAt: deps.now,
},
})
},
@@ -222,7 +222,7 @@ func TestAPIContracts(t *testing.T) {
"user_id": 1,
"group_id": 10,
"starts_at": "2025-01-02T03:04:05Z",
"expires_at": "2025-01-03T03:04:05Z",
"expires_at": "2099-01-02T03:04:05Z",
"status": "active",
"daily_window_start": null,
"weekly_window_start": null,
@@ -452,6 +452,9 @@ func TestAPIContracts(t *testing.T) {
"registration_enabled": true,
"email_verify_enabled": false,
"promo_code_enabled": true,
"password_reset_enabled": false,
"totp_enabled": false,
"totp_encryption_key_configured": false,
"smtp_host": "smtp.example.com",
"smtp_port": 587,
"smtp_username": "user",
@@ -595,7 +598,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
@@ -754,6 +757,18 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
return 0, errors.New("not implemented")
}
func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error {
return errors.New("not implemented")
}
type stubApiKeyCache struct{}
func (stubApiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
@@ -1176,7 +1191,7 @@ func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userI
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {

View File

@@ -367,7 +367,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in
return nil, nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}

View File

@@ -26,11 +26,20 @@ func RegisterAuthRoutes(
{
auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login)
auth.POST("/login/2fa", h.Auth.Login2FA)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
// 优惠码验证接口添加速率限制:每分钟最多 10 次Redis 故障时 fail-close
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ValidatePromoCode)
// 忘记密码接口添加速率限制:每分钟最多 5 次Redis 故障时 fail-close
auth.POST("/forgot-password", rateLimiter.LimitWithOptions("forgot-password", 5, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ForgotPassword)
// 重置密码接口添加速率限制:每分钟最多 10 次Redis 故障时 fail-close
auth.POST("/reset-password", rateLimiter.LimitWithOptions("reset-password", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ResetPassword)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
}

View File

@@ -22,6 +22,17 @@ func RegisterUserRoutes(
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile)
// TOTP 双因素认证
totp := user.Group("/totp")
{
totp.GET("/status", h.Totp.GetStatus)
totp.GET("/verification-method", h.Totp.GetVerificationMethod)
totp.POST("/send-code", h.Totp.SendVerifyCode)
totp.POST("/setup", h.Totp.InitiateSetup)
totp.POST("/enable", h.Totp.Enable)
totp.POST("/disable", h.Totp.Disable)
}
}
// API Key管理

View File

@@ -197,6 +197,35 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time {
return nil
}
// GetCredentialAsInt64 解析凭证中的 int64 字段
// 用于读取 _token_version 等内部字段
func (a *Account) GetCredentialAsInt64(key string) int64 {
if a == nil || a.Credentials == nil {
return 0
}
val, ok := a.Credentials[key]
if !ok || val == nil {
return 0
}
switch v := val.(type) {
case int64:
return v
case float64:
return int64(v)
case int:
return int64(v)
case json.Number:
if i, err := v.Int64(); err == nil {
return i
}
case string:
if i, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64); err == nil {
return i
}
}
return 0
}
func (a *Account) IsTempUnschedulableEnabled() bool {
if a.Credentials == nil {
return false

View File

@@ -93,6 +93,18 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
panic("unexpected RemoveGroupFromAllowedGroups call")
}
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
func (s *userRepoStub) EnableTotp(ctx context.Context, userID int64) error {
panic("unexpected EnableTotp call")
}
func (s *userRepoStub) DisableTotp(ctx context.Context, userID int64) error {
panic("unexpected DisableTotp call")
}
type groupRepoStub struct {
affectedUserIDs []int64
deleteErr error

View File

@@ -1305,6 +1305,14 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
return nil, err
}
// 清理 Schema
if cleanedBody, err := cleanGeminiRequest(injectedBody); err == nil {
injectedBody = cleanedBody
log.Printf("[Antigravity] Cleaned request schema in forwarded request for account %s", account.Name)
} else {
log.Printf("[Antigravity] Failed to clean schema: %v", err)
}
// 包装请求
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody)
if err != nil {
@@ -1705,6 +1713,19 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
// Check for MALFORMED_FUNCTION_CALL
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok {
if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward stream")
if content, ok := cand["content"]; ok {
if b, err := json.Marshal(content); err == nil {
log.Printf("[Antigravity] Malformed content: %s", string(b))
}
}
}
}
}
}
if firstTokenMs == nil {
@@ -1854,6 +1875,20 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
usage = u
}
// Check for MALFORMED_FUNCTION_CALL
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok {
if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward non-stream collect")
if content, ok := cand["content"]; ok {
if b, err := json.Marshal(content); err == nil {
log.Printf("[Antigravity] Malformed content: %s", string(b))
}
}
}
}
}
// 保留最后一个有 parts 的响应
if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed
@@ -1950,6 +1985,58 @@ func getOrCreateGeminiParts(response map[string]any) (result map[string]any, exi
return result, existingParts, setParts
}
// mergeCollectedPartsToResponse 将收集的所有 parts 合并到 Gemini 响应中
// 这个函数会合并所有类型的 partstext、thinking、functionCall、inlineData 等
// 保持原始顺序,只合并连续的普通 text parts
func mergeCollectedPartsToResponse(response map[string]any, collectedParts []map[string]any) map[string]any {
if len(collectedParts) == 0 {
return response
}
result, _, setParts := getOrCreateGeminiParts(response)
// 合并策略:
// 1. 保持原始顺序
// 2. 连续的普通 text parts 合并为一个
// 3. thinking、functionCall、inlineData 等保持原样
var mergedParts []any
var textBuffer strings.Builder
flushTextBuffer := func() {
if textBuffer.Len() > 0 {
mergedParts = append(mergedParts, map[string]any{
"text": textBuffer.String(),
})
textBuffer.Reset()
}
}
for _, part := range collectedParts {
// 检查是否是普通 text part
if text, ok := part["text"].(string); ok {
// 检查是否有 thought 标记
if thought, _ := part["thought"].(bool); thought {
// thinking part先刷新 text buffer然后保留原样
flushTextBuffer()
mergedParts = append(mergedParts, part)
} else {
// 普通 text累积到 buffer
_, _ = textBuffer.WriteString(text)
}
} else {
// 非 text partfunctionCall、inlineData 等),先刷新 text buffer然后保留原样
flushTextBuffer()
mergedParts = append(mergedParts, part)
}
}
// 刷新剩余的 text
flushTextBuffer()
setParts(mergedParts)
return result
}
// mergeImagePartsToResponse 将收集到的图片 parts 合并到 Gemini 响应中
func mergeImagePartsToResponse(response map[string]any, imageParts []map[string]any) map[string]any {
if len(imageParts) == 0 {
@@ -2133,6 +2220,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
var firstTokenMs *int
var last map[string]any
var lastWithParts map[string]any
var collectedParts []map[string]any // 收集所有 parts包括 text、thinking、functionCall、inlineData 等)
type scanEvent struct {
line string
@@ -2227,9 +2315,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
last = parsed
// 保留最后一个有 parts 的响应
// 保留最后一个有 parts 的响应,并收集所有 parts
if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed
// 收集所有 partstext、thinking、functionCall、inlineData 等)
collectedParts = append(collectedParts, parts...)
}
case <-intervalCh:
@@ -2252,6 +2343,11 @@ returnResponse:
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream")
}
// 将收集的所有 parts 合并到最终响应中
if len(collectedParts) > 0 {
finalResponse = mergeCollectedPartsToResponse(finalResponse, collectedParts)
}
// 序列化为 JSONGemini 格式)
geminiBody, err := json.Marshal(finalResponse)
if err != nil {
@@ -2459,3 +2555,55 @@ func isImageGenerationModel(model string) bool {
modelLower == "gemini-2.5-flash-image-preview" ||
strings.HasPrefix(modelLower, "gemini-2.5-flash-image-")
}
// cleanGeminiRequest 清理 Gemini 请求体中的 Schema
func cleanGeminiRequest(body []byte) ([]byte, error) {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return nil, err
}
modified := false
// 1. 清理 Tools
if tools, ok := payload["tools"].([]any); ok && len(tools) > 0 {
for _, t := range tools {
toolMap, ok := t.(map[string]any)
if !ok {
continue
}
// function_declarations (snake_case) or functionDeclarations (camelCase)
var funcs []any
if f, ok := toolMap["functionDeclarations"].([]any); ok {
funcs = f
} else if f, ok := toolMap["function_declarations"].([]any); ok {
funcs = f
}
if len(funcs) == 0 {
continue
}
for _, f := range funcs {
funcMap, ok := f.(map[string]any)
if !ok {
continue
}
if params, ok := funcMap["parameters"].(map[string]any); ok {
antigravity.DeepCleanUndefined(params)
cleaned := antigravity.CleanJSONSchema(params)
funcMap["parameters"] = cleaned
modified = true
}
}
}
}
if !modified {
return body, nil
}
return json.Marshal(payload)
}

View File

@@ -142,12 +142,13 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
result.Email = userInfo.Email
}
// 获取 project_id部分账户类型可能没有
loadResp, _, err := client.LoadCodeAssist(ctx, tokenResp.AccessToken)
if err != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败: %v\n", err)
} else if loadResp != nil && loadResp.CloudAICompanionProject != "" {
result.ProjectID = loadResp.CloudAICompanionProject
// 获取 project_id部分账户类型可能没有,失败时重试
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenResp.AccessToken, proxyURL, 3)
if loadErr != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr)
result.ProjectIDMissing = true
} else {
result.ProjectID = projectID
}
return result, nil
@@ -237,21 +238,60 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
tokenInfo.Email = existingEmail
}
// 每次刷新都调用 LoadCodeAssist 获取 project_id
client := antigravity.NewClient(proxyURL)
loadResp, _, err := client.LoadCodeAssist(ctx, tokenInfo.AccessToken)
if err != nil || loadResp == nil || loadResp.CloudAICompanionProject == "" {
// LoadCodeAssist 失败或返回空,保留原有 project_id标记缺失
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
// 每次刷新都调用 LoadCodeAssist 获取 project_id,失败时重试
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
if loadErr != nil {
// LoadCodeAssist 失败,保留原有 project_id
tokenInfo.ProjectID = existingProjectID
tokenInfo.ProjectIDMissing = true
// 只有从未获取过 project_id 且本次也获取失败时,才标记为真正缺失
// 如果之前有 project_id本次只是临时故障不应标记为错误
if existingProjectID == "" {
tokenInfo.ProjectIDMissing = true
}
} else {
tokenInfo.ProjectID = loadResp.CloudAICompanionProject
tokenInfo.ProjectID = projectID
}
return tokenInfo, nil
}
// loadProjectIDWithRetry 带重试机制获取 project_id
// 返回 project_id 和错误,失败时会重试指定次数
func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, accessToken, proxyURL string, maxRetries int) (string, error) {
var lastErr error
for attempt := 0; attempt <= maxRetries; attempt++ {
if attempt > 0 {
// 指数退避1s, 2s, 4s
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
if backoff > 8*time.Second {
backoff = 8 * time.Second
}
time.Sleep(backoff)
}
client := antigravity.NewClient(proxyURL)
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
return loadResp.CloudAICompanionProject, nil
}
// 记录错误
if err != nil {
lastErr = err
} else if loadResp == nil {
lastErr = fmt.Errorf("LoadCodeAssist 返回空响应")
} else {
lastErr = fmt.Errorf("LoadCodeAssist 返回空 project_id")
}
}
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
}
// BuildAccountCredentials 构建账户凭证
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
creds := map[string]any{

View File

@@ -94,14 +94,14 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
var handleErrorCalled bool
result, err := antigravityRetryLoop(antigravityRetryLoopParams{
prefix: "[test]",
ctx: context.Background(),
account: account,
proxyURL: "",
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
quotaScope: AntigravityQuotaScopeClaude,
prefix: "[test]",
ctx: context.Background(),
account: account,
proxyURL: "",
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
quotaScope: AntigravityQuotaScopeClaude,
httpUpstream: upstream,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
handleErrorCalled = true

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"log"
"log/slog"
"strconv"
"strings"
"time"
@@ -101,21 +102,32 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return "", errors.New("access_token not found in credentials")
}
// 3. 存入缓存
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if p.tokenCache != nil {
ttl := 30 * time.Minute
if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > antigravityTokenCacheSkew:
ttl = until - antigravityTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if isStale && latestAccount != nil {
// 版本过时,使用 DB 中的最新 token
slog.Debug("antigravity_token_version_stale_use_latest", "account_id", account.ID)
accessToken = latestAccount.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found after version check")
}
// 不写入缓存,让下次请求重新处理
} else {
ttl := 30 * time.Minute
if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > antigravityTokenCacheSkew:
ttl = until - antigravityTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
return accessToken, nil

View File

@@ -3,6 +3,8 @@ package service
import (
"context"
"fmt"
"log"
"strings"
"time"
)
@@ -55,15 +57,32 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
}
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
// 合并旧的 credentials保留新 credentials 中不存在的字段
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
// 如果 project_id 获取失败,返回 credentials 但同时返回错误让账户被标记
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
// 这确保了即使 LoadCodeAssist 失败project_id 也不会丢失
if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
newCredentials["project_id"] = oldProjectID
}
}
// 如果 project_id 获取失败,只记录警告,不返回错误
// LoadCodeAssist 失败可能是临时网络问题,应该允许重试而不是立即标记为不可重试错误
// Token 刷新本身是成功的access_token 和 refresh_token 已更新)
if tokenInfo.ProjectIDMissing {
return newCredentials, fmt.Errorf("missing_project_id: 账户缺少project id可能无法使用Antigravity")
if tokenInfo.ProjectID != "" {
// 有旧的 project_id本次获取失败保留旧值
log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 临时失败,保留旧 project_id", account.ID)
} else {
// 从未获取过 project_id本次也失败但不返回错误以允许下次重试
log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 失败project_id 缺失,但 token 已更新,将在下次刷新时重试", account.ID)
}
}
return newCredentials, nil

View File

@@ -580,3 +580,149 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
// 生成新token
return s.GenerateToken(user)
}
// IsPasswordResetEnabled 检查是否启用密码重置功能
// 要求:必须同时开启邮件验证且 SMTP 配置正确
func (s *AuthService) IsPasswordResetEnabled(ctx context.Context) bool {
if s.settingService == nil {
return false
}
// Must have email verification enabled and SMTP configured
if !s.settingService.IsEmailVerifyEnabled(ctx) {
return false
}
return s.settingService.IsPasswordResetEnabled(ctx)
}
// preparePasswordReset validates the password reset request and returns necessary data
// Returns (siteName, resetURL, shouldProceed)
// shouldProceed is false when we should silently return success (to prevent enumeration)
func (s *AuthService) preparePasswordReset(ctx context.Context, email, frontendBaseURL string) (string, string, bool) {
// Check if user exists (but don't reveal this to the caller)
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
// Security: Log but don't reveal that user doesn't exist
log.Printf("[Auth] Password reset requested for non-existent email: %s", email)
return "", "", false
}
log.Printf("[Auth] Database error checking email for password reset: %v", err)
return "", "", false
}
// Check if user is active
if !user.IsActive() {
log.Printf("[Auth] Password reset requested for inactive user: %s", email)
return "", "", false
}
// Get site name
siteName := "Sub2API"
if s.settingService != nil {
siteName = s.settingService.GetSiteName(ctx)
}
// Build reset URL base
resetURL := fmt.Sprintf("%s/reset-password", strings.TrimSuffix(frontendBaseURL, "/"))
return siteName, resetURL, true
}
// RequestPasswordReset 请求密码重置(同步发送)
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendBaseURL string) error {
if !s.IsPasswordResetEnabled(ctx) {
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
}
if s.emailService == nil {
return ErrServiceUnavailable
}
siteName, resetURL, shouldProceed := s.preparePasswordReset(ctx, email, frontendBaseURL)
if !shouldProceed {
return nil // Silent success to prevent enumeration
}
if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
log.Printf("[Auth] Failed to send password reset email to %s: %v", email, err)
return nil // Silent success to prevent enumeration
}
log.Printf("[Auth] Password reset email sent to: %s", email)
return nil
}
// RequestPasswordResetAsync 异步请求密码重置(队列发送)
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, frontendBaseURL string) error {
if !s.IsPasswordResetEnabled(ctx) {
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
}
if s.emailQueueService == nil {
return ErrServiceUnavailable
}
siteName, resetURL, shouldProceed := s.preparePasswordReset(ctx, email, frontendBaseURL)
if !shouldProceed {
return nil // Silent success to prevent enumeration
}
if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL); err != nil {
log.Printf("[Auth] Failed to enqueue password reset email for %s: %v", email, err)
return nil // Silent success to prevent enumeration
}
log.Printf("[Auth] Password reset email enqueued for: %s", email)
return nil
}
// ResetPassword 重置密码
// Security: Increments TokenVersion to invalidate all existing JWT tokens
func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPassword string) error {
// Check if password reset is enabled
if !s.IsPasswordResetEnabled(ctx) {
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
}
if s.emailService == nil {
return ErrServiceUnavailable
}
// Verify and consume the reset token (one-time use)
if err := s.emailService.ConsumePasswordResetToken(ctx, email, token); err != nil {
return err
}
// Get user
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
return ErrInvalidResetToken // Token was valid but user was deleted
}
log.Printf("[Auth] Database error getting user for password reset: %v", err)
return ErrServiceUnavailable
}
// Check if user is active
if !user.IsActive() {
return ErrUserNotActive
}
// Hash new password
hashedPassword, err := s.HashPassword(newPassword)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
// Update password and increment TokenVersion
user.PasswordHash = hashedPassword
user.TokenVersion++ // Invalidate all existing tokens
if err := s.userRepo.Update(ctx, user); err != nil {
log.Printf("[Auth] Database error updating password for user %d: %v", user.ID, err)
return ErrServiceUnavailable
}
log.Printf("[Auth] Password reset successful for user: %s", email)
return nil
}

View File

@@ -71,6 +71,26 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin
return nil
}
func (s *emailCacheStub) GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error) {
return nil, nil
}
func (s *emailCacheStub) SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error {
return nil
}
func (s *emailCacheStub) DeletePasswordResetToken(ctx context.Context, email string) error {
return nil
}
func (s *emailCacheStub) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
return false
}
func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
return nil
}
func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService {
cfg := &config.Config{
JWT: config.JWTConfig{

View File

@@ -181,26 +181,37 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return "", errors.New("access_token not found in credentials")
}
// 3. 存入缓存
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if p.tokenCache != nil {
ttl := 30 * time.Minute
if refreshFailed {
// 刷新失败时使用短 TTL避免失效 token 长时间缓存导致 401 抖动
ttl = time.Minute
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
} else if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > claudeTokenCacheSkew:
ttl = until - claudeTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if isStale && latestAccount != nil {
// 版本过时,使用 DB 中的最新 token
slog.Debug("claude_token_version_stale_use_latest", "account_id", account.ID)
accessToken = latestAccount.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found after version check")
}
// 不写入缓存,让下次请求重新处理
} else {
ttl := 30 * time.Minute
if refreshFailed {
// 刷新失败时使用短 TTL避免失效 token 长时间缓存导致 401 抖动
ttl = time.Minute
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
} else if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > claudeTokenCacheSkew:
ttl = until - claudeTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err)
}
}
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err)
}
}

View File

@@ -69,9 +69,10 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// Setting keys
const (
// 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
// 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
@@ -87,6 +88,9 @@ const (
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
// TOTP 双因素认证设置
SettingKeyTotpEnabled = "totp_enabled" // 是否启用 TOTP 2FA 功能
// LinuxDo Connect OAuth 登录设置
SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled"
SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id"

View File

@@ -8,11 +8,18 @@ import (
"time"
)
// Task type constants
const (
TaskTypeVerifyCode = "verify_code"
TaskTypePasswordReset = "password_reset"
)
// EmailTask 邮件发送任务
type EmailTask struct {
Email string
SiteName string
TaskType string // "verify_code"
TaskType string // "verify_code" or "password_reset"
ResetURL string // Only used for password_reset task type
}
// EmailQueueService 异步邮件队列服务
@@ -73,12 +80,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
defer cancel()
switch task.TaskType {
case "verify_code":
case TaskTypeVerifyCode:
if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil {
log.Printf("[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err)
} else {
log.Printf("[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email)
}
case TaskTypePasswordReset:
if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL); err != nil {
log.Printf("[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err)
} else {
log.Printf("[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email)
}
default:
log.Printf("[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType)
}
@@ -89,7 +102,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
task := EmailTask{
Email: email,
SiteName: siteName,
TaskType: "verify_code",
TaskType: TaskTypeVerifyCode,
}
select {
@@ -101,6 +114,24 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
}
}
// EnqueuePasswordReset 将密码重置邮件任务加入队列
func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL string) error {
task := EmailTask{
Email: email,
SiteName: siteName,
TaskType: TaskTypePasswordReset,
ResetURL: resetURL,
}
select {
case s.taskChan <- task:
log.Printf("[EmailQueue] Enqueued password reset task for %s", email)
return nil
default:
return fmt.Errorf("email queue is full")
}
}
// Stop 停止队列服务
func (s *EmailQueueService) Stop() {
close(s.stopChan)

View File

@@ -3,11 +3,14 @@ package service
import (
"context"
"crypto/rand"
"crypto/subtle"
"crypto/tls"
"encoding/hex"
"fmt"
"log"
"math/big"
"net/smtp"
"net/url"
"strconv"
"time"
@@ -19,6 +22,9 @@ var (
ErrInvalidVerifyCode = infraerrors.BadRequest("INVALID_VERIFY_CODE", "invalid or expired verification code")
ErrVerifyCodeTooFrequent = infraerrors.TooManyRequests("VERIFY_CODE_TOO_FREQUENT", "please wait before requesting a new code")
ErrVerifyCodeMaxAttempts = infraerrors.TooManyRequests("VERIFY_CODE_MAX_ATTEMPTS", "too many failed attempts, please request a new code")
// Password reset errors
ErrInvalidResetToken = infraerrors.BadRequest("INVALID_RESET_TOKEN", "invalid or expired password reset token")
)
// EmailCache defines cache operations for email service
@@ -26,6 +32,16 @@ type EmailCache interface {
GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error)
SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
DeleteVerificationCode(ctx context.Context, email string) error
// Password reset token methods
GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error)
SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error
DeletePasswordResetToken(ctx context.Context, email string) error
// Password reset email cooldown methods
// Returns true if in cooldown period (email was sent recently)
IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool
SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error
}
// VerificationCodeData represents verification code data
@@ -35,10 +51,22 @@ type VerificationCodeData struct {
CreatedAt time.Time
}
// PasswordResetTokenData represents password reset token data
type PasswordResetTokenData struct {
Token string
CreatedAt time.Time
}
const (
verifyCodeTTL = 15 * time.Minute
verifyCodeCooldown = 1 * time.Minute
maxVerifyCodeAttempts = 5
// Password reset token settings
passwordResetTokenTTL = 30 * time.Minute
// Password reset email cooldown (prevent email bombing)
passwordResetEmailCooldown = 30 * time.Second
)
// SMTPConfig SMTP配置
@@ -254,8 +282,8 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
return ErrVerifyCodeMaxAttempts
}
// 验证码不匹配
if data.Code != code {
// 验证码不匹配 (constant-time comparison to prevent timing attacks)
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
data.Attempts++
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
log.Printf("[Email] Failed to update verification attempt count: %v", err)
@@ -357,3 +385,157 @@ func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
return client.Quit()
}
// GeneratePasswordResetToken generates a secure 32-byte random token (64 hex characters)
func (s *EmailService) GeneratePasswordResetToken() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// SendPasswordResetEmail sends a password reset email with a reset link
func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteName, resetURL string) error {
var token string
var needSaveToken bool
// Check if token already exists
existing, err := s.cache.GetPasswordResetToken(ctx, email)
if err == nil && existing != nil {
// Token exists, reuse it (allows resending email without generating new token)
token = existing.Token
needSaveToken = false
} else {
// Generate new token
token, err = s.GeneratePasswordResetToken()
if err != nil {
return fmt.Errorf("generate token: %w", err)
}
needSaveToken = true
}
// Save token to Redis (only if new token generated)
if needSaveToken {
data := &PasswordResetTokenData{
Token: token,
CreatedAt: time.Now(),
}
if err := s.cache.SetPasswordResetToken(ctx, email, data, passwordResetTokenTTL); err != nil {
return fmt.Errorf("save reset token: %w", err)
}
}
// Build full reset URL with URL-encoded token and email
fullResetURL := fmt.Sprintf("%s?email=%s&token=%s", resetURL, url.QueryEscape(email), url.QueryEscape(token))
// Build email content
subject := fmt.Sprintf("[%s] 密码重置请求", siteName)
body := s.buildPasswordResetEmailBody(fullResetURL, siteName)
// Send email
if err := s.SendEmail(ctx, email, subject, body); err != nil {
return fmt.Errorf("send email: %w", err)
}
return nil
}
// SendPasswordResetEmailWithCooldown sends password reset email with cooldown check (called by queue worker)
// This method wraps SendPasswordResetEmail with email cooldown to prevent email bombing
func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error {
// Check email cooldown to prevent email bombing
if s.cache.IsPasswordResetEmailInCooldown(ctx, email) {
log.Printf("[Email] Password reset email skipped (cooldown): %s", email)
return nil // Silent success to prevent revealing cooldown to attackers
}
// Send email using core method
if err := s.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
return err
}
// Set cooldown marker (Redis TTL handles expiration)
if err := s.cache.SetPasswordResetEmailCooldown(ctx, email, passwordResetEmailCooldown); err != nil {
log.Printf("[Email] Failed to set password reset cooldown for %s: %v", email, err)
}
return nil
}
// VerifyPasswordResetToken verifies the password reset token without consuming it
func (s *EmailService) VerifyPasswordResetToken(ctx context.Context, email, token string) error {
data, err := s.cache.GetPasswordResetToken(ctx, email)
if err != nil || data == nil {
return ErrInvalidResetToken
}
// Use constant-time comparison to prevent timing attacks
if subtle.ConstantTimeCompare([]byte(data.Token), []byte(token)) != 1 {
return ErrInvalidResetToken
}
return nil
}
// ConsumePasswordResetToken verifies and deletes the token (one-time use)
func (s *EmailService) ConsumePasswordResetToken(ctx context.Context, email, token string) error {
// Verify first
if err := s.VerifyPasswordResetToken(ctx, email, token); err != nil {
return err
}
// Delete after verification (one-time use)
if err := s.cache.DeletePasswordResetToken(ctx, email); err != nil {
log.Printf("[Email] Failed to delete password reset token after consumption: %v", err)
}
return nil
}
// buildPasswordResetEmailBody builds the HTML content for password reset email
func (s *EmailService) buildPasswordResetEmailBody(resetURL, siteName string) string {
return fmt.Sprintf(`
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; text-align: center; }
.button { display: inline-block; background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 14px 32px; text-decoration: none; border-radius: 8px; font-size: 16px; font-weight: 600; margin: 20px 0; }
.button:hover { opacity: 0.9; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
.link-fallback { color: #666; font-size: 12px; word-break: break-all; margin-top: 20px; padding: 15px; background-color: #f8f9fa; border-radius: 4px; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
.warning { color: #e74c3c; font-weight: 500; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>%s</h1>
</div>
<div class="content">
<p style="font-size: 18px; color: #333;">密码重置请求</p>
<p style="color: #666;">您已请求重置密码。请点击下方按钮设置新密码:</p>
<a href="%s" class="button">重置密码</a>
<div class="info">
<p>此链接将在 <strong>30 分钟</strong>后失效。</p>
<p class="warning">如果您没有请求重置密码,请忽略此邮件。您的密码将保持不变。</p>
</div>
<div class="link-fallback">
<p>如果按钮无法点击,请复制以下链接到浏览器中打开:</p>
<p>%s</p>
</div>
</div>
<div class="footer">
<p>这是一封自动发送的邮件,请勿回复。</p>
</div>
</div>
</body>
</html>
`, siteName, resetURL, resetURL)
}

View File

@@ -305,6 +305,19 @@ func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64,
return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, accountID, stickySessionTTL)
}
// GetCachedSessionAccountID retrieves the account ID bound to a sticky session.
// Returns 0 if no binding exists or on error.
func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID *int64, sessionHash string) (int64, error) {
if sessionHash == "" || s.cache == nil {
return 0, nil
}
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err != nil {
return 0, err
}
return accountID, nil
}
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
if parsed == nil {
return ""
@@ -3359,19 +3372,12 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
} `json:"usage"`
}
if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" {
// output_tokens 总是从 message_delta 获取
// message_delta 是推理结束后的最终统计,应完全覆盖 message_start 的数据
// 这对于 Claude API 和 GLM 等兼容 API 都是正确的行为
usage.InputTokens = msgDelta.Usage.InputTokens
usage.OutputTokens = msgDelta.Usage.OutputTokens
// 如果 message_start 中没有值,则从 message_delta 获取兼容GLM等API
if usage.InputTokens == 0 {
usage.InputTokens = msgDelta.Usage.InputTokens
}
if usage.CacheCreationInputTokens == 0 {
usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens
}
if usage.CacheReadInputTokens == 0 {
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
}
usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
}
}

View File

@@ -931,6 +931,13 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
}
// 图片生成计费
imageCount := 0
imageSize := s.extractImageSize(body)
if isImageGenerationModel(originalModel) {
imageCount = 1
}
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
@@ -938,6 +945,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
Stream: req.Stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ImageCount: imageCount,
ImageSize: imageSize,
}, nil
}
@@ -1371,6 +1380,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
usage = &ClaudeUsage{}
}
// 图片生成计费
imageCount := 0
imageSize := s.extractImageSize(body)
if isImageGenerationModel(originalModel) {
imageCount = 1
}
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
@@ -1378,6 +1394,8 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
Stream: stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ImageCount: imageCount,
ImageSize: imageSize,
}, nil
}
@@ -1972,6 +1990,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
var last map[string]any
var lastWithParts map[string]any
var collectedTextParts []string // Collect all text parts for aggregation
usage := &ClaudeUsage{}
for {
@@ -1983,7 +2002,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
switch payload {
case "", "[DONE]":
if payload == "[DONE]" {
return pickGeminiCollectResult(last, lastWithParts), usage, nil
return mergeCollectedTextParts(pickGeminiCollectResult(last, lastWithParts), collectedTextParts), usage, nil
}
default:
var parsed map[string]any
@@ -2002,6 +2021,12 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
}
if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed
// Collect text from each part for aggregation
for _, part := range parts {
if text, ok := part["text"].(string); ok && text != "" {
collectedTextParts = append(collectedTextParts, text)
}
}
}
}
}
@@ -2016,7 +2041,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
}
}
return pickGeminiCollectResult(last, lastWithParts), usage, nil
return mergeCollectedTextParts(pickGeminiCollectResult(last, lastWithParts), collectedTextParts), usage, nil
}
func pickGeminiCollectResult(last map[string]any, lastWithParts map[string]any) map[string]any {
@@ -2029,6 +2054,83 @@ func pickGeminiCollectResult(last map[string]any, lastWithParts map[string]any)
return map[string]any{}
}
// mergeCollectedTextParts merges all collected text chunks into the final response.
// This fixes the issue where non-streaming responses only returned the last chunk
// instead of the complete aggregated text.
func mergeCollectedTextParts(response map[string]any, textParts []string) map[string]any {
if len(textParts) == 0 {
return response
}
// Join all text parts
mergedText := strings.Join(textParts, "")
// Deep copy response
result := make(map[string]any)
for k, v := range response {
result[k] = v
}
// Get or create candidates
candidates, ok := result["candidates"].([]any)
if !ok || len(candidates) == 0 {
candidates = []any{map[string]any{}}
}
// Get first candidate
candidate, ok := candidates[0].(map[string]any)
if !ok {
candidate = make(map[string]any)
candidates[0] = candidate
}
// Get or create content
content, ok := candidate["content"].(map[string]any)
if !ok {
content = map[string]any{"role": "model"}
candidate["content"] = content
}
// Get existing parts
existingParts, ok := content["parts"].([]any)
if !ok {
existingParts = []any{}
}
// Find and update first text part, or create new one
newParts := make([]any, 0, len(existingParts)+1)
textUpdated := false
for _, p := range existingParts {
pm, ok := p.(map[string]any)
if !ok {
newParts = append(newParts, p)
continue
}
if _, hasText := pm["text"]; hasText && !textUpdated {
// Replace with merged text
newPart := make(map[string]any)
for k, v := range pm {
newPart[k] = v
}
newPart["text"] = mergedText
newParts = append(newParts, newPart)
textUpdated = true
} else {
newParts = append(newParts, pm)
}
}
if !textUpdated {
newParts = append([]any{map[string]any{"text": mergedText}}, newParts...)
}
content["parts"] = newParts
result["candidates"] = candidates
return result
}
type geminiNativeStreamResult struct {
usage *ClaudeUsage
firstTokenMs *int
@@ -2947,3 +3049,26 @@ func convertClaudeGenerationConfig(req map[string]any) map[string]any {
}
return out
}
// extractImageSize 从 Gemini 请求中提取 image_size 参数
func (s *GeminiMessagesCompatService) extractImageSize(body []byte) string {
var req struct {
GenerationConfig *struct {
ImageConfig *struct {
ImageSize string `json:"imageSize"`
} `json:"imageConfig"`
} `json:"generationConfig"`
}
if err := json.Unmarshal(body, &req); err != nil {
return "2K"
}
if req.GenerationConfig != nil && req.GenerationConfig.ImageConfig != nil {
size := strings.ToUpper(strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize))
if size == "1K" || size == "2K" || size == "4K" {
return size
}
}
return "2K"
}

View File

@@ -0,0 +1,72 @@
package service
import (
"encoding/json"
)
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中移除 thoughtSignature 字段,
// 以避免跨账号签名验证错误。
//
// 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature
// 会导致新账号的签名验证失败。通过移除这些签名,让新账号重新生成有效的签名。
//
// CleanGeminiNativeThoughtSignatures removes thoughtSignature fields from Gemini native API requests
// to avoid cross-account signature validation errors.
//
// When sticky session switches accounts (e.g., original account becomes unavailable),
// thoughtSignatures from the old account will cause validation failures on the new account.
// By removing these signatures, we allow the new account to generate valid signatures.
func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
if len(body) == 0 {
return body
}
// 解析 JSON
var data any
if err := json.Unmarshal(body, &data); err != nil {
// 如果解析失败,返回原始 body可能不是 JSON 或格式不正确)
return body
}
// 递归清理 thoughtSignature
cleaned := cleanThoughtSignaturesRecursive(data)
// 重新序列化
result, err := json.Marshal(cleaned)
if err != nil {
// 如果序列化失败,返回原始 body
return body
}
return result
}
// cleanThoughtSignaturesRecursive 递归遍历数据结构,移除所有 thoughtSignature 字段
func cleanThoughtSignaturesRecursive(data any) any {
switch v := data.(type) {
case map[string]any:
// 创建新的 map移除 thoughtSignature
result := make(map[string]any, len(v))
for key, value := range v {
// 跳过 thoughtSignature 字段
if key == "thoughtSignature" {
continue
}
// 递归处理嵌套结构
result[key] = cleanThoughtSignaturesRecursive(value)
}
return result
case []any:
// 递归处理数组中的每个元素
result := make([]any, len(v))
for i, item := range v {
result[i] = cleanThoughtSignaturesRecursive(item)
}
return result
default:
// 基本类型string, number, bool, null直接返回
return v
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"log"
"log/slog"
"strconv"
"strings"
"time"
@@ -131,21 +132,32 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
}
// 3) Populate cache with TTL.
// 3) Populate cache with TTL(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if p.tokenCache != nil {
ttl := 30 * time.Minute
if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > geminiTokenCacheSkew:
ttl = until - geminiTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if isStale && latestAccount != nil {
// 版本过时,使用 DB 中的最新 token
slog.Debug("gemini_token_version_stale_use_latest", "account_id", account.ID)
accessToken = latestAccount.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found after version check")
}
// 不写入缓存,让下次请求重新处理
} else {
ttl := 30 * time.Minute
if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > geminiTokenCacheSkew:
ttl = until - geminiTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
return accessToken, nil

View File

@@ -122,6 +122,7 @@ type TokenInfo struct {
Scope string `json:"scope,omitempty"`
OrgUUID string `json:"org_uuid,omitempty"`
AccountUUID string `json:"account_uuid,omitempty"`
EmailAddress string `json:"email_address,omitempty"`
}
// ExchangeCode exchanges authorization code for tokens
@@ -252,9 +253,15 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif
tokenInfo.OrgUUID = tokenResp.Organization.UUID
log.Printf("[OAuth] Got org_uuid: %s", tokenInfo.OrgUUID)
}
if tokenResp.Account != nil && tokenResp.Account.UUID != "" {
tokenInfo.AccountUUID = tokenResp.Account.UUID
log.Printf("[OAuth] Got account_uuid: %s", tokenInfo.AccountUUID)
if tokenResp.Account != nil {
if tokenResp.Account.UUID != "" {
tokenInfo.AccountUUID = tokenResp.Account.UUID
log.Printf("[OAuth] Got account_uuid: %s", tokenInfo.AccountUUID)
}
if tokenResp.Account.EmailAddress != "" {
tokenInfo.EmailAddress = tokenResp.Account.EmailAddress
log.Printf("[OAuth] Got email_address: %s", tokenInfo.EmailAddress)
}
}
return tokenInfo, nil

View File

@@ -60,6 +60,92 @@ type OpenAICodexUsageSnapshot struct {
UpdatedAt string `json:"updated_at,omitempty"`
}
// NormalizedCodexLimits contains normalized 5h/7d rate limit data
type NormalizedCodexLimits struct {
Used5hPercent *float64
Reset5hSeconds *int
Window5hMinutes *int
Used7dPercent *float64
Reset7dSeconds *int
Window7dMinutes *int
}
// Normalize converts primary/secondary fields to canonical 5h/7d fields.
// Strategy: Compare window_minutes to determine which is 5h vs 7d.
// Returns nil if snapshot is nil or has no useful data.
func (s *OpenAICodexUsageSnapshot) Normalize() *NormalizedCodexLimits {
if s == nil {
return nil
}
result := &NormalizedCodexLimits{}
primaryMins := 0
secondaryMins := 0
hasPrimaryWindow := false
hasSecondaryWindow := false
if s.PrimaryWindowMinutes != nil {
primaryMins = *s.PrimaryWindowMinutes
hasPrimaryWindow = true
}
if s.SecondaryWindowMinutes != nil {
secondaryMins = *s.SecondaryWindowMinutes
hasSecondaryWindow = true
}
// Determine mapping based on window_minutes
use5hFromPrimary := false
use7dFromPrimary := false
if hasPrimaryWindow && hasSecondaryWindow {
// Both known: smaller window is 5h, larger is 7d
if primaryMins < secondaryMins {
use5hFromPrimary = true
} else {
use7dFromPrimary = true
}
} else if hasPrimaryWindow {
// Only primary known: classify by threshold (<=360 min = 6h -> 5h window)
if primaryMins <= 360 {
use5hFromPrimary = true
} else {
use7dFromPrimary = true
}
} else if hasSecondaryWindow {
// Only secondary known: classify by threshold
if secondaryMins <= 360 {
// 5h from secondary, so primary (if any data) is 7d
use7dFromPrimary = true
} else {
// 7d from secondary, so primary (if any data) is 5h
use5hFromPrimary = true
}
} else {
// No window_minutes: fall back to legacy assumption (primary=7d, secondary=5h)
use7dFromPrimary = true
}
// Assign values
if use5hFromPrimary {
result.Used5hPercent = s.PrimaryUsedPercent
result.Reset5hSeconds = s.PrimaryResetAfterSeconds
result.Window5hMinutes = s.PrimaryWindowMinutes
result.Used7dPercent = s.SecondaryUsedPercent
result.Reset7dSeconds = s.SecondaryResetAfterSeconds
result.Window7dMinutes = s.SecondaryWindowMinutes
} else if use7dFromPrimary {
result.Used7dPercent = s.PrimaryUsedPercent
result.Reset7dSeconds = s.PrimaryResetAfterSeconds
result.Window7dMinutes = s.PrimaryWindowMinutes
result.Used5hPercent = s.SecondaryUsedPercent
result.Reset5hSeconds = s.SecondaryResetAfterSeconds
result.Window5hMinutes = s.SecondaryWindowMinutes
}
return result
}
// OpenAIUsage represents OpenAI API response usage
type OpenAIUsage struct {
InputTokens int `json:"input_tokens"`
@@ -867,7 +953,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if account.Type == AccountTypeOAuth {
if snapshot := extractCodexUsageHeaders(resp.Header); snapshot != nil {
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
}
}
@@ -1665,8 +1751,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
return nil
}
// extractCodexUsageHeaders extracts Codex usage limits from response headers
func extractCodexUsageHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
// ParseCodexRateLimitHeaders extracts Codex usage limits from response headers.
// Exported for use in ratelimit_service when handling OpenAI 429 responses.
func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
snapshot := &OpenAICodexUsageSnapshot{}
hasData := false
@@ -1740,6 +1827,8 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
// Convert snapshot to map for merging into Extra
updates := make(map[string]any)
// Save raw primary/secondary fields for debugging/tracing
if snapshot.PrimaryUsedPercent != nil {
updates["codex_primary_used_percent"] = *snapshot.PrimaryUsedPercent
}
@@ -1763,109 +1852,25 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
}
updates["codex_usage_updated_at"] = snapshot.UpdatedAt
// Normalize to canonical 5h/7d fields based on window_minutes
// This fixes the issue where OpenAI's primary/secondary naming is reversed
// Strategy: Compare the two windows and assign the smaller one to 5h, larger one to 7d
// IMPORTANT: We can only reliably determine window type from window_minutes field
// The reset_after_seconds is remaining time, not window size, so it cannot be used for comparison
var primaryWindowMins, secondaryWindowMins int
var hasPrimaryWindow, hasSecondaryWindow bool
// Only use window_minutes for reliable window size comparison
if snapshot.PrimaryWindowMinutes != nil {
primaryWindowMins = *snapshot.PrimaryWindowMinutes
hasPrimaryWindow = true
}
if snapshot.SecondaryWindowMinutes != nil {
secondaryWindowMins = *snapshot.SecondaryWindowMinutes
hasSecondaryWindow = true
}
// Determine which is 5h and which is 7d
var use5hFromPrimary, use7dFromPrimary bool
var use5hFromSecondary, use7dFromSecondary bool
if hasPrimaryWindow && hasSecondaryWindow {
// Both window sizes known: compare and assign smaller to 5h, larger to 7d
if primaryWindowMins < secondaryWindowMins {
use5hFromPrimary = true
use7dFromSecondary = true
} else {
use5hFromSecondary = true
use7dFromPrimary = true
// Normalize to canonical 5h/7d fields
if normalized := snapshot.Normalize(); normalized != nil {
if normalized.Used5hPercent != nil {
updates["codex_5h_used_percent"] = *normalized.Used5hPercent
}
} else if hasPrimaryWindow {
// Only primary window size known: classify by absolute threshold
if primaryWindowMins <= 360 {
use5hFromPrimary = true
} else {
use7dFromPrimary = true
if normalized.Reset5hSeconds != nil {
updates["codex_5h_reset_after_seconds"] = *normalized.Reset5hSeconds
}
} else if hasSecondaryWindow {
// Only secondary window size known: classify by absolute threshold
if secondaryWindowMins <= 360 {
use5hFromSecondary = true
} else {
use7dFromSecondary = true
if normalized.Window5hMinutes != nil {
updates["codex_5h_window_minutes"] = *normalized.Window5hMinutes
}
} else {
// No window_minutes available: cannot reliably determine window types
// Fall back to legacy assumption (may be incorrect)
// Assume primary=7d, secondary=5h based on historical observation
if snapshot.SecondaryUsedPercent != nil || snapshot.SecondaryResetAfterSeconds != nil || snapshot.SecondaryWindowMinutes != nil {
use5hFromSecondary = true
if normalized.Used7dPercent != nil {
updates["codex_7d_used_percent"] = *normalized.Used7dPercent
}
if snapshot.PrimaryUsedPercent != nil || snapshot.PrimaryResetAfterSeconds != nil || snapshot.PrimaryWindowMinutes != nil {
use7dFromPrimary = true
if normalized.Reset7dSeconds != nil {
updates["codex_7d_reset_after_seconds"] = *normalized.Reset7dSeconds
}
}
// Write canonical 5h fields
if use5hFromPrimary {
if snapshot.PrimaryUsedPercent != nil {
updates["codex_5h_used_percent"] = *snapshot.PrimaryUsedPercent
}
if snapshot.PrimaryResetAfterSeconds != nil {
updates["codex_5h_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
}
if snapshot.PrimaryWindowMinutes != nil {
updates["codex_5h_window_minutes"] = *snapshot.PrimaryWindowMinutes
}
} else if use5hFromSecondary {
if snapshot.SecondaryUsedPercent != nil {
updates["codex_5h_used_percent"] = *snapshot.SecondaryUsedPercent
}
if snapshot.SecondaryResetAfterSeconds != nil {
updates["codex_5h_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
}
if snapshot.SecondaryWindowMinutes != nil {
updates["codex_5h_window_minutes"] = *snapshot.SecondaryWindowMinutes
}
}
// Write canonical 7d fields
if use7dFromPrimary {
if snapshot.PrimaryUsedPercent != nil {
updates["codex_7d_used_percent"] = *snapshot.PrimaryUsedPercent
}
if snapshot.PrimaryResetAfterSeconds != nil {
updates["codex_7d_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
}
if snapshot.PrimaryWindowMinutes != nil {
updates["codex_7d_window_minutes"] = *snapshot.PrimaryWindowMinutes
}
} else if use7dFromSecondary {
if snapshot.SecondaryUsedPercent != nil {
updates["codex_7d_used_percent"] = *snapshot.SecondaryUsedPercent
}
if snapshot.SecondaryResetAfterSeconds != nil {
updates["codex_7d_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
}
if snapshot.SecondaryWindowMinutes != nil {
updates["codex_7d_window_minutes"] = *snapshot.SecondaryWindowMinutes
if normalized.Window7dMinutes != nil {
updates["codex_7d_window_minutes"] = *normalized.Window7dMinutes
}
}

View File

@@ -2,9 +2,10 @@ package service
import (
"context"
"fmt"
"net/http"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
)
@@ -35,12 +36,12 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
// Generate PKCE values
state, err := openai.GenerateState()
if err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err)
return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_STATE_FAILED", "failed to generate state: %v", err)
}
codeVerifier, err := openai.GenerateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_VERIFIER_FAILED", "failed to generate code verifier: %v", err)
}
codeChallenge := openai.GenerateCodeChallenge(codeVerifier)
@@ -48,14 +49,17 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
// Generate session ID
sessionID, err := openai.GenerateSessionID()
if err != nil {
return nil, fmt.Errorf("failed to generate session ID: %w", err)
return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_SESSION_FAILED", "failed to generate session ID: %v", err)
}
// Get proxy URL if specified
var proxyURL string
if proxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err == nil && proxy != nil {
if err != nil {
return nil, infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
}
if proxy != nil {
proxyURL = proxy.URL()
}
}
@@ -110,14 +114,17 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
// Get session
session, ok := s.sessionStore.Get(input.SessionID)
if !ok {
return nil, fmt.Errorf("session not found or expired")
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired")
}
// Get proxy URL
// Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL
proxyURL := session.ProxyURL
if input.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
if err == nil && proxy != nil {
if err != nil {
return nil, infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
}
if proxy != nil {
proxyURL = proxy.URL()
}
}
@@ -131,7 +138,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
// Exchange code for token
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL)
if err != nil {
return nil, fmt.Errorf("failed to exchange code: %w", err)
return nil, err
}
// Parse ID token to get user info
@@ -201,12 +208,12 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
// RefreshAccountToken refreshes token for an OpenAI account
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
if !account.IsOpenAI() {
return nil, fmt.Errorf("account is not an OpenAI account")
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account")
}
refreshToken := account.GetOpenAIRefreshToken()
if refreshToken == "" {
return nil, fmt.Errorf("no refresh token available")
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
}
var proxyURL string

View File

@@ -162,26 +162,37 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
return "", errors.New("access_token not found in credentials")
}
// 3. 存入缓存
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if p.tokenCache != nil {
ttl := 30 * time.Minute
if refreshFailed {
// 刷新失败时使用短 TTL避免失效 token 长时间缓存导致 401 抖动
ttl = time.Minute
slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
} else if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > openAITokenCacheSkew:
ttl = until - openAITokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if isStale && latestAccount != nil {
// 版本过时,使用 DB 中的最新 token
slog.Debug("openai_token_version_stale_use_latest", "account_id", account.ID)
accessToken = latestAccount.GetOpenAIAccessToken()
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found after version check")
}
// 不写入缓存,让下次请求重新处理
} else {
ttl := 30 * time.Minute
if refreshFailed {
// 刷新失败时使用短 TTL避免失效 token 长时间缓存导致 401 抖动
ttl = time.Minute
slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
} else if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > openAITokenCacheSkew:
ttl = until - openAITokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
slog.Warn("openai_token_cache_set_failed", "account_id", account.ID, "error", err)
}
}
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
slog.Warn("openai_token_cache_set_failed", "account_id", account.ID, "error", err)
}
}

View File

@@ -343,9 +343,48 @@ func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *A
// handle429 处理429限流错误
// 解析响应头获取重置时间,标记账号为限流状态
func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header, responseBody []byte) {
// 解析重置时间戳
// 1. OpenAI 平台:优先尝试解析 x-codex-* 响应头(用于 rate_limit_exceeded
if account.Platform == PlatformOpenAI {
if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil {
if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
}
slog.Info("openai_account_rate_limited", "account_id", account.ID, "reset_at", *resetAt)
return
}
}
// 2. 尝试从响应头解析重置时间Anthropic
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
// 3. 如果响应头没有尝试从响应体解析OpenAI usage_limit_reached, Gemini
if resetTimestamp == "" {
switch account.Platform {
case PlatformOpenAI:
// 尝试解析 OpenAI 的 usage_limit_reached 错误
if resetAt := parseOpenAIRateLimitResetTime(responseBody); resetAt != nil {
resetTime := time.Unix(*resetAt, 0)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
}
slog.Info("account_rate_limited", "account_id", account.ID, "platform", account.Platform, "reset_at", resetTime, "reset_in", time.Until(resetTime).Truncate(time.Second))
return
}
case PlatformGemini, PlatformAntigravity:
// 尝试解析 Gemini 格式(用于其他平台)
if resetAt := ParseGeminiRateLimitResetTime(responseBody); resetAt != nil {
resetTime := time.Unix(*resetAt, 0)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
}
slog.Info("account_rate_limited", "account_id", account.ID, "platform", account.Platform, "reset_at", resetTime, "reset_in", time.Until(resetTime).Truncate(time.Second))
return
}
}
// 没有重置时间使用默认5分钟
resetAt := time.Now().Add(5 * time.Minute)
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
@@ -356,6 +395,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
}
return
}
slog.Warn("rate_limit_no_reset_time", "account_id", account.ID, "platform", account.Platform, "using_default", "5m")
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
}
@@ -419,6 +459,108 @@ func (s *RateLimitService) shouldScopeClaudeSonnetRateLimit(account *Account, re
return strings.Contains(msg, "sonnet")
}
// calculateOpenAI429ResetTime 从 OpenAI 429 响应头计算正确的重置时间
// 返回 nil 表示无法从响应头中确定重置时间
func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time {
snapshot := ParseCodexRateLimitHeaders(headers)
if snapshot == nil {
return nil
}
normalized := snapshot.Normalize()
if normalized == nil {
return nil
}
now := time.Now()
// 判断哪个限制被触发used_percent >= 100
is7dExhausted := normalized.Used7dPercent != nil && *normalized.Used7dPercent >= 100
is5hExhausted := normalized.Used5hPercent != nil && *normalized.Used5hPercent >= 100
// 优先使用被触发限制的重置时间
if is7dExhausted && normalized.Reset7dSeconds != nil {
resetAt := now.Add(time.Duration(*normalized.Reset7dSeconds) * time.Second)
slog.Info("openai_429_7d_limit_exhausted", "reset_after_seconds", *normalized.Reset7dSeconds, "reset_at", resetAt)
return &resetAt
}
if is5hExhausted && normalized.Reset5hSeconds != nil {
resetAt := now.Add(time.Duration(*normalized.Reset5hSeconds) * time.Second)
slog.Info("openai_429_5h_limit_exhausted", "reset_after_seconds", *normalized.Reset5hSeconds, "reset_at", resetAt)
return &resetAt
}
// 都未达到100%但收到429使用较长的重置时间
var maxResetSecs int
if normalized.Reset7dSeconds != nil && *normalized.Reset7dSeconds > maxResetSecs {
maxResetSecs = *normalized.Reset7dSeconds
}
if normalized.Reset5hSeconds != nil && *normalized.Reset5hSeconds > maxResetSecs {
maxResetSecs = *normalized.Reset5hSeconds
}
if maxResetSecs > 0 {
resetAt := now.Add(time.Duration(maxResetSecs) * time.Second)
slog.Info("openai_429_using_max_reset", "max_reset_seconds", maxResetSecs, "reset_at", resetAt)
return &resetAt
}
return nil
}
// parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳
// OpenAI 的 usage_limit_reached 错误格式:
//
// {
// "error": {
// "message": "The usage limit has been reached",
// "type": "usage_limit_reached",
// "resets_at": 1769404154,
// "resets_in_seconds": 133107
// }
// }
func parseOpenAIRateLimitResetTime(body []byte) *int64 {
var parsed map[string]any
if err := json.Unmarshal(body, &parsed); err != nil {
return nil
}
errObj, ok := parsed["error"].(map[string]any)
if !ok {
return nil
}
// 检查是否为 usage_limit_reached 或 rate_limit_exceeded 类型
errType, _ := errObj["type"].(string)
if errType != "usage_limit_reached" && errType != "rate_limit_exceeded" {
return nil
}
// 优先使用 resets_atUnix 时间戳)
if resetsAt, ok := errObj["resets_at"].(float64); ok {
ts := int64(resetsAt)
return &ts
}
if resetsAt, ok := errObj["resets_at"].(string); ok {
if ts, err := strconv.ParseInt(resetsAt, 10, 64); err == nil {
return &ts
}
}
// 如果没有 resets_at尝试使用 resets_in_seconds
if resetsInSeconds, ok := errObj["resets_in_seconds"].(float64); ok {
ts := time.Now().Unix() + int64(resetsInSeconds)
return &ts
}
if resetsInSeconds, ok := errObj["resets_in_seconds"].(string); ok {
if sec, err := strconv.ParseInt(resetsInSeconds, 10, 64); err == nil {
ts := time.Now().Unix() + sec
return &ts
}
}
return nil
}
// handle529 处理529过载错误
// 根据配置设置过载冷却时间
func (s *RateLimitService) handle529(ctx context.Context, account *Account) {

View File

@@ -0,0 +1,364 @@
package service
import (
"net/http"
"testing"
"time"
)
func TestCalculateOpenAI429ResetTime_7dExhausted(t *testing.T) {
svc := &RateLimitService{}
// Simulate headers when 7d limit is exhausted (100% used)
// Primary = 7d (10080 minutes), Secondary = 5h (300 minutes)
headers := http.Header{}
headers.Set("x-codex-primary-used-percent", "100")
headers.Set("x-codex-primary-reset-after-seconds", "384607") // ~4.5 days
headers.Set("x-codex-primary-window-minutes", "10080") // 7 days
headers.Set("x-codex-secondary-used-percent", "3")
headers.Set("x-codex-secondary-reset-after-seconds", "17369") // ~4.8 hours
headers.Set("x-codex-secondary-window-minutes", "300") // 5 hours
before := time.Now()
resetAt := svc.calculateOpenAI429ResetTime(headers)
after := time.Now()
if resetAt == nil {
t.Fatal("expected non-nil resetAt")
}
// Should be approximately 384607 seconds from now
expectedDuration := 384607 * time.Second
minExpected := before.Add(expectedDuration)
maxExpected := after.Add(expectedDuration)
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
}
}
func TestCalculateOpenAI429ResetTime_5hExhausted(t *testing.T) {
svc := &RateLimitService{}
// Simulate headers when 5h limit is exhausted (100% used)
headers := http.Header{}
headers.Set("x-codex-primary-used-percent", "50")
headers.Set("x-codex-primary-reset-after-seconds", "500000")
headers.Set("x-codex-primary-window-minutes", "10080") // 7 days
headers.Set("x-codex-secondary-used-percent", "100")
headers.Set("x-codex-secondary-reset-after-seconds", "3600") // 1 hour
headers.Set("x-codex-secondary-window-minutes", "300") // 5 hours
before := time.Now()
resetAt := svc.calculateOpenAI429ResetTime(headers)
after := time.Now()
if resetAt == nil {
t.Fatal("expected non-nil resetAt")
}
// Should be approximately 3600 seconds from now
expectedDuration := 3600 * time.Second
minExpected := before.Add(expectedDuration)
maxExpected := after.Add(expectedDuration)
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
}
}
func TestCalculateOpenAI429ResetTime_NeitherExhausted_UsesMax(t *testing.T) {
svc := &RateLimitService{}
// Neither limit at 100%, should use the longer reset time
headers := http.Header{}
headers.Set("x-codex-primary-used-percent", "80")
headers.Set("x-codex-primary-reset-after-seconds", "100000")
headers.Set("x-codex-primary-window-minutes", "10080")
headers.Set("x-codex-secondary-used-percent", "90")
headers.Set("x-codex-secondary-reset-after-seconds", "5000")
headers.Set("x-codex-secondary-window-minutes", "300")
before := time.Now()
resetAt := svc.calculateOpenAI429ResetTime(headers)
after := time.Now()
if resetAt == nil {
t.Fatal("expected non-nil resetAt")
}
// Should use the max (100000 seconds from 7d window)
expectedDuration := 100000 * time.Second
minExpected := before.Add(expectedDuration)
maxExpected := after.Add(expectedDuration)
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
}
}
func TestCalculateOpenAI429ResetTime_NoCodexHeaders(t *testing.T) {
svc := &RateLimitService{}
// No codex headers at all
headers := http.Header{}
headers.Set("content-type", "application/json")
resetAt := svc.calculateOpenAI429ResetTime(headers)
if resetAt != nil {
t.Errorf("expected nil resetAt when no codex headers, got %v", resetAt)
}
}
func TestCalculateOpenAI429ResetTime_ReversedWindowOrder(t *testing.T) {
svc := &RateLimitService{}
// Test when OpenAI sends primary as 5h and secondary as 7d (reversed)
headers := http.Header{}
headers.Set("x-codex-primary-used-percent", "100") // This is 5h
headers.Set("x-codex-primary-reset-after-seconds", "3600") // 1 hour
headers.Set("x-codex-primary-window-minutes", "300") // 5 hours - smaller!
headers.Set("x-codex-secondary-used-percent", "50")
headers.Set("x-codex-secondary-reset-after-seconds", "500000")
headers.Set("x-codex-secondary-window-minutes", "10080") // 7 days - larger!
before := time.Now()
resetAt := svc.calculateOpenAI429ResetTime(headers)
after := time.Now()
if resetAt == nil {
t.Fatal("expected non-nil resetAt")
}
// Should correctly identify that primary is 5h (smaller window) and use its reset time
expectedDuration := 3600 * time.Second
minExpected := before.Add(expectedDuration)
maxExpected := after.Add(expectedDuration)
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
}
}
func TestNormalizedCodexLimits(t *testing.T) {
// Test the Normalize() method directly
pUsed := 100.0
pReset := 384607
pWindow := 10080
sUsed := 3.0
sReset := 17369
sWindow := 300
snapshot := &OpenAICodexUsageSnapshot{
PrimaryUsedPercent: &pUsed,
PrimaryResetAfterSeconds: &pReset,
PrimaryWindowMinutes: &pWindow,
SecondaryUsedPercent: &sUsed,
SecondaryResetAfterSeconds: &sReset,
SecondaryWindowMinutes: &sWindow,
}
normalized := snapshot.Normalize()
if normalized == nil {
t.Fatal("expected non-nil normalized")
}
// Primary has larger window (10080 > 300), so primary should be 7d
if normalized.Used7dPercent == nil || *normalized.Used7dPercent != 100.0 {
t.Errorf("expected Used7dPercent=100, got %v", normalized.Used7dPercent)
}
if normalized.Reset7dSeconds == nil || *normalized.Reset7dSeconds != 384607 {
t.Errorf("expected Reset7dSeconds=384607, got %v", normalized.Reset7dSeconds)
}
if normalized.Used5hPercent == nil || *normalized.Used5hPercent != 3.0 {
t.Errorf("expected Used5hPercent=3, got %v", normalized.Used5hPercent)
}
if normalized.Reset5hSeconds == nil || *normalized.Reset5hSeconds != 17369 {
t.Errorf("expected Reset5hSeconds=17369, got %v", normalized.Reset5hSeconds)
}
}
func TestNormalizedCodexLimits_OnlyPrimaryData(t *testing.T) {
// Test when only primary has data, no window_minutes
pUsed := 80.0
pReset := 50000
snapshot := &OpenAICodexUsageSnapshot{
PrimaryUsedPercent: &pUsed,
PrimaryResetAfterSeconds: &pReset,
// No window_minutes, no secondary data
}
normalized := snapshot.Normalize()
if normalized == nil {
t.Fatal("expected non-nil normalized")
}
// Legacy assumption: primary=7d, secondary=5h
if normalized.Used7dPercent == nil || *normalized.Used7dPercent != 80.0 {
t.Errorf("expected Used7dPercent=80, got %v", normalized.Used7dPercent)
}
if normalized.Reset7dSeconds == nil || *normalized.Reset7dSeconds != 50000 {
t.Errorf("expected Reset7dSeconds=50000, got %v", normalized.Reset7dSeconds)
}
// Secondary (5h) should be nil
if normalized.Used5hPercent != nil {
t.Errorf("expected Used5hPercent=nil, got %v", *normalized.Used5hPercent)
}
if normalized.Reset5hSeconds != nil {
t.Errorf("expected Reset5hSeconds=nil, got %v", *normalized.Reset5hSeconds)
}
}
func TestNormalizedCodexLimits_OnlySecondaryData(t *testing.T) {
// Test when only secondary has data, no window_minutes
sUsed := 60.0
sReset := 3000
snapshot := &OpenAICodexUsageSnapshot{
SecondaryUsedPercent: &sUsed,
SecondaryResetAfterSeconds: &sReset,
// No window_minutes, no primary data
}
normalized := snapshot.Normalize()
if normalized == nil {
t.Fatal("expected non-nil normalized")
}
// Legacy assumption: primary=7d, secondary=5h
// So secondary goes to 5h
if normalized.Used5hPercent == nil || *normalized.Used5hPercent != 60.0 {
t.Errorf("expected Used5hPercent=60, got %v", normalized.Used5hPercent)
}
if normalized.Reset5hSeconds == nil || *normalized.Reset5hSeconds != 3000 {
t.Errorf("expected Reset5hSeconds=3000, got %v", normalized.Reset5hSeconds)
}
// Primary (7d) should be nil
if normalized.Used7dPercent != nil {
t.Errorf("expected Used7dPercent=nil, got %v", *normalized.Used7dPercent)
}
}
func TestNormalizedCodexLimits_BothDataNoWindowMinutes(t *testing.T) {
// Test when both have data but no window_minutes
pUsed := 100.0
pReset := 400000
sUsed := 50.0
sReset := 10000
snapshot := &OpenAICodexUsageSnapshot{
PrimaryUsedPercent: &pUsed,
PrimaryResetAfterSeconds: &pReset,
SecondaryUsedPercent: &sUsed,
SecondaryResetAfterSeconds: &sReset,
// No window_minutes
}
normalized := snapshot.Normalize()
if normalized == nil {
t.Fatal("expected non-nil normalized")
}
// Legacy assumption: primary=7d, secondary=5h
if normalized.Used7dPercent == nil || *normalized.Used7dPercent != 100.0 {
t.Errorf("expected Used7dPercent=100, got %v", normalized.Used7dPercent)
}
if normalized.Reset7dSeconds == nil || *normalized.Reset7dSeconds != 400000 {
t.Errorf("expected Reset7dSeconds=400000, got %v", normalized.Reset7dSeconds)
}
if normalized.Used5hPercent == nil || *normalized.Used5hPercent != 50.0 {
t.Errorf("expected Used5hPercent=50, got %v", normalized.Used5hPercent)
}
if normalized.Reset5hSeconds == nil || *normalized.Reset5hSeconds != 10000 {
t.Errorf("expected Reset5hSeconds=10000, got %v", normalized.Reset5hSeconds)
}
}
func TestHandle429_AnthropicPlatformUnaffected(t *testing.T) {
// Verify that Anthropic platform accounts still use the original logic
// This test ensures we don't break existing Claude account rate limiting
svc := &RateLimitService{}
// Simulate Anthropic 429 headers
headers := http.Header{}
headers.Set("anthropic-ratelimit-unified-reset", "1737820800") // A future Unix timestamp
// For Anthropic platform, calculateOpenAI429ResetTime should return nil
// because it only handles OpenAI platform
resetAt := svc.calculateOpenAI429ResetTime(headers)
// Should return nil since there are no x-codex-* headers
if resetAt != nil {
t.Errorf("expected nil for Anthropic headers, got %v", resetAt)
}
}
func TestCalculateOpenAI429ResetTime_UserProvidedScenario(t *testing.T) {
// This is the exact scenario from the user:
// codex_7d_used_percent: 100
// codex_7d_reset_after_seconds: 384607 (约4.5天后重置)
// codex_5h_used_percent: 3
// codex_5h_reset_after_seconds: 17369 (约4.8小时后重置)
svc := &RateLimitService{}
// Simulate headers matching user's data
// Note: We need to map the canonical 5h/7d back to primary/secondary
// Based on typical OpenAI behavior: primary=7d (larger window), secondary=5h (smaller window)
headers := http.Header{}
headers.Set("x-codex-primary-used-percent", "100")
headers.Set("x-codex-primary-reset-after-seconds", "384607")
headers.Set("x-codex-primary-window-minutes", "10080") // 7 days = 10080 minutes
headers.Set("x-codex-secondary-used-percent", "3")
headers.Set("x-codex-secondary-reset-after-seconds", "17369")
headers.Set("x-codex-secondary-window-minutes", "300") // 5 hours = 300 minutes
before := time.Now()
resetAt := svc.calculateOpenAI429ResetTime(headers)
after := time.Now()
if resetAt == nil {
t.Fatal("expected non-nil resetAt for user scenario")
}
// Should use the 7d reset time (384607 seconds) since 7d limit is exhausted (100%)
expectedDuration := 384607 * time.Second
minExpected := before.Add(expectedDuration)
maxExpected := after.Add(expectedDuration)
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
}
// Verify it's approximately 4.45 days (384607 seconds)
duration := resetAt.Sub(before)
actualDays := duration.Hours() / 24.0
// 384607 / 86400 = ~4.45 days
if actualDays < 4.4 || actualDays > 4.5 {
t.Errorf("expected ~4.45 days, got %.2f days", actualDays)
}
t.Logf("User scenario: reset_at=%v, duration=%.2f days", resetAt, actualDays)
}
func TestCalculateOpenAI429ResetTime_5MinFallbackWhenNoReset(t *testing.T) {
// Test that we return nil when there's used_percent but no reset_after_seconds
// This should cause the caller to use the default 5-minute fallback
svc := &RateLimitService{}
headers := http.Header{}
headers.Set("x-codex-primary-used-percent", "100")
// No reset_after_seconds!
resetAt := svc.calculateOpenAI429ResetTime(headers)
// Should return nil since there's no reset time available
if resetAt != nil {
t.Errorf("expected nil when no reset_after_seconds, got %v", resetAt)
}
}

View File

@@ -61,6 +61,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyRegistrationEnabled,
SettingKeyEmailVerifyEnabled,
SettingKeyPromoCodeEnabled,
SettingKeyPasswordResetEnabled,
SettingKeyTotpEnabled,
SettingKeyTurnstileEnabled,
SettingKeyTurnstileSiteKey,
SettingKeySiteName,
@@ -86,21 +88,27 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled
}
// Password reset requires email verification to be enabled
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
passwordResetEnabled := emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true"
return &PublicSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
APIBaseURL: settings[SettingKeyAPIBaseURL],
ContactInfo: settings[SettingKeyContactInfo],
DocURL: settings[SettingKeyDocURL],
HomeContent: settings[SettingKeyHomeContent],
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
LinuxDoOAuthEnabled: linuxDoEnabled,
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: emailVerifyEnabled,
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
PasswordResetEnabled: passwordResetEnabled,
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
APIBaseURL: settings[SettingKeyAPIBaseURL],
ContactInfo: settings[SettingKeyContactInfo],
DocURL: settings[SettingKeyDocURL],
HomeContent: settings[SettingKeyHomeContent],
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
LinuxDoOAuthEnabled: linuxDoEnabled,
}, nil
}
@@ -125,37 +133,41 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
// Return a struct that matches the frontend's expected format
return &struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo,omitempty"`
SiteSubtitle string `json:"site_subtitle,omitempty"`
APIBaseURL string `json:"api_base_url,omitempty"`
ContactInfo string `json:"contact_info,omitempty"`
DocURL string `json:"doc_url,omitempty"`
HomeContent string `json:"home_content,omitempty"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version,omitempty"`
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
TotpEnabled bool `json:"totp_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo,omitempty"`
SiteSubtitle string `json:"site_subtitle,omitempty"`
APIBaseURL string `json:"api_base_url,omitempty"`
ContactInfo string `json:"contact_info,omitempty"`
DocURL string `json:"doc_url,omitempty"`
HomeContent string `json:"home_content,omitempty"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version,omitempty"`
}{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: s.version,
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
TotpEnabled: settings.TotpEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: s.version,
}, nil
}
@@ -167,6 +179,8 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled)
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled)
updates[SettingKeyPasswordResetEnabled] = strconv.FormatBool(settings.PasswordResetEnabled)
updates[SettingKeyTotpEnabled] = strconv.FormatBool(settings.TotpEnabled)
// 邮件服务设置(只有非空才更新密码)
updates[SettingKeySMTPHost] = settings.SMTPHost
@@ -262,6 +276,35 @@ func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool {
return value != "false"
}
// IsPasswordResetEnabled 检查是否启用密码重置功能
// 要求:必须同时开启邮件验证
func (s *SettingService) IsPasswordResetEnabled(ctx context.Context) bool {
// Password reset requires email verification to be enabled
if !s.IsEmailVerifyEnabled(ctx) {
return false
}
value, err := s.settingRepo.GetValue(ctx, SettingKeyPasswordResetEnabled)
if err != nil {
return false // 默认关闭
}
return value == "true"
}
// IsTotpEnabled 检查是否启用 TOTP 双因素认证功能
func (s *SettingService) IsTotpEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyTotpEnabled)
if err != nil {
return false // 默认关闭
}
return value == "true"
}
// IsTotpEncryptionKeyConfigured 检查 TOTP 加密密钥是否已手动配置
// 只有手动配置了密钥才允许在管理后台启用 TOTP 功能
func (s *SettingService) IsTotpEncryptionKeyConfigured() bool {
return s.cfg.Totp.EncryptionKeyConfigured
}
// GetSiteName 获取网站名称
func (s *SettingService) GetSiteName(ctx context.Context) string {
value, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
@@ -340,10 +383,13 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// parseSettings 解析设置到结构体
func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings {
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
result := &SystemSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
EmailVerifyEnabled: emailVerifyEnabled,
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
PasswordResetEnabled: emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true",
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
SMTPHost: settings[SettingKeySMTPHost],
SMTPUsername: settings[SettingKeySMTPUsername],
SMTPFrom: settings[SettingKeySMTPFrom],

View File

@@ -1,9 +1,11 @@
package service
type SystemSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
PromoCodeEnabled bool
RegistrationEnabled bool
EmailVerifyEnabled bool
PromoCodeEnabled bool
PasswordResetEnabled bool
TotpEnabled bool // TOTP 双因素认证
SMTPHost string
SMTPPort int
@@ -57,21 +59,23 @@ type SystemSettings struct {
}
type PublicSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
PromoCodeEnabled bool
TurnstileEnabled bool
TurnstileSiteKey string
SiteName string
SiteLogo string
SiteSubtitle string
APIBaseURL string
ContactInfo string
DocURL string
HomeContent string
HideCcsImportButton bool
LinuxDoOAuthEnabled bool
Version string
RegistrationEnabled bool
EmailVerifyEnabled bool
PromoCodeEnabled bool
PasswordResetEnabled bool
TotpEnabled bool // TOTP 双因素认证
TurnstileEnabled bool
TurnstileSiteKey string
SiteName string
SiteLogo string
SiteSubtitle string
APIBaseURL string
ContactInfo string
DocURL string
HomeContent string
HideCcsImportButton bool
LinuxDoOAuthEnabled bool
Version string
}
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)

View File

@@ -0,0 +1,71 @@
package service
import (
"context"
"log"
"sync"
"time"
)
// SubscriptionExpiryService periodically updates expired subscription status.
type SubscriptionExpiryService struct {
userSubRepo UserSubscriptionRepository
interval time.Duration
stopCh chan struct{}
stopOnce sync.Once
wg sync.WaitGroup
}
func NewSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, interval time.Duration) *SubscriptionExpiryService {
return &SubscriptionExpiryService{
userSubRepo: userSubRepo,
interval: interval,
stopCh: make(chan struct{}),
}
}
func (s *SubscriptionExpiryService) Start() {
if s == nil || s.userSubRepo == nil || s.interval <= 0 {
return
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
ticker := time.NewTicker(s.interval)
defer ticker.Stop()
s.runOnce()
for {
select {
case <-ticker.C:
s.runOnce()
case <-s.stopCh:
return
}
}
}()
}
func (s *SubscriptionExpiryService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
close(s.stopCh)
})
s.wg.Wait()
}
func (s *SubscriptionExpiryService) runOnce() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
updated, err := s.userSubRepo.BatchUpdateExpiredStatus(ctx)
if err != nil {
log.Printf("[SubscriptionExpiry] Update expired subscriptions failed: %v", err)
return
}
if updated > 0 {
log.Printf("[SubscriptionExpiry] Updated %d expired subscriptions", updated)
}
}

View File

@@ -324,18 +324,31 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
days = -MaxValidityDays
}
now := time.Now()
isExpired := !sub.ExpiresAt.After(now)
// 如果订阅已过期,不允许负向调整
if isExpired && days < 0 {
return nil, infraerrors.BadRequest("CANNOT_SHORTEN_EXPIRED", "cannot shorten an expired subscription")
}
// 计算新的过期时间
newExpiresAt := sub.ExpiresAt.AddDate(0, 0, days)
var newExpiresAt time.Time
if isExpired {
// 已过期:从当前时间开始增加天数
newExpiresAt = now.AddDate(0, 0, days)
} else {
// 未过期:从原过期时间增加/减少天数
newExpiresAt = sub.ExpiresAt.AddDate(0, 0, days)
}
if newExpiresAt.After(MaxExpiresAt) {
newExpiresAt = MaxExpiresAt
}
// 如果是缩短(负数),检查新的过期时间必须大于当前时间
if days < 0 {
now := time.Now()
if !newExpiresAt.After(now) {
return nil, ErrAdjustWouldExpire
}
// 检查新的过期时间必须大于当前时间
if !newExpiresAt.After(now) {
return nil, ErrAdjustWouldExpire
}
if err := s.userSubRepo.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil {
@@ -383,6 +396,7 @@ func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID
return nil, err
}
normalizeExpiredWindows(subs)
normalizeSubscriptionStatus(subs)
return subs, nil
}
@@ -404,17 +418,19 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI
return nil, nil, err
}
normalizeExpiredWindows(subs)
normalizeSubscriptionStatus(subs)
return subs, pag, nil
}
// List 获取所有订阅(分页,支持筛选)
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]UserSubscription, *pagination.PaginationResult, error) {
// List 获取所有订阅(分页,支持筛选和排序
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status)
subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status, sortBy, sortOrder)
if err != nil {
return nil, nil, err
}
normalizeExpiredWindows(subs)
normalizeSubscriptionStatus(subs)
return subs, pag, nil
}
@@ -441,6 +457,18 @@ func normalizeExpiredWindows(subs []UserSubscription) {
}
}
// normalizeSubscriptionStatus 根据实际过期时间修正状态(仅影响返回数据,不影响数据库)
// 这确保前端显示正确的状态,即使定时任务尚未更新数据库
func normalizeSubscriptionStatus(subs []UserSubscription) {
now := time.Now()
for i := range subs {
sub := &subs[i]
if sub.Status == SubscriptionStatusActive && !sub.ExpiresAt.After(now) {
sub.Status = SubscriptionStatusExpired
}
}
}
// startOfDay 返回给定时间所在日期的零点(保持原时区)
func startOfDay(t time.Time) time.Time {
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location())
@@ -659,11 +687,6 @@ func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Conte
return progresses, nil
}
// UpdateExpiredSubscriptions 更新过期订阅状态(定时任务调用)
func (s *SubscriptionService) UpdateExpiredSubscriptions(ctx context.Context) (int64, error) {
return s.userSubRepo.BatchUpdateExpiredStatus(ctx)
}
// ValidateSubscription 验证订阅是否有效
func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *UserSubscription) error {
if sub.Status == SubscriptionStatusExpired {

View File

@@ -1,6 +1,10 @@
package service
import "context"
import (
"context"
"log/slog"
"strconv"
)
type TokenCacheInvalidator interface {
InvalidateToken(ctx context.Context, account *Account) error
@@ -24,18 +28,87 @@ func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, ac
return nil
}
var cacheKey string
var keysToDelete []string
accountIDKey := "account:" + strconv.FormatInt(account.ID, 10)
switch account.Platform {
case PlatformGemini:
cacheKey = GeminiTokenCacheKey(account)
// Gemini 可能有两种缓存键project_id 或 account_id
// 首次获取 token 时可能没有 project_id之后自动检测到 project_id 后会使用新 key
// 刷新时需要同时删除两种可能的 key确保不会遗留旧缓存
keysToDelete = append(keysToDelete, GeminiTokenCacheKey(account))
keysToDelete = append(keysToDelete, "gemini:"+accountIDKey)
case PlatformAntigravity:
cacheKey = AntigravityTokenCacheKey(account)
// Antigravity 同样可能有两种缓存键
keysToDelete = append(keysToDelete, AntigravityTokenCacheKey(account))
keysToDelete = append(keysToDelete, "ag:"+accountIDKey)
case PlatformOpenAI:
cacheKey = OpenAITokenCacheKey(account)
keysToDelete = append(keysToDelete, OpenAITokenCacheKey(account))
case PlatformAnthropic:
cacheKey = ClaudeTokenCacheKey(account)
keysToDelete = append(keysToDelete, ClaudeTokenCacheKey(account))
default:
return nil
}
return c.cache.DeleteAccessToken(ctx, cacheKey)
// 删除所有可能的缓存键(去重后)
seen := make(map[string]bool)
for _, key := range keysToDelete {
if seen[key] {
continue
}
seen[key] = true
if err := c.cache.DeleteAccessToken(ctx, key); err != nil {
slog.Warn("token_cache_delete_failed", "key", key, "account_id", account.ID, "error", err)
}
}
return nil
}
// CheckTokenVersion 检查 account 的 token 版本是否已过时,并返回最新的 account
// 用于解决异步刷新任务与请求线程的竞态条件:
// 如果刷新任务已更新 token 并删除缓存,此时请求线程的旧 account 对象不应写入缓存
//
// 返回值:
// - latestAccount: 从 DB 获取的最新 account如果查询失败则返回 nil
// - isStale: true 表示 token 已过时(应使用 latestAccountfalse 表示可以使用当前 account
func CheckTokenVersion(ctx context.Context, account *Account, repo AccountRepository) (latestAccount *Account, isStale bool) {
if account == nil || repo == nil {
return nil, false
}
currentVersion := account.GetCredentialAsInt64("_token_version")
latestAccount, err := repo.GetByID(ctx, account.ID)
if err != nil || latestAccount == nil {
// 查询失败,默认允许缓存,不返回 latestAccount
return nil, false
}
latestVersion := latestAccount.GetCredentialAsInt64("_token_version")
// 情况1: 当前 account 没有版本号,但 DB 中已有版本号
// 说明异步刷新任务已更新 token当前 account 已过时
if currentVersion == 0 && latestVersion > 0 {
slog.Debug("token_version_stale_no_current_version",
"account_id", account.ID,
"latest_version", latestVersion)
return latestAccount, true
}
// 情况2: 两边都没有版本号,说明从未被异步刷新过,允许缓存
if currentVersion == 0 && latestVersion == 0 {
return latestAccount, false
}
// 情况3: 比较版本号,如果 DB 中的版本更新,当前 account 已过时
if latestVersion > currentVersion {
slog.Debug("token_version_stale",
"account_id", account.ID,
"current_version", currentVersion,
"latest_version", latestVersion)
return latestAccount, true
}
return latestAccount, false
}

View File

@@ -51,7 +51,27 @@ func TestCompositeTokenCacheInvalidator_Gemini(t *testing.T) {
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, []string{"gemini:project-x"}, cache.deletedKeys)
// 新行为:同时删除基于 project_id 和 account_id 的缓存键
// 这是为了处理:首次获取 token 时可能没有 project_id之后自动检测到后会使用新 key
require.Equal(t, []string{"gemini:project-x", "gemini:account:10"}, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_GeminiWithoutProjectID(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
account := &Account{
ID: 10,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "gemini-token",
},
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
// 没有 project_id 时,两个 key 相同,去重后只删除一个
require.Equal(t, []string{"gemini:account:10"}, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
@@ -68,7 +88,26 @@ func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, []string{"ag:ag-project"}, cache.deletedKeys)
// 新行为:同时删除基于 project_id 和 account_id 的缓存键
require.Equal(t, []string{"ag:ag-project", "ag:account:99"}, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_AntigravityWithoutProjectID(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
account := &Account{
ID: 99,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "ag-token",
},
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
// 没有 project_id 时,两个 key 相同,去重后只删除一个
require.Equal(t, []string{"ag:account:99"}, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_OpenAI(t *testing.T) {
@@ -233,9 +272,10 @@ func TestCompositeTokenCacheInvalidator_DeleteError(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 新行为:删除失败只记录日志,不返回错误
// 这是因为缓存失效失败不应影响主业务流程
err := invalidator.InvalidateToken(context.Background(), tt.account)
require.Error(t, err)
require.Equal(t, expectedErr, err)
require.NoError(t, err)
})
}
}
@@ -252,9 +292,12 @@ func TestCompositeTokenCacheInvalidator_AllPlatformsIntegration(t *testing.T) {
{ID: 4, Platform: PlatformAnthropic, Type: AccountTypeOAuth},
}
// 新行为Gemini 和 Antigravity 会同时删除基于 project_id 和 account_id 的键
expectedKeys := []string{
"gemini:gemini-proj",
"gemini:account:1",
"ag:ag-proj",
"ag:account:2",
"openai:account:3",
"claude:account:4",
}
@@ -266,3 +309,239 @@ func TestCompositeTokenCacheInvalidator_AllPlatformsIntegration(t *testing.T) {
require.Equal(t, expectedKeys, cache.deletedKeys)
}
// ========== GetCredentialAsInt64 测试 ==========
func TestAccount_GetCredentialAsInt64(t *testing.T) {
tests := []struct {
name string
credentials map[string]any
key string
expected int64
}{
{
name: "int64_value",
credentials: map[string]any{"_token_version": int64(1737654321000)},
key: "_token_version",
expected: 1737654321000,
},
{
name: "float64_value",
credentials: map[string]any{"_token_version": float64(1737654321000)},
key: "_token_version",
expected: 1737654321000,
},
{
name: "int_value",
credentials: map[string]any{"_token_version": 12345},
key: "_token_version",
expected: 12345,
},
{
name: "string_value",
credentials: map[string]any{"_token_version": "1737654321000"},
key: "_token_version",
expected: 1737654321000,
},
{
name: "string_with_spaces",
credentials: map[string]any{"_token_version": " 1737654321000 "},
key: "_token_version",
expected: 1737654321000,
},
{
name: "nil_credentials",
credentials: nil,
key: "_token_version",
expected: 0,
},
{
name: "missing_key",
credentials: map[string]any{"other_key": 123},
key: "_token_version",
expected: 0,
},
{
name: "nil_value",
credentials: map[string]any{"_token_version": nil},
key: "_token_version",
expected: 0,
},
{
name: "invalid_string",
credentials: map[string]any{"_token_version": "not_a_number"},
key: "_token_version",
expected: 0,
},
{
name: "empty_string",
credentials: map[string]any{"_token_version": ""},
key: "_token_version",
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{Credentials: tt.credentials}
result := account.GetCredentialAsInt64(tt.key)
require.Equal(t, tt.expected, result)
})
}
}
func TestAccount_GetCredentialAsInt64_NilAccount(t *testing.T) {
var account *Account
result := account.GetCredentialAsInt64("_token_version")
require.Equal(t, int64(0), result)
}
// ========== CheckTokenVersion 测试 ==========
func TestCheckTokenVersion(t *testing.T) {
tests := []struct {
name string
account *Account
latestAccount *Account
repoErr error
expectedStale bool
}{
{
name: "nil_account",
account: nil,
latestAccount: nil,
expectedStale: false,
},
{
name: "no_version_in_account_but_db_has_version",
account: &Account{
ID: 1,
Credentials: map[string]any{},
},
latestAccount: &Account{
ID: 1,
Credentials: map[string]any{"_token_version": int64(100)},
},
expectedStale: true, // 当前 account 无版本但 DB 有,说明已被异步刷新,当前已过时
},
{
name: "both_no_version",
account: &Account{
ID: 1,
Credentials: map[string]any{},
},
latestAccount: &Account{
ID: 1,
Credentials: map[string]any{},
},
expectedStale: false, // 两边都没有版本号,说明从未被异步刷新过,允许缓存
},
{
name: "same_version",
account: &Account{
ID: 1,
Credentials: map[string]any{"_token_version": int64(100)},
},
latestAccount: &Account{
ID: 1,
Credentials: map[string]any{"_token_version": int64(100)},
},
expectedStale: false,
},
{
name: "current_version_newer",
account: &Account{
ID: 1,
Credentials: map[string]any{"_token_version": int64(200)},
},
latestAccount: &Account{
ID: 1,
Credentials: map[string]any{"_token_version": int64(100)},
},
expectedStale: false,
},
{
name: "current_version_older_stale",
account: &Account{
ID: 1,
Credentials: map[string]any{"_token_version": int64(100)},
},
latestAccount: &Account{
ID: 1,
Credentials: map[string]any{"_token_version": int64(200)},
},
expectedStale: true, // 当前版本过时
},
{
name: "repo_error",
account: &Account{
ID: 1,
Credentials: map[string]any{"_token_version": int64(100)},
},
latestAccount: nil,
repoErr: errors.New("db error"),
expectedStale: false, // 查询失败,默认允许缓存
},
{
name: "repo_returns_nil",
account: &Account{
ID: 1,
Credentials: map[string]any{"_token_version": int64(100)},
},
latestAccount: nil,
repoErr: nil,
expectedStale: false, // 查询返回 nil默认允许缓存
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 由于 CheckTokenVersion 接受 AccountRepository 接口,而创建完整的 mock 很繁琐
// 这里我们直接测试函数的核心逻辑来验证行为
if tt.name == "nil_account" {
_, isStale := CheckTokenVersion(context.Background(), nil, nil)
require.Equal(t, tt.expectedStale, isStale)
return
}
// 模拟 CheckTokenVersion 的核心逻辑
account := tt.account
currentVersion := account.GetCredentialAsInt64("_token_version")
// 模拟 repo 查询
latestAccount := tt.latestAccount
if tt.repoErr != nil || latestAccount == nil {
require.Equal(t, tt.expectedStale, false)
return
}
latestVersion := latestAccount.GetCredentialAsInt64("_token_version")
// 情况1: 当前 account 没有版本号,但 DB 中已有版本号
if currentVersion == 0 && latestVersion > 0 {
require.Equal(t, tt.expectedStale, true)
return
}
// 情况2: 两边都没有版本号
if currentVersion == 0 && latestVersion == 0 {
require.Equal(t, tt.expectedStale, false)
return
}
// 情况3: 比较版本号
isStale := latestVersion > currentVersion
require.Equal(t, tt.expectedStale, isStale)
})
}
}
func TestCheckTokenVersion_NilRepo(t *testing.T) {
account := &Account{
ID: 1,
Credentials: map[string]any{"_token_version": int64(100)},
}
_, isStale := CheckTokenVersion(context.Background(), account, nil)
require.False(t, isStale) // nil repo默认允许缓存
}

View File

@@ -169,6 +169,10 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
// 如果有新凭证,先更新(即使有错误也要保存 token
if newCredentials != nil {
// 记录刷新版本时间戳,用于解决缓存一致性问题
// TokenProvider 写入缓存前会检查此版本,如果版本已更新则跳过写入
newCredentials["_token_version"] = time.Now().UnixMilli()
account.Credentials = newCredentials
if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
return fmt.Errorf("failed to save credentials: %w", saveErr)
@@ -233,7 +237,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
}
// isNonRetryableRefreshError 判断是否为不可重试的刷新错误
// 这些错误通常表示凭证已失效,需要用户重新授权
// 这些错误通常表示凭证已失效或配置确实缺失,需要用户重新授权
// 注意missing_project_id 错误只在真正缺失(从未获取过)时返回,临时获取失败不会返回此错误
func isNonRetryableRefreshError(err error) bool {
if err == nil {
return false

View File

@@ -0,0 +1,506 @@
package service
import (
"context"
"crypto/rand"
"crypto/subtle"
"encoding/hex"
"fmt"
"log/slog"
"time"
"github.com/pquerna/otp/totp"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
var (
ErrTotpNotEnabled = infraerrors.BadRequest("TOTP_NOT_ENABLED", "totp feature is not enabled")
ErrTotpAlreadyEnabled = infraerrors.BadRequest("TOTP_ALREADY_ENABLED", "totp is already enabled for this account")
ErrTotpNotSetup = infraerrors.BadRequest("TOTP_NOT_SETUP", "totp is not set up for this account")
ErrTotpInvalidCode = infraerrors.BadRequest("TOTP_INVALID_CODE", "invalid totp code")
ErrTotpSetupExpired = infraerrors.BadRequest("TOTP_SETUP_EXPIRED", "totp setup session expired")
ErrTotpTooManyAttempts = infraerrors.TooManyRequests("TOTP_TOO_MANY_ATTEMPTS", "too many verification attempts, please try again later")
ErrVerifyCodeRequired = infraerrors.BadRequest("VERIFY_CODE_REQUIRED", "email verification code is required")
ErrPasswordRequired = infraerrors.BadRequest("PASSWORD_REQUIRED", "password is required")
)
// TotpCache defines cache operations for TOTP service
type TotpCache interface {
// Setup session methods
GetSetupSession(ctx context.Context, userID int64) (*TotpSetupSession, error)
SetSetupSession(ctx context.Context, userID int64, session *TotpSetupSession, ttl time.Duration) error
DeleteSetupSession(ctx context.Context, userID int64) error
// Login session methods (for 2FA login flow)
GetLoginSession(ctx context.Context, tempToken string) (*TotpLoginSession, error)
SetLoginSession(ctx context.Context, tempToken string, session *TotpLoginSession, ttl time.Duration) error
DeleteLoginSession(ctx context.Context, tempToken string) error
// Rate limiting
IncrementVerifyAttempts(ctx context.Context, userID int64) (int, error)
GetVerifyAttempts(ctx context.Context, userID int64) (int, error)
ClearVerifyAttempts(ctx context.Context, userID int64) error
}
// SecretEncryptor defines encryption operations for TOTP secrets
type SecretEncryptor interface {
Encrypt(plaintext string) (string, error)
Decrypt(ciphertext string) (string, error)
}
// TotpSetupSession represents a TOTP setup session
type TotpSetupSession struct {
Secret string // Plain text TOTP secret (not encrypted yet)
SetupToken string // Random token to verify setup request
CreatedAt time.Time
}
// TotpLoginSession represents a pending 2FA login session
type TotpLoginSession struct {
UserID int64
Email string
TokenExpiry time.Time
}
// TotpStatus represents the TOTP status for a user
type TotpStatus struct {
Enabled bool `json:"enabled"`
EnabledAt *time.Time `json:"enabled_at,omitempty"`
FeatureEnabled bool `json:"feature_enabled"`
}
// TotpSetupResponse represents the response for initiating TOTP setup
type TotpSetupResponse struct {
Secret string `json:"secret"`
QRCodeURL string `json:"qr_code_url"`
SetupToken string `json:"setup_token"`
Countdown int `json:"countdown"` // seconds until setup expires
}
const (
totpSetupTTL = 5 * time.Minute
totpLoginTTL = 5 * time.Minute
totpAttemptsTTL = 15 * time.Minute
maxTotpAttempts = 5
totpIssuer = "Sub2API"
)
// TotpService handles TOTP operations
type TotpService struct {
userRepo UserRepository
encryptor SecretEncryptor
cache TotpCache
settingService *SettingService
emailService *EmailService
emailQueueService *EmailQueueService
}
// NewTotpService creates a new TOTP service
func NewTotpService(
userRepo UserRepository,
encryptor SecretEncryptor,
cache TotpCache,
settingService *SettingService,
emailService *EmailService,
emailQueueService *EmailQueueService,
) *TotpService {
return &TotpService{
userRepo: userRepo,
encryptor: encryptor,
cache: cache,
settingService: settingService,
emailService: emailService,
emailQueueService: emailQueueService,
}
}
// GetStatus returns the TOTP status for a user
func (s *TotpService) GetStatus(ctx context.Context, userID int64) (*TotpStatus, error) {
featureEnabled := s.settingService.IsTotpEnabled(ctx)
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
return &TotpStatus{
Enabled: user.TotpEnabled,
EnabledAt: user.TotpEnabledAt,
FeatureEnabled: featureEnabled,
}, nil
}
// InitiateSetup starts the TOTP setup process
// If email verification is enabled, emailCode is required; otherwise password is required
func (s *TotpService) InitiateSetup(ctx context.Context, userID int64, emailCode, password string) (*TotpSetupResponse, error) {
// Check if TOTP feature is enabled globally
if !s.settingService.IsTotpEnabled(ctx) {
return nil, ErrTotpNotEnabled
}
// Get user and check if TOTP is already enabled
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
if user.TotpEnabled {
return nil, ErrTotpAlreadyEnabled
}
// Verify identity based on email verification setting
if s.settingService.IsEmailVerifyEnabled(ctx) {
// Email verification enabled - verify email code
if emailCode == "" {
return nil, ErrVerifyCodeRequired
}
if err := s.emailService.VerifyCode(ctx, user.Email, emailCode); err != nil {
return nil, err
}
} else {
// Email verification disabled - verify password
if password == "" {
return nil, ErrPasswordRequired
}
if !user.CheckPassword(password) {
return nil, ErrPasswordIncorrect
}
}
// Generate a new TOTP key
key, err := totp.Generate(totp.GenerateOpts{
Issuer: totpIssuer,
AccountName: user.Email,
})
if err != nil {
return nil, fmt.Errorf("generate totp key: %w", err)
}
// Generate a random setup token
setupToken, err := generateRandomToken(32)
if err != nil {
return nil, fmt.Errorf("generate setup token: %w", err)
}
// Store the setup session in cache
session := &TotpSetupSession{
Secret: key.Secret(),
SetupToken: setupToken,
CreatedAt: time.Now(),
}
if err := s.cache.SetSetupSession(ctx, userID, session, totpSetupTTL); err != nil {
return nil, fmt.Errorf("store setup session: %w", err)
}
return &TotpSetupResponse{
Secret: key.Secret(),
QRCodeURL: key.URL(),
SetupToken: setupToken,
Countdown: int(totpSetupTTL.Seconds()),
}, nil
}
// CompleteSetup completes the TOTP setup by verifying the code
func (s *TotpService) CompleteSetup(ctx context.Context, userID int64, totpCode, setupToken string) error {
// Check if TOTP feature is enabled globally
if !s.settingService.IsTotpEnabled(ctx) {
return ErrTotpNotEnabled
}
// Get the setup session
session, err := s.cache.GetSetupSession(ctx, userID)
if err != nil {
return ErrTotpSetupExpired
}
if session == nil {
return ErrTotpSetupExpired
}
// Verify the setup token (constant-time comparison)
if subtle.ConstantTimeCompare([]byte(session.SetupToken), []byte(setupToken)) != 1 {
return ErrTotpSetupExpired
}
// Verify the TOTP code
if !totp.Validate(totpCode, session.Secret) {
return ErrTotpInvalidCode
}
setupSecretPrefix := "N/A"
if len(session.Secret) >= 4 {
setupSecretPrefix = session.Secret[:4]
}
slog.Debug("totp_complete_setup_before_encrypt",
"user_id", userID,
"secret_len", len(session.Secret),
"secret_prefix", setupSecretPrefix)
// Encrypt the secret
encryptedSecret, err := s.encryptor.Encrypt(session.Secret)
if err != nil {
return fmt.Errorf("encrypt totp secret: %w", err)
}
slog.Debug("totp_complete_setup_encrypted",
"user_id", userID,
"encrypted_len", len(encryptedSecret))
// Verify encryption by decrypting
decrypted, decErr := s.encryptor.Decrypt(encryptedSecret)
if decErr != nil {
slog.Debug("totp_complete_setup_verify_failed",
"user_id", userID,
"error", decErr)
} else {
decryptedPrefix := "N/A"
if len(decrypted) >= 4 {
decryptedPrefix = decrypted[:4]
}
slog.Debug("totp_complete_setup_verified",
"user_id", userID,
"original_len", len(session.Secret),
"decrypted_len", len(decrypted),
"match", session.Secret == decrypted,
"decrypted_prefix", decryptedPrefix)
}
// Update user with encrypted TOTP secret
if err := s.userRepo.UpdateTotpSecret(ctx, userID, &encryptedSecret); err != nil {
return fmt.Errorf("update totp secret: %w", err)
}
// Enable TOTP for the user
if err := s.userRepo.EnableTotp(ctx, userID); err != nil {
return fmt.Errorf("enable totp: %w", err)
}
// Clean up the setup session
_ = s.cache.DeleteSetupSession(ctx, userID)
return nil
}
// Disable disables TOTP for a user
// If email verification is enabled, emailCode is required; otherwise password is required
func (s *TotpService) Disable(ctx context.Context, userID int64, emailCode, password string) error {
// Get user
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("get user: %w", err)
}
if !user.TotpEnabled {
return ErrTotpNotSetup
}
// Verify identity based on email verification setting
if s.settingService.IsEmailVerifyEnabled(ctx) {
// Email verification enabled - verify email code
if emailCode == "" {
return ErrVerifyCodeRequired
}
if err := s.emailService.VerifyCode(ctx, user.Email, emailCode); err != nil {
return err
}
} else {
// Email verification disabled - verify password
if password == "" {
return ErrPasswordRequired
}
if !user.CheckPassword(password) {
return ErrPasswordIncorrect
}
}
// Disable TOTP
if err := s.userRepo.DisableTotp(ctx, userID); err != nil {
return fmt.Errorf("disable totp: %w", err)
}
return nil
}
// VerifyCode verifies a TOTP code for a user
func (s *TotpService) VerifyCode(ctx context.Context, userID int64, code string) error {
slog.Debug("totp_verify_code_called",
"user_id", userID,
"code_len", len(code))
// Check rate limiting
attempts, err := s.cache.GetVerifyAttempts(ctx, userID)
if err == nil && attempts >= maxTotpAttempts {
return ErrTotpTooManyAttempts
}
// Get user
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
slog.Debug("totp_verify_get_user_failed",
"user_id", userID,
"error", err)
return infraerrors.InternalServer("TOTP_VERIFY_ERROR", "failed to verify totp code")
}
if !user.TotpEnabled || user.TotpSecretEncrypted == nil {
slog.Debug("totp_verify_not_setup",
"user_id", userID,
"enabled", user.TotpEnabled,
"has_secret", user.TotpSecretEncrypted != nil)
return ErrTotpNotSetup
}
slog.Debug("totp_verify_encrypted_secret",
"user_id", userID,
"encrypted_len", len(*user.TotpSecretEncrypted))
// Decrypt the secret
secret, err := s.encryptor.Decrypt(*user.TotpSecretEncrypted)
if err != nil {
slog.Debug("totp_verify_decrypt_failed",
"user_id", userID,
"error", err)
return infraerrors.InternalServer("TOTP_VERIFY_ERROR", "failed to verify totp code")
}
secretPrefix := "N/A"
if len(secret) >= 4 {
secretPrefix = secret[:4]
}
slog.Debug("totp_verify_decrypted",
"user_id", userID,
"secret_len", len(secret),
"secret_prefix", secretPrefix)
// Verify the code
valid := totp.Validate(code, secret)
slog.Debug("totp_verify_result",
"user_id", userID,
"valid", valid,
"secret_len", len(secret),
"secret_prefix", secretPrefix,
"server_time", time.Now().UTC().Format(time.RFC3339))
if !valid {
// Increment failed attempts
_, _ = s.cache.IncrementVerifyAttempts(ctx, userID)
return ErrTotpInvalidCode
}
// Clear attempt counter on success
_ = s.cache.ClearVerifyAttempts(ctx, userID)
return nil
}
// CreateLoginSession creates a temporary login session for 2FA
func (s *TotpService) CreateLoginSession(ctx context.Context, userID int64, email string) (string, error) {
// Generate a random temp token
tempToken, err := generateRandomToken(32)
if err != nil {
return "", fmt.Errorf("generate temp token: %w", err)
}
session := &TotpLoginSession{
UserID: userID,
Email: email,
TokenExpiry: time.Now().Add(totpLoginTTL),
}
if err := s.cache.SetLoginSession(ctx, tempToken, session, totpLoginTTL); err != nil {
return "", fmt.Errorf("store login session: %w", err)
}
return tempToken, nil
}
// GetLoginSession retrieves a login session
func (s *TotpService) GetLoginSession(ctx context.Context, tempToken string) (*TotpLoginSession, error) {
return s.cache.GetLoginSession(ctx, tempToken)
}
// DeleteLoginSession deletes a login session
func (s *TotpService) DeleteLoginSession(ctx context.Context, tempToken string) error {
return s.cache.DeleteLoginSession(ctx, tempToken)
}
// IsTotpEnabledForUser checks if TOTP is enabled for a specific user
func (s *TotpService) IsTotpEnabledForUser(ctx context.Context, userID int64) (bool, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return false, fmt.Errorf("get user: %w", err)
}
return user.TotpEnabled, nil
}
// MaskEmail masks an email address for display
func MaskEmail(email string) string {
if len(email) < 3 {
return "***"
}
atIdx := -1
for i, c := range email {
if c == '@' {
atIdx = i
break
}
}
if atIdx == -1 || atIdx < 1 {
return email[:1] + "***"
}
localPart := email[:atIdx]
domain := email[atIdx:]
if len(localPart) <= 2 {
return localPart[:1] + "***" + domain
}
return localPart[:1] + "***" + localPart[len(localPart)-1:] + domain
}
// generateRandomToken generates a random hex-encoded token
func generateRandomToken(byteLength int) (string, error) {
b := make([]byte, byteLength)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
// VerificationMethod represents the method required for TOTP operations
type VerificationMethod struct {
Method string `json:"method"` // "email" or "password"
}
// GetVerificationMethod returns the verification method for TOTP operations
func (s *TotpService) GetVerificationMethod(ctx context.Context) *VerificationMethod {
if s.settingService.IsEmailVerifyEnabled(ctx) {
return &VerificationMethod{Method: "email"}
}
return &VerificationMethod{Method: "password"}
}
// SendVerifyCode sends an email verification code for TOTP operations
func (s *TotpService) SendVerifyCode(ctx context.Context, userID int64) error {
// Check if email verification is enabled
if !s.settingService.IsEmailVerifyEnabled(ctx) {
return infraerrors.BadRequest("EMAIL_VERIFY_NOT_ENABLED", "email verification is not enabled")
}
// Get user email
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("get user: %w", err)
}
// Get site name for email
siteName := s.settingService.GetSiteName(ctx)
// Send verification code via queue
return s.emailQueueService.EnqueueVerifyCode(user.Email, siteName)
}

View File

@@ -21,6 +21,11 @@ type User struct {
CreatedAt time.Time
UpdatedAt time.Time
// TOTP 双因素认证字段
TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥
TotpEnabled bool // 是否启用 TOTP
TotpEnabledAt *time.Time // TOTP 启用时间
APIKeys []APIKey
Subscriptions []UserSubscription
}

View File

@@ -38,6 +38,11 @@ type UserRepository interface {
UpdateConcurrency(ctx context.Context, id int64, amount int) error
ExistsByEmail(ctx context.Context, email string) (bool, error)
RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error)
// TOTP 相关方法
UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error
EnableTotp(ctx context.Context, userID int64) error
DisableTotp(ctx context.Context, userID int64) error
}
// UpdateProfileRequest 更新用户资料请求

View File

@@ -18,7 +18,7 @@ type UserSubscriptionRepository interface {
ListByUserID(ctx context.Context, userID int64) ([]UserSubscription, error)
ListActiveByUserID(ctx context.Context, userID int64) ([]UserSubscription, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error)
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]UserSubscription, *pagination.PaginationResult, error)
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error)
ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error)
ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error

View File

@@ -72,6 +72,13 @@ func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpirySe
return svc
}
// ProvideSubscriptionExpiryService creates and starts SubscriptionExpiryService.
func ProvideSubscriptionExpiryService(userSubRepo UserSubscriptionRepository) *SubscriptionExpiryService {
svc := NewSubscriptionExpiryService(userSubRepo, time.Minute)
svc.Start()
return svc
}
// ProvideTimingWheelService creates and starts TimingWheelService
func ProvideTimingWheelService() (*TimingWheelService, error) {
svc, err := NewTimingWheelService()
@@ -256,6 +263,7 @@ var ProviderSet = wire.NewSet(
ProvideUpdateService,
ProvideTokenRefreshService,
ProvideAccountExpiryService,
ProvideSubscriptionExpiryService,
ProvideTimingWheelService,
ProvideDashboardAggregationService,
ProvideUsageCleanupService,
@@ -263,4 +271,5 @@ var ProviderSet = wire.NewSet(
NewAntigravityQuotaFetcher,
NewUserAttributeService,
NewUsageCache,
NewTotpService,
)

View File

@@ -46,7 +46,7 @@ func ValidateURLFormat(raw string, allowInsecureHTTP bool) (string, error) {
}
}
return trimmed, nil
return strings.TrimRight(trimmed, "/"), nil
}
func ValidateHTTPSURL(raw string, opts ValidationOptions) (string, error) {

View File

@@ -21,4 +21,31 @@ func TestValidateURLFormat(t *testing.T) {
if _, err := ValidateURLFormat("https://example.com:bad", true); err == nil {
t.Fatalf("expected invalid port to fail")
}
// 验证末尾斜杠被移除
normalized, err := ValidateURLFormat("https://example.com/", false)
if err != nil {
t.Fatalf("expected trailing slash url to pass, got %v", err)
}
if normalized != "https://example.com" {
t.Fatalf("expected trailing slash to be removed, got %s", normalized)
}
// 验证多个末尾斜杠被移除
normalized, err = ValidateURLFormat("https://example.com///", false)
if err != nil {
t.Fatalf("expected multiple trailing slashes to pass, got %v", err)
}
if normalized != "https://example.com" {
t.Fatalf("expected all trailing slashes to be removed, got %s", normalized)
}
// 验证带路径的 URL 末尾斜杠被移除
normalized, err = ValidateURLFormat("https://example.com/api/v1/", false)
if err != nil {
t.Fatalf("expected trailing slash url with path to pass, got %v", err)
}
if normalized != "https://example.com/api/v1" {
t.Fatalf("expected trailing slash to be removed from path, got %s", normalized)
}
}

View File

@@ -0,0 +1,12 @@
-- 为 users 表添加 TOTP 双因素认证字段
ALTER TABLE users
ADD COLUMN IF NOT EXISTS totp_secret_encrypted TEXT DEFAULT NULL,
ADD COLUMN IF NOT EXISTS totp_enabled BOOLEAN NOT NULL DEFAULT FALSE,
ADD COLUMN IF NOT EXISTS totp_enabled_at TIMESTAMPTZ DEFAULT NULL;
COMMENT ON COLUMN users.totp_secret_encrypted IS 'AES-256-GCM 加密的 TOTP 密钥';
COMMENT ON COLUMN users.totp_enabled IS '是否启用 TOTP 双因素认证';
COMMENT ON COLUMN users.totp_enabled_at IS 'TOTP 启用时间';
-- 创建索引以支持快速查询启用 2FA 的用户
CREATE INDEX IF NOT EXISTS idx_users_totp_enabled ON users(totp_enabled) WHERE deleted_at IS NULL AND totp_enabled = true;

View File

@@ -61,6 +61,18 @@ ADMIN_PASSWORD=
JWT_SECRET=
JWT_EXPIRE_HOUR=24
# -----------------------------------------------------------------------------
# TOTP (2FA) Configuration
# TOTP双因素认证配置
# -----------------------------------------------------------------------------
# IMPORTANT: Set a fixed encryption key for TOTP secrets. If left empty, a
# random key will be generated on each startup, causing all existing TOTP
# configurations to become invalid (users won't be able to login with 2FA).
# Generate a secure key: openssl rand -hex 32
# 重要:设置固定的 TOTP 加密密钥。如果留空,每次启动将生成随机密钥,
# 导致现有的 TOTP 配置失效(用户无法使用双因素认证登录)。
TOTP_ENCRYPTION_KEY=
# -----------------------------------------------------------------------------
# Configuration File (Optional)
# -----------------------------------------------------------------------------

View File

@@ -403,6 +403,21 @@ jwt:
# 令牌过期时间(小时,最大 24
expire_hour: 24
# =============================================================================
# TOTP (2FA) Configuration
# TOTP 双因素认证配置
# =============================================================================
totp:
# IMPORTANT: Set a fixed encryption key for TOTP secrets.
# 重要:设置固定的 TOTP 加密密钥。
# If left empty, a random key will be generated on each startup, causing all
# existing TOTP configurations to become invalid (users won't be able to
# login with 2FA).
# 如果留空,每次启动将生成随机密钥,导致现有的 TOTP 配置失效(用户无法使用
# 双因素认证登录)。
# Generate with / 生成命令: openssl rand -hex 32
encryption_key: ""
# =============================================================================
# LinuxDo Connect OAuth Login (SSO)
# LinuxDo Connect OAuth 登录(用于 Sub2API 用户登录)

View File

@@ -79,6 +79,16 @@ services:
- JWT_SECRET=${JWT_SECRET:-}
- JWT_EXPIRE_HOUR=${JWT_EXPIRE_HOUR:-24}
# =======================================================================
# TOTP (2FA) Configuration
# =======================================================================
# IMPORTANT: Set a fixed encryption key for TOTP secrets. If left empty,
# a random key will be generated on each startup, causing all existing
# TOTP configurations to become invalid (users won't be able to login
# with 2FA).
# Generate a secure key: openssl rand -hex 32
- TOTP_ENCRYPTION_KEY=${TOTP_ENCRYPTION_KEY:-}
# =======================================================================
# Timezone Configuration
# This affects ALL time operations in the application:

View File

@@ -15,6 +15,7 @@
"driver.js": "^1.4.0",
"file-saver": "^2.0.5",
"pinia": "^2.1.7",
"qrcode": "^1.5.4",
"vue": "^3.4.0",
"vue-chartjs": "^5.3.0",
"vue-i18n": "^9.14.5",
@@ -25,6 +26,7 @@
"@types/file-saver": "^2.0.7",
"@types/mdx": "^2.0.13",
"@types/node": "^20.10.5",
"@types/qrcode": "^1.5.6",
"@typescript-eslint/eslint-plugin": "^7.18.0",
"@typescript-eslint/parser": "^7.18.0",
"@vitejs/plugin-vue": "^5.2.3",
@@ -1680,6 +1682,16 @@
"integrity": "sha512-dISoDXWWQwUquiKsyZ4Ng+HX2KsPL7LyHKHQwgGFEA3IaKac4Obd+h2a/a6waisAoepJlBcx9paWqjA8/HVjCw==",
"license": "MIT"
},
"node_modules/@types/qrcode": {
"version": "1.5.6",
"resolved": "https://registry.npmmirror.com/@types/qrcode/-/qrcode-1.5.6.tgz",
"integrity": "sha512-te7NQcV2BOvdj2b1hCAHzAoMNuj65kNBMz0KBaxM6c3VGBOhU0dURQKOtH8CFNI/dsKkwlv32p26qYQTWoB5bw==",
"dev": true,
"license": "MIT",
"dependencies": {
"@types/node": "*"
}
},
"node_modules/@types/web-bluetooth": {
"version": "0.0.20",
"resolved": "https://registry.npmjs.org/@types/web-bluetooth/-/web-bluetooth-0.0.20.tgz",
@@ -2354,7 +2366,6 @@
"version": "5.0.1",
"resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz",
"integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">=8"
@@ -2364,7 +2375,6 @@
"version": "4.3.0",
"resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz",
"integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==",
"dev": true,
"license": "MIT",
"dependencies": {
"color-convert": "^2.0.1"
@@ -2646,6 +2656,15 @@
"node": ">=6"
}
},
"node_modules/camelcase": {
"version": "5.3.1",
"resolved": "https://registry.npmmirror.com/camelcase/-/camelcase-5.3.1.tgz",
"integrity": "sha512-L28STB170nwWS63UjtlEOE3dldQApaJXZkOI1uMFfzf3rRuPegHaHesyee+YxQ+W6SvRDQV6UrdOdRiR153wJg==",
"license": "MIT",
"engines": {
"node": ">=6"
}
},
"node_modules/camelcase-css": {
"version": "2.0.1",
"resolved": "https://registry.npmjs.org/camelcase-css/-/camelcase-css-2.0.1.tgz",
@@ -2784,6 +2803,51 @@
"node": ">= 6"
}
},
"node_modules/cliui": {
"version": "6.0.0",
"resolved": "https://registry.npmmirror.com/cliui/-/cliui-6.0.0.tgz",
"integrity": "sha512-t6wbgtoCXvAzst7QgXxJYqPt0usEfbgQdftEPbLL/cvv6HPE5VgvqCuAIDR0NgU52ds6rFwqrgakNLrHEjCbrQ==",
"license": "ISC",
"dependencies": {
"string-width": "^4.2.0",
"strip-ansi": "^6.0.0",
"wrap-ansi": "^6.2.0"
}
},
"node_modules/cliui/node_modules/emoji-regex": {
"version": "8.0.0",
"resolved": "https://registry.npmmirror.com/emoji-regex/-/emoji-regex-8.0.0.tgz",
"integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==",
"license": "MIT"
},
"node_modules/cliui/node_modules/string-width": {
"version": "4.2.3",
"resolved": "https://registry.npmmirror.com/string-width/-/string-width-4.2.3.tgz",
"integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==",
"license": "MIT",
"dependencies": {
"emoji-regex": "^8.0.0",
"is-fullwidth-code-point": "^3.0.0",
"strip-ansi": "^6.0.1"
},
"engines": {
"node": ">=8"
}
},
"node_modules/cliui/node_modules/wrap-ansi": {
"version": "6.2.0",
"resolved": "https://registry.npmmirror.com/wrap-ansi/-/wrap-ansi-6.2.0.tgz",
"integrity": "sha512-r6lPcBGxZXlIcymEu7InxDMhdW0KDxpLgoFLcguasxCaJ/SOIZwINatK9KY/tf+ZrlywOKU0UDj3ATXUBfxJXA==",
"license": "MIT",
"dependencies": {
"ansi-styles": "^4.0.0",
"string-width": "^4.1.0",
"strip-ansi": "^6.0.0"
},
"engines": {
"node": ">=8"
}
},
"node_modules/clsx": {
"version": "2.1.1",
"resolved": "https://registry.npmjs.org/clsx/-/clsx-2.1.1.tgz",
@@ -2806,7 +2870,6 @@
"version": "2.0.1",
"resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz",
"integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==",
"dev": true,
"license": "MIT",
"dependencies": {
"color-name": "~1.1.4"
@@ -2819,7 +2882,6 @@
"version": "1.1.4",
"resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz",
"integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==",
"dev": true,
"license": "MIT"
},
"node_modules/combined-stream": {
@@ -2989,6 +3051,15 @@
}
}
},
"node_modules/decamelize": {
"version": "1.2.0",
"resolved": "https://registry.npmmirror.com/decamelize/-/decamelize-1.2.0.tgz",
"integrity": "sha512-z2S+W9X73hAUUki+N+9Za2lBlun89zigOyGrsax+KUQ6wKW4ZoWpEYBkGhQjwAjjDCkWxhY0VKEhk8wzY7F5cA==",
"license": "MIT",
"engines": {
"node": ">=0.10.0"
}
},
"node_modules/decimal.js": {
"version": "10.6.0",
"resolved": "https://registry.npmjs.org/decimal.js/-/decimal.js-10.6.0.tgz",
@@ -3029,6 +3100,12 @@
"dev": true,
"license": "Apache-2.0"
},
"node_modules/dijkstrajs": {
"version": "1.0.3",
"resolved": "https://registry.npmmirror.com/dijkstrajs/-/dijkstrajs-1.0.3.tgz",
"integrity": "sha512-qiSlmBq9+BCdCA/L46dw8Uy93mloxsPSbwnm5yrKn2vMPiy8KyAskTF6zuV/j5BMsmOGZDPs7KjU+mjb670kfA==",
"license": "MIT"
},
"node_modules/dir-glob": {
"version": "3.0.1",
"resolved": "https://registry.npmjs.org/dir-glob/-/dir-glob-3.0.1.tgz",
@@ -3759,6 +3836,15 @@
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/get-caller-file": {
"version": "2.0.5",
"resolved": "https://registry.npmmirror.com/get-caller-file/-/get-caller-file-2.0.5.tgz",
"integrity": "sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg==",
"license": "ISC",
"engines": {
"node": "6.* || 8.* || >= 10.*"
}
},
"node_modules/get-intrinsic": {
"version": "1.3.0",
"resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.3.0.tgz",
@@ -4156,7 +4242,6 @@
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz",
"integrity": "sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">=8"
@@ -4883,6 +4968,15 @@
"url": "https://github.com/sponsors/sindresorhus"
}
},
"node_modules/p-try": {
"version": "2.2.0",
"resolved": "https://registry.npmmirror.com/p-try/-/p-try-2.2.0.tgz",
"integrity": "sha512-R4nPAVTAU0B9D35/Gk3uJf/7XYbQcyohSKdvAxIRSNghFl4e71hVoGnBNQz9cWaXxO2I10KTC+3jMdvvoKw6dQ==",
"license": "MIT",
"engines": {
"node": ">=6"
}
},
"node_modules/package-json-from-dist": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/package-json-from-dist/-/package-json-from-dist-1.0.1.tgz",
@@ -4957,7 +5051,6 @@
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/path-exists/-/path-exists-4.0.0.tgz",
"integrity": "sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">=8"
@@ -5093,6 +5186,15 @@
"node": ">= 6"
}
},
"node_modules/pngjs": {
"version": "5.0.0",
"resolved": "https://registry.npmmirror.com/pngjs/-/pngjs-5.0.0.tgz",
"integrity": "sha512-40QW5YalBNfQo5yRYmiw7Yz6TKKVr3h6970B2YE+3fQpsWcrbj1PzJgxeJ19DRQjhMbKPIuMY8rFaXc8moolVw==",
"license": "MIT",
"engines": {
"node": ">=10.13.0"
}
},
"node_modules/polished": {
"version": "4.3.1",
"resolved": "https://registry.npmjs.org/polished/-/polished-4.3.1.tgz",
@@ -5313,6 +5415,23 @@
"node": ">=6"
}
},
"node_modules/qrcode": {
"version": "1.5.4",
"resolved": "https://registry.npmmirror.com/qrcode/-/qrcode-1.5.4.tgz",
"integrity": "sha512-1ca71Zgiu6ORjHqFBDpnSMTR2ReToX4l1Au1VFLyVeBTFavzQnv5JxMFr3ukHVKpSrSA2MCk0lNJSykjUfz7Zg==",
"license": "MIT",
"dependencies": {
"dijkstrajs": "^1.0.1",
"pngjs": "^5.0.0",
"yargs": "^15.3.1"
},
"bin": {
"qrcode": "bin/qrcode"
},
"engines": {
"node": ">=10.13.0"
}
},
"node_modules/querystringify": {
"version": "2.2.0",
"resolved": "https://registry.npmjs.org/querystringify/-/querystringify-2.2.0.tgz",
@@ -5370,6 +5489,21 @@
"node": ">=8.10.0"
}
},
"node_modules/require-directory": {
"version": "2.1.1",
"resolved": "https://registry.npmmirror.com/require-directory/-/require-directory-2.1.1.tgz",
"integrity": "sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==",
"license": "MIT",
"engines": {
"node": ">=0.10.0"
}
},
"node_modules/require-main-filename": {
"version": "2.0.0",
"resolved": "https://registry.npmmirror.com/require-main-filename/-/require-main-filename-2.0.0.tgz",
"integrity": "sha512-NKN5kMDylKuldxYLSUfrbo5Tuzh4hd+2E8NPPX02mZtn1VuREQToYe/ZdlJy+J3uCpfaiGF05e7B8W0iXbQHmg==",
"license": "ISC"
},
"node_modules/requires-port": {
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/requires-port/-/requires-port-1.0.0.tgz",
@@ -5543,6 +5677,12 @@
"node": ">=10"
}
},
"node_modules/set-blocking": {
"version": "2.0.0",
"resolved": "https://registry.npmmirror.com/set-blocking/-/set-blocking-2.0.0.tgz",
"integrity": "sha512-KiKBS8AnWGEyLzofFfmvKwpdPzqiy16LvQfK3yv/fVH7Bj13/wl3JSR1J+rfgRE9q7xUJK4qvgS8raSOeLUehw==",
"license": "ISC"
},
"node_modules/shebang-command": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz",
@@ -5714,7 +5854,6 @@
"version": "6.0.1",
"resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz",
"integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==",
"dev": true,
"license": "MIT",
"dependencies": {
"ansi-regex": "^5.0.1"
@@ -6715,6 +6854,12 @@
"node": ">= 8"
}
},
"node_modules/which-module": {
"version": "2.0.1",
"resolved": "https://registry.npmmirror.com/which-module/-/which-module-2.0.1.tgz",
"integrity": "sha512-iBdZ57RDvnOR9AGBhML2vFZf7h8vmBjhoaZqODJBFWHVtKkDmKuHai3cx5PgVMrX5YDNp27AofYbAwctSS+vhQ==",
"license": "ISC"
},
"node_modules/why-is-node-running": {
"version": "2.3.0",
"resolved": "https://registry.npmjs.org/why-is-node-running/-/why-is-node-running-2.3.0.tgz",
@@ -6928,6 +7073,12 @@
"dev": true,
"license": "MIT"
},
"node_modules/y18n": {
"version": "4.0.3",
"resolved": "https://registry.npmmirror.com/y18n/-/y18n-4.0.3.tgz",
"integrity": "sha512-JKhqTOwSrqNA1NY5lSztJ1GrBiUodLMmIZuLiDaMRJ+itFd+ABVE8XBjOvIWL+rSqNDC74LCSFmlb/U4UZ4hJQ==",
"license": "ISC"
},
"node_modules/yaml": {
"version": "1.10.2",
"resolved": "https://registry.npmjs.org/yaml/-/yaml-1.10.2.tgz",
@@ -6937,6 +7088,113 @@
"node": ">= 6"
}
},
"node_modules/yargs": {
"version": "15.4.1",
"resolved": "https://registry.npmmirror.com/yargs/-/yargs-15.4.1.tgz",
"integrity": "sha512-aePbxDmcYW++PaqBsJ+HYUFwCdv4LVvdnhBy78E57PIor8/OVvhMrADFFEDh8DHDFRv/O9i3lPhsENjO7QX0+A==",
"license": "MIT",
"dependencies": {
"cliui": "^6.0.0",
"decamelize": "^1.2.0",
"find-up": "^4.1.0",
"get-caller-file": "^2.0.1",
"require-directory": "^2.1.1",
"require-main-filename": "^2.0.0",
"set-blocking": "^2.0.0",
"string-width": "^4.2.0",
"which-module": "^2.0.0",
"y18n": "^4.0.0",
"yargs-parser": "^18.1.2"
},
"engines": {
"node": ">=8"
}
},
"node_modules/yargs-parser": {
"version": "18.1.3",
"resolved": "https://registry.npmmirror.com/yargs-parser/-/yargs-parser-18.1.3.tgz",
"integrity": "sha512-o50j0JeToy/4K6OZcaQmW6lyXXKhq7csREXcDwk2omFPJEwUNOVtJKvmDr9EI1fAJZUyZcRF7kxGBWmRXudrCQ==",
"license": "ISC",
"dependencies": {
"camelcase": "^5.0.0",
"decamelize": "^1.2.0"
},
"engines": {
"node": ">=6"
}
},
"node_modules/yargs/node_modules/emoji-regex": {
"version": "8.0.0",
"resolved": "https://registry.npmmirror.com/emoji-regex/-/emoji-regex-8.0.0.tgz",
"integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==",
"license": "MIT"
},
"node_modules/yargs/node_modules/find-up": {
"version": "4.1.0",
"resolved": "https://registry.npmmirror.com/find-up/-/find-up-4.1.0.tgz",
"integrity": "sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw==",
"license": "MIT",
"dependencies": {
"locate-path": "^5.0.0",
"path-exists": "^4.0.0"
},
"engines": {
"node": ">=8"
}
},
"node_modules/yargs/node_modules/locate-path": {
"version": "5.0.0",
"resolved": "https://registry.npmmirror.com/locate-path/-/locate-path-5.0.0.tgz",
"integrity": "sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g==",
"license": "MIT",
"dependencies": {
"p-locate": "^4.1.0"
},
"engines": {
"node": ">=8"
}
},
"node_modules/yargs/node_modules/p-limit": {
"version": "2.3.0",
"resolved": "https://registry.npmmirror.com/p-limit/-/p-limit-2.3.0.tgz",
"integrity": "sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w==",
"license": "MIT",
"dependencies": {
"p-try": "^2.0.0"
},
"engines": {
"node": ">=6"
},
"funding": {
"url": "https://github.com/sponsors/sindresorhus"
}
},
"node_modules/yargs/node_modules/p-locate": {
"version": "4.1.0",
"resolved": "https://registry.npmmirror.com/p-locate/-/p-locate-4.1.0.tgz",
"integrity": "sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A==",
"license": "MIT",
"dependencies": {
"p-limit": "^2.2.0"
},
"engines": {
"node": ">=8"
}
},
"node_modules/yargs/node_modules/string-width": {
"version": "4.2.3",
"resolved": "https://registry.npmmirror.com/string-width/-/string-width-4.2.3.tgz",
"integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==",
"license": "MIT",
"dependencies": {
"emoji-regex": "^8.0.0",
"is-fullwidth-code-point": "^3.0.0",
"strip-ansi": "^6.0.1"
},
"engines": {
"node": ">=8"
}
},
"node_modules/yocto-queue": {
"version": "0.1.0",
"resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz",

View File

@@ -22,6 +22,7 @@
"driver.js": "^1.4.0",
"file-saver": "^2.0.5",
"pinia": "^2.1.7",
"qrcode": "^1.5.4",
"vue": "^3.4.0",
"vue-chartjs": "^5.3.0",
"vue-i18n": "^9.14.5",
@@ -32,6 +33,7 @@
"@types/file-saver": "^2.0.7",
"@types/mdx": "^2.0.13",
"@types/node": "^20.10.5",
"@types/qrcode": "^1.5.6",
"@typescript-eslint/eslint-plugin": "^7.18.0",
"@typescript-eslint/parser": "^7.18.0",
"@vitejs/plugin-vue": "^5.2.3",

175
frontend/pnpm-lock.yaml generated
View File

@@ -29,6 +29,9 @@ importers:
pinia:
specifier: ^2.1.7
version: 2.3.1(typescript@5.6.3)(vue@3.5.26(typescript@5.6.3))
qrcode:
specifier: ^1.5.4
version: 1.5.4
vue:
specifier: ^3.4.0
version: 3.5.26(typescript@5.6.3)
@@ -54,6 +57,9 @@ importers:
'@types/node':
specifier: ^20.10.5
version: 20.19.27
'@types/qrcode':
specifier: ^1.5.6
version: 1.5.6
'@typescript-eslint/eslint-plugin':
specifier: ^7.18.0
version: 7.18.0(@typescript-eslint/parser@7.18.0(eslint@8.57.1)(typescript@5.6.3))(eslint@8.57.1)(typescript@5.6.3)
@@ -1239,56 +1245,67 @@ packages:
resolution: {integrity: sha512-EHMUcDwhtdRGlXZsGSIuXSYwD5kOT9NVnx9sqzYiwAc91wfYOE1g1djOEDseZJKKqtHAHGwnGPQu3kytmfaXLQ==}
cpu: [arm]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-arm-musleabihf@4.54.0':
resolution: {integrity: sha512-+pBrqEjaakN2ySv5RVrj/qLytYhPKEUwk+e3SFU5jTLHIcAtqh2rLrd/OkbNuHJpsBgxsD8ccJt5ga/SeG0JmA==}
cpu: [arm]
os: [linux]
libc: [musl]
'@rollup/rollup-linux-arm64-gnu@4.54.0':
resolution: {integrity: sha512-NSqc7rE9wuUaRBsBp5ckQ5CVz5aIRKCwsoa6WMF7G01sX3/qHUw/z4pv+D+ahL1EIKy6Enpcnz1RY8pf7bjwng==}
cpu: [arm64]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-arm64-musl@4.54.0':
resolution: {integrity: sha512-gr5vDbg3Bakga5kbdpqx81m2n9IX8M6gIMlQQIXiLTNeQW6CucvuInJ91EuCJ/JYvc+rcLLsDFcfAD1K7fMofg==}
cpu: [arm64]
os: [linux]
libc: [musl]
'@rollup/rollup-linux-loong64-gnu@4.54.0':
resolution: {integrity: sha512-gsrtB1NA3ZYj2vq0Rzkylo9ylCtW/PhpLEivlgWe0bpgtX5+9j9EZa0wtZiCjgu6zmSeZWyI/e2YRX1URozpIw==}
cpu: [loong64]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-ppc64-gnu@4.54.0':
resolution: {integrity: sha512-y3qNOfTBStmFNq+t4s7Tmc9hW2ENtPg8FeUD/VShI7rKxNW7O4fFeaYbMsd3tpFlIg1Q8IapFgy7Q9i2BqeBvA==}
cpu: [ppc64]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-riscv64-gnu@4.54.0':
resolution: {integrity: sha512-89sepv7h2lIVPsFma8iwmccN7Yjjtgz0Rj/Ou6fEqg3HDhpCa+Et+YSufy27i6b0Wav69Qv4WBNl3Rs6pwhebQ==}
cpu: [riscv64]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-riscv64-musl@4.54.0':
resolution: {integrity: sha512-ZcU77ieh0M2Q8Ur7D5X7KvK+UxbXeDHwiOt/CPSBTI1fBmeDMivW0dPkdqkT4rOgDjrDDBUed9x4EgraIKoR2A==}
cpu: [riscv64]
os: [linux]
libc: [musl]
'@rollup/rollup-linux-s390x-gnu@4.54.0':
resolution: {integrity: sha512-2AdWy5RdDF5+4YfG/YesGDDtbyJlC9LHmL6rZw6FurBJ5n4vFGupsOBGfwMRjBYH7qRQowT8D/U4LoSvVwOhSQ==}
cpu: [s390x]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-x64-gnu@4.54.0':
resolution: {integrity: sha512-WGt5J8Ij/rvyqpFexxk3ffKqqbLf9AqrTBbWDk7ApGUzaIs6V+s2s84kAxklFwmMF/vBNGrVdYgbblCOFFezMQ==}
cpu: [x64]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-x64-musl@4.54.0':
resolution: {integrity: sha512-JzQmb38ATzHjxlPHuTH6tE7ojnMKM2kYNzt44LO/jJi8BpceEC8QuXYA908n8r3CNuG/B3BV8VR3Hi1rYtmPiw==}
cpu: [x64]
os: [linux]
libc: [musl]
'@rollup/rollup-openharmony-arm64@4.54.0':
resolution: {integrity: sha512-huT3fd0iC7jigGh7n3q/+lfPcXxBi+om/Rs3yiFxjvSxbSB6aohDFXbWvlspaqjeOh+hx7DDHS+5Es5qRkWkZg==}
@@ -1479,6 +1496,9 @@ packages:
'@types/parse-json@4.0.2':
resolution: {integrity: sha512-dISoDXWWQwUquiKsyZ4Ng+HX2KsPL7LyHKHQwgGFEA3IaKac4Obd+h2a/a6waisAoepJlBcx9paWqjA8/HVjCw==}
'@types/qrcode@1.5.6':
resolution: {integrity: sha512-te7NQcV2BOvdj2b1hCAHzAoMNuj65kNBMz0KBaxM6c3VGBOhU0dURQKOtH8CFNI/dsKkwlv32p26qYQTWoB5bw==}
'@types/react@19.2.7':
resolution: {integrity: sha512-MWtvHrGZLFttgeEj28VXHxpmwYbor/ATPYbBfSFZEIRK0ecCFLl2Qo55z52Hss+UV9CRN7trSeq1zbgx7YDWWg==}
@@ -1832,6 +1852,10 @@ packages:
resolution: {integrity: sha512-QOSvevhslijgYwRx6Rv7zKdMF8lbRmx+uQGx2+vDc+KI/eBnsy9kit5aj23AgGu3pa4t9AgwbnXWqS+iOY+2aA==}
engines: {node: '>= 6'}
camelcase@5.3.1:
resolution: {integrity: sha512-L28STB170nwWS63UjtlEOE3dldQApaJXZkOI1uMFfzf3rRuPegHaHesyee+YxQ+W6SvRDQV6UrdOdRiR153wJg==}
engines: {node: '>=6'}
caniuse-lite@1.0.30001761:
resolution: {integrity: sha512-JF9ptu1vP2coz98+5051jZ4PwQgd2ni8A+gYSN7EA7dPKIMf0pDlSUxhdmVOaV3/fYK5uWBkgSXJaRLr4+3A6g==}
@@ -1895,6 +1919,9 @@ packages:
classnames@2.5.1:
resolution: {integrity: sha512-saHYOzhIQs6wy2sVxTM6bUDsQO4F50V9RQ22qBpEdCW+I+/Wmke2HOl6lS6dTpdxVhb88/I6+Hs+438c3lfUow==}
cliui@6.0.0:
resolution: {integrity: sha512-t6wbgtoCXvAzst7QgXxJYqPt0usEfbgQdftEPbLL/cvv6HPE5VgvqCuAIDR0NgU52ds6rFwqrgakNLrHEjCbrQ==}
clsx@1.2.1:
resolution: {integrity: sha512-EcR6r5a8bj6pu3ycsa/E/cKVGuTgZJZdsyUYHOksG/UHIiKfjxzRxYJpyVBwYaQeOvghal9fcc4PidlgzugAQg==}
engines: {node: '>=6'}
@@ -2164,6 +2191,10 @@ packages:
supports-color:
optional: true
decamelize@1.2.0:
resolution: {integrity: sha512-z2S+W9X73hAUUki+N+9Za2lBlun89zigOyGrsax+KUQ6wKW4ZoWpEYBkGhQjwAjjDCkWxhY0VKEhk8wzY7F5cA==}
engines: {node: '>=0.10.0'}
decimal.js@10.6.0:
resolution: {integrity: sha512-YpgQiITW3JXGntzdUmyUR1V812Hn8T1YVXhCu+wO3OpS4eU9l4YdD3qjyiKdV6mvV29zapkMeD390UVEf2lkUg==}
@@ -2198,6 +2229,9 @@ packages:
didyoumean@1.2.2:
resolution: {integrity: sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw==}
dijkstrajs@1.0.3:
resolution: {integrity: sha512-qiSlmBq9+BCdCA/L46dw8Uy93mloxsPSbwnm5yrKn2vMPiy8KyAskTF6zuV/j5BMsmOGZDPs7KjU+mjb670kfA==}
dir-glob@3.0.1:
resolution: {integrity: sha512-WkrWp9GR4KXfKGYzOLmTuGVi1UWFfws377n9cc55/tb6DuqyF6pcQ5AbiHEshaDpY9v6oaSr2XCDidGmMwdzIA==}
engines: {node: '>=8'}
@@ -2424,6 +2458,10 @@ packages:
find-root@1.1.0:
resolution: {integrity: sha512-NKfW6bec6GfKc0SGx1e07QZY9PE99u0Bft/0rzSD5k3sO/vwkVUpDUKVm5Gpp5Ue3YfShPFTX2070tDs5kB9Ng==}
find-up@4.1.0:
resolution: {integrity: sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw==}
engines: {node: '>=8'}
find-up@5.0.0:
resolution: {integrity: sha512-78/PXT1wlLLDgTzDs7sjq9hzz0vXD+zn+7wypEe4fXQxCmdmqfGsEPQxmiCSQI3ajFV91bVSsvNtrJRiW6nGng==}
engines: {node: '>=10'}
@@ -2488,6 +2526,10 @@ packages:
function-bind@1.1.2:
resolution: {integrity: sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==}
get-caller-file@2.0.5:
resolution: {integrity: sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg==}
engines: {node: 6.* || 8.* || >= 10.*}
get-east-asian-width@1.4.0:
resolution: {integrity: sha512-QZjmEOC+IT1uk6Rx0sX22V6uHWVwbdbxf1faPqJ1QhLdGgsRGCZoyaQBm/piRdJy/D2um6hM1UP7ZEeQ4EkP+Q==}
engines: {node: '>=18'}
@@ -2856,6 +2898,10 @@ packages:
lit@3.3.2:
resolution: {integrity: sha512-NF9zbsP79l4ao2SNrH3NkfmFgN/hBYSQo90saIVI1o5GpjAdCPVstVzO1MrLOakHoEhYkrtRjPK6Ob521aoYWQ==}
locate-path@5.0.0:
resolution: {integrity: sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g==}
engines: {node: '>=8'}
locate-path@6.0.0:
resolution: {integrity: sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw==}
engines: {node: '>=10'}
@@ -3239,14 +3285,26 @@ packages:
resolution: {integrity: sha512-6IpQ7mKUxRcZNLIObR0hz7lxsapSSIYNZJwXPGeF0mTVqGKFIXj1DQcMoT22S3ROcLyY/rz0PWaWZ9ayWmad9g==}
engines: {node: '>= 0.8.0'}
p-limit@2.3.0:
resolution: {integrity: sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w==}
engines: {node: '>=6'}
p-limit@3.1.0:
resolution: {integrity: sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ==}
engines: {node: '>=10'}
p-locate@4.1.0:
resolution: {integrity: sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A==}
engines: {node: '>=8'}
p-locate@5.0.0:
resolution: {integrity: sha512-LaNjtRWUBY++zB5nE/NwcaoMylSPk+S+ZHNB1TzdbMJMny6dynpAGt7X/tl/QYq3TIeE6nxHppbo2LGymrG5Pw==}
engines: {node: '>=10'}
p-try@2.2.0:
resolution: {integrity: sha512-R4nPAVTAU0B9D35/Gk3uJf/7XYbQcyohSKdvAxIRSNghFl4e71hVoGnBNQz9cWaXxO2I10KTC+3jMdvvoKw6dQ==}
engines: {node: '>=6'}
package-json-from-dist@1.0.1:
resolution: {integrity: sha512-UEZIS3/by4OC8vL3P2dTXRETpebLI2NiI5vIrjaD/5UtrkFX/tNbwjTSRAGC/+7CAo2pIcBaRgWmcBBHcsaCIw==}
@@ -3341,6 +3399,10 @@ packages:
pkg-types@1.3.1:
resolution: {integrity: sha512-/Jm5M4RvtBFVkKWRu2BLUTNP8/M2a+UwuAX+ae4770q1qVGtfjG+WTCupoZixokjmHiry8uI+dlY8KXYV5HVVQ==}
pngjs@5.0.0:
resolution: {integrity: sha512-40QW5YalBNfQo5yRYmiw7Yz6TKKVr3h6970B2YE+3fQpsWcrbj1PzJgxeJ19DRQjhMbKPIuMY8rFaXc8moolVw==}
engines: {node: '>=10.13.0'}
points-on-curve@0.2.0:
resolution: {integrity: sha512-0mYKnYYe9ZcqMCWhUjItv/oHjvgEsfKvnUTg8sAtnHr3GVy7rGkXCb6d5cSyqrWqL4k81b9CPg3urd+T7aop3A==}
@@ -3421,6 +3483,11 @@ packages:
resolution: {integrity: sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg==}
engines: {node: '>=6'}
qrcode@1.5.4:
resolution: {integrity: sha512-1ca71Zgiu6ORjHqFBDpnSMTR2ReToX4l1Au1VFLyVeBTFavzQnv5JxMFr3ukHVKpSrSA2MCk0lNJSykjUfz7Zg==}
engines: {node: '>=10.13.0'}
hasBin: true
query-string@9.3.1:
resolution: {integrity: sha512-5fBfMOcDi5SA9qj5jZhWAcTtDfKF5WFdd2uD9nVNlbxVv1baq65aALy6qofpNEGELHvisjjasxQp7BlM9gvMzw==}
engines: {node: '>=18'}
@@ -3664,6 +3731,13 @@ packages:
remark-stringify@11.0.0:
resolution: {integrity: sha512-1OSmLd3awB/t8qdoEOMazZkNsfVTeY4fTsgzcQFdXNq8ToTN4ZGwrMnlda4K6smTFKD+GRV6O48i6Z4iKgPPpw==}
require-directory@2.1.1:
resolution: {integrity: sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==}
engines: {node: '>=0.10.0'}
require-main-filename@2.0.0:
resolution: {integrity: sha512-NKN5kMDylKuldxYLSUfrbo5Tuzh4hd+2E8NPPX02mZtn1VuREQToYe/ZdlJy+J3uCpfaiGF05e7B8W0iXbQHmg==}
requires-port@1.0.0:
resolution: {integrity: sha512-KigOCHcocU3XODJxsu8i/j8T9tzT4adHiecwORRQ0ZZFcp7ahwXuRU1m+yuO90C5ZUyGeGfocHDI14M3L3yDAQ==}
@@ -3739,6 +3813,9 @@ packages:
engines: {node: '>=10'}
hasBin: true
set-blocking@2.0.0:
resolution: {integrity: sha512-KiKBS8AnWGEyLzofFfmvKwpdPzqiy16LvQfK3yv/fVH7Bj13/wl3JSR1J+rfgRE9q7xUJK4qvgS8raSOeLUehw==}
set-value@2.0.1:
resolution: {integrity: sha512-JxHc1weCN68wRY0fhCoXpyK55m/XPHafOmK4UWD7m2CI14GMcFypt4w/0+NV5f/ZMby2F6S2wwA7fgynh9gWSw==}
engines: {node: '>=0.10.0'}
@@ -4263,6 +4340,9 @@ packages:
resolution: {integrity: sha512-De72GdQZzNTUBBChsXueQUnPKDkg/5A5zp7pFDuQAj5UFoENpiACU0wlCvzpAGnTkj++ihpKwKyYewn/XNUbKw==}
engines: {node: '>=18'}
which-module@2.0.1:
resolution: {integrity: sha512-iBdZ57RDvnOR9AGBhML2vFZf7h8vmBjhoaZqODJBFWHVtKkDmKuHai3cx5PgVMrX5YDNp27AofYbAwctSS+vhQ==}
which@2.0.2:
resolution: {integrity: sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==}
engines: {node: '>= 8'}
@@ -4285,6 +4365,10 @@ packages:
resolution: {integrity: sha512-OELeY0Q61OXpdUfTp+oweA/vtLVg5VDOXh+3he3PNzLGG/y0oylSOC1xRVj0+l4vQ3tj/bB1HVHv1ocXkQceFA==}
engines: {node: '>=0.8'}
wrap-ansi@6.2.0:
resolution: {integrity: sha512-r6lPcBGxZXlIcymEu7InxDMhdW0KDxpLgoFLcguasxCaJ/SOIZwINatK9KY/tf+ZrlywOKU0UDj3ATXUBfxJXA==}
engines: {node: '>=8'}
wrap-ansi@7.0.0:
resolution: {integrity: sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==}
engines: {node: '>=10'}
@@ -4324,10 +4408,21 @@ packages:
xmlchars@2.2.0:
resolution: {integrity: sha512-JZnDKK8B0RCDw84FNdDAIpZK+JuJw+s7Lz8nksI7SIuU3UXJJslUthsi+uWBUYOwPFwW7W7PRLRfUKpxjtjFCw==}
y18n@4.0.3:
resolution: {integrity: sha512-JKhqTOwSrqNA1NY5lSztJ1GrBiUodLMmIZuLiDaMRJ+itFd+ABVE8XBjOvIWL+rSqNDC74LCSFmlb/U4UZ4hJQ==}
yaml@1.10.2:
resolution: {integrity: sha512-r3vXyErRCYJ7wg28yvBY5VSoAF8ZvlcW9/BwUzEtUsjvX/DKs24dIkuwjtuprwJJHsbyUbLApepYTR1BN4uHrg==}
engines: {node: '>= 6'}
yargs-parser@18.1.3:
resolution: {integrity: sha512-o50j0JeToy/4K6OZcaQmW6lyXXKhq7csREXcDwk2omFPJEwUNOVtJKvmDr9EI1fAJZUyZcRF7kxGBWmRXudrCQ==}
engines: {node: '>=6'}
yargs@15.4.1:
resolution: {integrity: sha512-aePbxDmcYW++PaqBsJ+HYUFwCdv4LVvdnhBy78E57PIor8/OVvhMrADFFEDh8DHDFRv/O9i3lPhsENjO7QX0+A==}
engines: {node: '>=8'}
yocto-queue@0.1.0:
resolution: {integrity: sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==}
engines: {node: '>=10'}
@@ -5838,6 +5933,10 @@ snapshots:
'@types/parse-json@4.0.2': {}
'@types/qrcode@1.5.6':
dependencies:
'@types/node': 20.19.27
'@types/react@19.2.7':
dependencies:
csstype: 3.2.3
@@ -6321,6 +6420,8 @@ snapshots:
camelcase-css@2.0.1: {}
camelcase@5.3.1: {}
caniuse-lite@1.0.30001761: {}
ccount@2.0.1: {}
@@ -6395,6 +6496,12 @@ snapshots:
classnames@2.5.1: {}
cliui@6.0.0:
dependencies:
string-width: 4.2.3
strip-ansi: 6.0.1
wrap-ansi: 6.2.0
clsx@1.2.1: {}
clsx@2.1.1: {}
@@ -6668,6 +6775,8 @@ snapshots:
dependencies:
ms: 2.1.3
decamelize@1.2.0: {}
decimal.js@10.6.0: {}
decode-named-character-reference@1.2.0:
@@ -6694,6 +6803,8 @@ snapshots:
didyoumean@1.2.2: {}
dijkstrajs@1.0.3: {}
dir-glob@3.0.1:
dependencies:
path-type: 4.0.0
@@ -6978,6 +7089,11 @@ snapshots:
find-root@1.1.0: {}
find-up@4.1.0:
dependencies:
locate-path: 5.0.0
path-exists: 4.0.0
find-up@5.0.0:
dependencies:
locate-path: 6.0.0
@@ -7029,6 +7145,8 @@ snapshots:
function-bind@1.1.2: {}
get-caller-file@2.0.5: {}
get-east-asian-width@1.4.0: {}
get-intrinsic@1.3.0:
@@ -7521,6 +7639,10 @@ snapshots:
lit-element: 4.2.2
lit-html: 3.3.2
locate-path@5.0.0:
dependencies:
p-locate: 4.1.0
locate-path@6.0.0:
dependencies:
p-locate: 5.0.0
@@ -8194,14 +8316,24 @@ snapshots:
type-check: 0.4.0
word-wrap: 1.2.5
p-limit@2.3.0:
dependencies:
p-try: 2.2.0
p-limit@3.1.0:
dependencies:
yocto-queue: 0.1.0
p-locate@4.1.0:
dependencies:
p-limit: 2.3.0
p-locate@5.0.0:
dependencies:
p-limit: 3.1.0
p-try@2.2.0: {}
package-json-from-dist@1.0.1: {}
package-manager-detector@1.6.0: {}
@@ -8284,6 +8416,8 @@ snapshots:
mlly: 1.8.0
pathe: 2.0.3
pngjs@5.0.0: {}
points-on-curve@0.2.0: {}
points-on-path@0.2.1:
@@ -8352,6 +8486,12 @@ snapshots:
punycode@2.3.1: {}
qrcode@1.5.4:
dependencies:
dijkstrajs: 1.0.3
pngjs: 5.0.0
yargs: 15.4.1
query-string@9.3.1:
dependencies:
decode-uri-component: 0.4.1
@@ -8703,6 +8843,10 @@ snapshots:
mdast-util-to-markdown: 2.1.2
unified: 11.0.5
require-directory@2.1.1: {}
require-main-filename@2.0.0: {}
requires-port@1.0.0: {}
reselect@5.1.1: {}
@@ -8788,6 +8932,8 @@ snapshots:
semver@7.7.3: {}
set-blocking@2.0.0: {}
set-value@2.0.1:
dependencies:
extend-shallow: 2.0.1
@@ -9298,6 +9444,8 @@ snapshots:
tr46: 5.1.1
webidl-conversions: 7.0.0
which-module@2.0.1: {}
which@2.0.2:
dependencies:
isexe: 2.0.0
@@ -9313,6 +9461,12 @@ snapshots:
word@0.3.0: {}
wrap-ansi@6.2.0:
dependencies:
ansi-styles: 4.3.0
string-width: 4.2.3
strip-ansi: 6.0.1
wrap-ansi@7.0.0:
dependencies:
ansi-styles: 4.3.0
@@ -9345,8 +9499,29 @@ snapshots:
xmlchars@2.2.0: {}
y18n@4.0.3: {}
yaml@1.10.2: {}
yargs-parser@18.1.3:
dependencies:
camelcase: 5.3.1
decamelize: 1.2.0
yargs@15.4.1:
dependencies:
cliui: 6.0.0
decamelize: 1.2.0
find-up: 4.1.0
get-caller-file: 2.0.5
require-directory: 2.1.1
require-main-filename: 2.0.0
set-blocking: 2.0.0
string-width: 4.2.3
which-module: 2.0.1
y18n: 4.0.3
yargs-parser: 18.1.3
yocto-queue@0.1.0: {}
zustand@3.7.2(react@19.2.3):

View File

@@ -13,6 +13,9 @@ export interface SystemSettings {
registration_enabled: boolean
email_verify_enabled: boolean
promo_code_enabled: boolean
password_reset_enabled: boolean
totp_enabled: boolean // TOTP 双因素认证
totp_encryption_key_configured: boolean // TOTP 加密密钥是否已配置
// Default settings
default_balance: number
default_concurrency: number
@@ -66,6 +69,8 @@ export interface UpdateSettingsRequest {
registration_enabled?: boolean
email_verify_enabled?: boolean
promo_code_enabled?: boolean
password_reset_enabled?: boolean
totp_enabled?: boolean // TOTP 双因素认证
default_balance?: number
default_concurrency?: number
site_name?: string

View File

@@ -17,7 +17,7 @@ import type {
* List all subscriptions with pagination
* @param page - Page number (default: 1)
* @param pageSize - Items per page (default: 20)
* @param filters - Optional filters (status, user_id, group_id)
* @param filters - Optional filters (status, user_id, group_id, sort_by, sort_order)
* @returns Paginated list of subscriptions
*/
export async function list(
@@ -27,6 +27,8 @@ export async function list(
status?: 'active' | 'expired' | 'revoked'
user_id?: number
group_id?: number
sort_by?: string
sort_order?: 'asc' | 'desc'
},
options?: {
signal?: AbortSignal

View File

@@ -11,9 +11,23 @@ import type {
CurrentUserResponse,
SendVerifyCodeRequest,
SendVerifyCodeResponse,
PublicSettings
PublicSettings,
TotpLoginResponse,
TotpLogin2FARequest
} from '@/types'
/**
* Login response type - can be either full auth or 2FA required
*/
export type LoginResponse = AuthResponse | TotpLoginResponse
/**
* Type guard to check if login response requires 2FA
*/
export function isTotp2FARequired(response: LoginResponse): response is TotpLoginResponse {
return 'requires_2fa' in response && response.requires_2fa === true
}
/**
* Store authentication token in localStorage
*/
@@ -38,11 +52,28 @@ export function clearAuthToken(): void {
/**
* User login
* @param credentials - Username and password
* @param credentials - Email and password
* @returns Authentication response with token and user data, or 2FA required response
*/
export async function login(credentials: LoginRequest): Promise<LoginResponse> {
const { data } = await apiClient.post<LoginResponse>('/auth/login', credentials)
// Only store token if 2FA is not required
if (!isTotp2FARequired(data)) {
setAuthToken(data.access_token)
localStorage.setItem('auth_user', JSON.stringify(data.user))
}
return data
}
/**
* Complete login with 2FA code
* @param request - Temp token and TOTP code
* @returns Authentication response with token and user data
*/
export async function login(credentials: LoginRequest): Promise<AuthResponse> {
const { data } = await apiClient.post<AuthResponse>('/auth/login', credentials)
export async function login2FA(request: TotpLogin2FARequest): Promise<AuthResponse> {
const { data } = await apiClient.post<AuthResponse>('/auth/login/2fa', request)
// Store token and user data
setAuthToken(data.access_token)
@@ -133,8 +164,61 @@ export async function validatePromoCode(code: string): Promise<ValidatePromoCode
return data
}
/**
* Forgot password request
*/
export interface ForgotPasswordRequest {
email: string
turnstile_token?: string
}
/**
* Forgot password response
*/
export interface ForgotPasswordResponse {
message: string
}
/**
* Request password reset link
* @param request - Email and optional Turnstile token
* @returns Response with message
*/
export async function forgotPassword(request: ForgotPasswordRequest): Promise<ForgotPasswordResponse> {
const { data } = await apiClient.post<ForgotPasswordResponse>('/auth/forgot-password', request)
return data
}
/**
* Reset password request
*/
export interface ResetPasswordRequest {
email: string
token: string
new_password: string
}
/**
* Reset password response
*/
export interface ResetPasswordResponse {
message: string
}
/**
* Reset password with token
* @param request - Email, token, and new password
* @returns Response with message
*/
export async function resetPassword(request: ResetPasswordRequest): Promise<ResetPasswordResponse> {
const { data } = await apiClient.post<ResetPasswordResponse>('/auth/reset-password', request)
return data
}
export const authAPI = {
login,
login2FA,
isTotp2FARequired,
register,
getCurrentUser,
logout,
@@ -144,7 +228,9 @@ export const authAPI = {
clearAuthToken,
getPublicSettings,
sendVerifyCode,
validatePromoCode
validatePromoCode,
forgotPassword,
resetPassword
}
export default authAPI

Some files were not shown because too many files have changed in this diff Show More