Compare commits

..

87 Commits

Author SHA1 Message Date
shaw
cadca752c4 修复SSE流式响应中usage数据被覆盖的问题 2026-01-28 18:36:21 +08:00
Wesley Liddick
edf215e6fd Merge pull request #409 from DuckyProject/feat/purchase-subscription-iframe
feat(purchase): 增加购买订阅 iframe 页面与配置
2026-01-28 17:28:47 +08:00
shaw
e12dd079fd 修复调度器空缓存导致的竞态条件bug
当新分组创建后立即绑定账号时,调度器会错误地将空快照视为有效缓存命中,
导致返回没有可调度的账号。现在空快照会触发数据库回退查询。
2026-01-28 17:26:32 +08:00
ducky
04a509d45e feat(purchase): 增加购买订阅 iframe 页面与配置
- 新增 /purchase 页面(iframe + 新窗口兜底)

- 管理员系统设置可配置开关与URL

- 非 simple mode 才在侧边栏展示入口
2026-01-28 13:54:32 +08:00
Wesley Liddick
269a659200 Merge pull request #406 from geminiwen/main
fix(openai-oauth): 改进错误处理和代理支持
2026-01-28 13:53:44 +08:00
Wesley Liddick
2c31bf46b5 Merge pull request #401 from slovx2/heihuzi_main
feat(gemini): 为 Gemini 原生平台添加图片计费支持
2026-01-28 13:51:14 +08:00
Gemini Wen
8f6639f825 fix(response): add nil check for c.Request in error logging
Prevents panic when ErrorFrom is called in test contexts where
gin.CreateTestContext doesn't set up an HTTP request.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 19:26:44 +08:00
Gemini Wen
fc17d9d7df chore: bump version to 0.1.61 and fix tests
- Update VERSION from 0.1.46 to 0.1.61
- Remove ForceHTTP2 tests for OpenAI OAuth client (ForceHTTP2 was removed)
- Update createOpenAIReqClient test to use new single-arg signature

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 19:22:45 +08:00
Gemini Wen
ab092e88a8 fix(openai-oauth): 改进错误处理和代理支持
- 使用 ApplicationError 返回详细错误信息到前端
- 添加 User-Agent: codex-cli/0.91.0
- 移除 ForceHTTP2 以兼容 HTTP 代理
- 修复代理获取失败时静默忽略的问题
- 500 错误时记录完整错误日志

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 19:13:01 +08:00
shaw
56a1e29cdd fix(gateway): 修复 SSE 流式响应 usage 统计错误
message_delta 应完全覆盖 message_start 的 usage 数据,
而非仅在值为 0 时才更新。
2026-01-27 09:16:34 +08:00
song
0059a232a6 feat(gemini): 为 Gemini 原生平台添加图片计费支持
对齐 Antigravity 平台的图片计费逻辑:
- 添加 extractImageSize() 方法提取图片尺寸
- Forward() 和 ForwardNative() 返回 ImageCount/ImageSize
- 支持分组自定义图片价格和倍率
2026-01-26 20:51:40 +08:00
shaw
45676fdc8d fix(ci): 转义 Telegram 消息中的 Markdown 特殊字符
修复发布通知发送失败的问题,原因是 tag message 中包含未闭合的
Markdown 格式标记(如 user_id 中的 _ 被解析为斜体开始)导致
Telegram API 返回解析错误。

添加 sed 命令转义 _、*、` 和 [ 字符,避免被 Telegram Markdown
解析器错误处理。
2026-01-26 11:07:08 +08:00
Wesley Liddick
e32c5f534f Merge pull request #386 from IanShaw027/fix/openai-usage-limit-reset-time
fix(ratelimit): 修复 OpenAI usage_limit_reached 错误的重置时间解析
2026-01-26 10:22:42 +08:00
shaw
426d691c95 fix(urlvalidator): 移除 ValidateURLFormat 返回值的末尾斜杠
修复 API Key 账号 base_url 末尾带斜杠时导致的双斜杠问题
2026-01-26 10:21:41 +08:00
shaw
e9a4c8ab97 docs: 修改演示站点域名 2026-01-26 10:04:44 +08:00
ianshaw
a55cfebd09 fix(ratelimit): 修复 OpenAI usage_limit_reached 错误的重置时间解析
- 问题:OpenAI 的 usage_limit_reached 错误(需 37 小时重置)被错误地设置为 5 分钟
- 原因:handle429 只检查 Anthropic 响应头,没有解析 OpenAI 响应体中的 resets_in_seconds
- 修复:新增 parseOpenAIRateLimitResetTime 函数解析 OpenAI 响应体
- 影响:避免调度器不断尝试已达配额上限的账户
2026-01-26 09:57:44 +08:00
Wesley Liddick
34cc02f8c7 Merge pull request #393 from IanShaw027/fix/gemini-thought-signature-preserve
fix(gemini): 修复 thoughtSignature 跨账号验证错误
2026-01-26 09:23:46 +08:00
Wesley Liddick
624d9fddb7 Merge pull request #391 from geminiwen/main
fix(subscription): 修复订阅调整逻辑,已过期订阅从当前时间计算
2026-01-26 09:23:29 +08:00
Wesley Liddick
47fbe43324 Merge pull request #385 from DDZS987/fix/oauth-token-refresh-missing-project-id-retry
fix(oauth): 修复 OAuth 令牌刷新时 missing_project_id 误报问题
2026-01-26 09:22:48 +08:00
shaw
1245f07a2d feat(auth): 实现 TOTP 双因素认证功能
新增功能:
- 支持 Google Authenticator 等应用进行 TOTP 二次验证
- 用户可在个人设置中启用/禁用 2FA
- 登录时支持 TOTP 验证流程
- 管理后台可全局开关 TOTP 功能

安全增强:
- TOTP 密钥使用 AES-256-GCM 加密存储
- 添加 TOTP_ENCRYPTION_KEY 配置项,必须手动配置才能启用功能
- 防止服务重启导致加密密钥变更使用户无法登录
- 验证失败次数限制,防止暴力破解

配置说明:
- Docker 部署:在 .env 中设置 TOTP_ENCRYPTION_KEY
- 非 Docker 部署:在 config.yaml 中设置 totp.encryption_key
- 生成密钥命令:openssl rand -hex 32
2026-01-26 09:19:53 +08:00
ianshaw
839975b0cf feat(gemini): 支持 Gemini CLI 粘性会话与跨账号 thoughtSignature 清理
## 问题背景

1. Gemini CLI 没有明确的会话标识(如 Claude Code 的 metadata.user_id)
2. thoughtSignature 与具体上游账号强绑定,跨账号使用会导致 400 错误
3. 粘性会话切换账号或 cache 丢失时,旧签名会导致请求失败

## 解决方案

### 1. Gemini CLI 会话标识提取

- 从 `x-gemini-api-privileged-user-id` header 和请求体中的 tmp 目录哈希生成会话标识
- 组合策略:SHA256(privileged-user-id + ":" + tmp_dir_hash)
- 正则提取:`/\.gemini/tmp/([A-Fa-f0-9]{64})`

### 2. 跨账号 thoughtSignature 清理

实现三种场景的智能清理:

1. **Cache 命中 + 账号切换**
   - 粘性会话绑定的账号与当前选择的账号不同时清理

2. **同一请求内 failover 切换**
   - 通过 sessionBoundAccountID 跟踪,检测重试时的账号切换

3. **Gemini CLI + Cache 未命中 + 含签名**
   - 预防性清理,避免 cache 丢失后首次转发就 400
   - 仅对 Gemini CLI 请求且请求体包含 thoughtSignature 时触发

## 修改内容

### backend/internal/handler/gemini_v1beta_handler.go
- 添加 `extractGeminiCLISessionHash` 函数提取 Gemini CLI 会话标识
- 添加 `isGeminiCLIRequest` 函数识别 Gemini CLI 请求
- 实现账号切换检测与 thoughtSignature 清理逻辑
- 添加 `geminiCLITmpDirRegex` 正则表达式

### backend/internal/service/gateway_service.go
- 添加 `GetCachedSessionAccountID` 方法查询粘性会话绑定的账号 ID

### backend/internal/service/gemini_native_signature_cleaner.go (新增)
- 实现 `CleanGeminiNativeThoughtSignatures` 函数
- 递归清理 JSON 中的所有 thoughtSignature 字段
- 支持任意 JSON 顶层类型(object/array)

### backend/internal/handler/gemini_cli_session_test.go (新增)
- 测试 Gemini CLI 会话哈希提取逻辑
- 测试 tmp 目录正则匹配
- 覆盖有/无 privileged-user-id 的场景

## 影响范围

- 修复 Gemini CLI 多轮对话时账号切换导致的 400 错误
- 提高粘性会话的稳定性和容错能力
- 不影响其他客户端(Claude Code 等)的会话标识生成

## 测试

- 单元测试:go test -tags=unit ./internal/handler -run TestExtractGeminiCLISessionHash
- 单元测试:go test -tags=unit ./internal/handler -run TestGeminiCLITmpDirRegex
- 编译验证:go build ./cmd/server
2026-01-26 04:40:38 +08:00
ianshaw
8c1233393f fix(antigravity): 修复 Gemini 模型 thoughtSignature 被错误覆盖的问题
## 问题描述

在使用 Gemini 模型(gemini-3-flash-preview)时,出现 400 错误:
"Unable to submit request because Thought signature is not valid"

## 根本原因

在 `request_transformer.go` 的 `buildParts()` 函数中:
- 对于 `tool_use` 和 `thinking` 块,当 `allowDummyThought=true`(Gemini 模型)时
- 代码会无条件将客户端传入的真实 `thoughtSignature` 覆盖成 dummy 值
- 导致 Gemini API 验证签名失败(签名与上下文不匹配)

## 修复方案

修改 signature 处理逻辑:
1. **优先透传真实 signature**:如果客户端提供了有效的 signature,保留它
2. **缺失时才使用 dummy**:只有在 signature 缺失且是 Gemini 模型时,才使用 dummy signature
3. **Claude 模型特殊处理**:将 dummy signature 视为缺失,避免透传到需要真实签名的链路

## 修改内容

### request_transformer.go
- `thinking` 块(第 367-386 行):优先透传真实 signature
- `tool_use` 块(第 411-418 行):优先透传真实 signature

### request_transformer_test.go
- 修改测试用例名称,反映新的行为
- 新增测试用例验证"缺失时才使用 dummy"的逻辑

## 影响范围

- 修复 Gemini 模型在多轮对话中使用 tool_use 时的签名验证错误
- 不影响 Claude 模型的现有行为
- 提高跨账号切换时的稳定性

相关问题:#[issue_number]
2026-01-26 03:06:45 +08:00
Gemini Wen
9cdb0568cc fix(subscription): 修复订阅调整逻辑,已过期订阅从当前时间计算
- 已过期订阅延长时,从当前时间开始增加天数
- 已过期订阅不允许负向调整(缩短)
- 未过期订阅保持原逻辑,从原过期时间增减

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-25 18:12:15 +08:00
shaw
74e05b83ea fix(ratelimit): 修复 OpenAI 账号限流倒计时计算错误
- 解析 x-codex-* 响应头获取正确的重置时间
- 7d 限制用尽时使用 codex_7d_reset_after_seconds
- 提取 Normalize() 方法统一窗口规范化逻辑
2026-01-25 13:32:08 +08:00
Ubuntu
4ded9e7d49 fix(oauth): 为初始 OAuth 授权添加 LoadCodeAssist 重试机制
问题:
- 初始授权时 LoadCodeAssist 没有重试机制,失败后直接跳过
- 导致账号创建时就可能缺失 project_id
- 之后每次刷新都因为 missing_project_id 报错

修复:
- 统一使用 loadProjectIDWithRetry 方法(最多4次尝试)
- 初始授权和token刷新使用相同的重试策略
- 保留原注释说明部分账户可能没有 project_id

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-24 23:41:36 +08:00
Ubuntu
716272a1e2 fix(oauth): 彻底修复 project_id 丢失问题
根本原因:
- BuildAccountCredentials 只在 project_id 非空时才添加该字段
- LoadCodeAssist 失败时返回空字符串 → 新 credentials 不包含 project_id 键
- 普通合并逻辑只保留新 credentials 中不存在的键,无法覆盖空值

解决方案:
1. 在合并后特殊处理 project_id:如果新值为空但旧值非空,保留旧值
2. LoadCodeAssist 失败不再返回错误,只记录警告
3. Token 刷新成功(access_token 已更新)就不应标记账户为 error

改进效果:
- 即使 LoadCodeAssist 连续失败,已有的 project_id 也不会丢失
- 避免因临时网络问题将账户误标记为不可用
- 允许在下次刷新时自动重试获取 project_id

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-24 23:04:48 +08:00
shaw
9cc8352593 feat(auth): 密码重置邮件队列化与限流优化
- 邮件发送改为异步队列处理,避免并发导致发送失败
- 新增 Email 维度限流(30秒冷却期),防止邮件轰炸
- Token 验证使用常量时间比较,防止时序攻击
- 重构代码消除冗余,提取公共验证逻辑
2026-01-24 22:55:28 +08:00
shaw
43a1031e38 fix(test): 修复订阅相关测试失败问题
1. 使用未来日期(2099年)作为测试订阅的过期时间,避免
   normalizeSubscriptionStatus 将测试数据标记为过期
2. 修复 List 方法调用参数不足的问题(新增 sortBy/sortOrder 参数)
2026-01-24 21:10:02 +08:00
Wesley Liddick
a5547b2f30 Merge pull request #380 from DDZS987/fix/oauth-token-refresh-missing-project-id-retry
fix(oauth): 修复 OAuth 令牌刷新时 missing_project_id 误报问题
2026-01-24 20:29:43 +08:00
shaw
b0aa23540b feat(subscription): 订阅过期状态自动更新与服务端排序
- 新增 SubscriptionExpiryService 定时任务,每分钟更新过期订阅状态
- 订阅列表支持服务端排序(按过期时间、状态、创建时间)
- 实时显示正确的过期状态,无需等待定时任务
- 允许对已过期订阅进行续期操作
- DataTable 组件支持 serverSideSort 模式
2026-01-24 20:26:01 +08:00
Ubuntu
ffaa6c4a17 fix(oauth): 修复 OAuth 令牌刷新时 missing_project_id 误报问题
问题描述:
在网络波动环境下,LoadCodeAssist 临时失败会被错误地标记为
"missing_project_id" 不可重试错误,导致账户被禁用。但实际上
账户配置正确,手动刷新后即可恢复。

解决方案:
1. 为 LoadCodeAssist 增加重试机制(最多4次,指数退避)
2. 区分真正的配置缺失和临时网络故障:
   - 如果之前有 project_id,本次获取失败只保留旧值,不报错
   - 只有从未获取过 project_id 且本次也失败时,才标记为缺失
3. 优化错误判断逻辑,避免误报

改进效果:
- 提高在复杂网络环境(如家宽代理)下的鲁棒性
- 减少因临时网络波动导致的服务中断
- 保持真正配置错误的检测能力

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-24 17:44:56 +08:00
Wesley Liddick
fbf72f0ec4 Merge pull request #377 from lynoot/fix/non-streaming-chunk-aggregation
fix(gateway): aggregate all text chunks in non-streaming Gemini responses
2026-01-24 08:49:10 +08:00
lynoot
909b8a8f9c fix(gateway): aggregate all text chunks in non-streaming Gemini responses
Previously, collectGeminiSSE() only returned the last chunk received
from the upstream streaming response when converting to non-streaming.
This caused incomplete responses where only the final text fragment
was returned to clients.

For example, a request asking to "count from 1 to 10" would only
return "\n" (the last chunk) instead of "1\n2\n3\n...\n10\n".

This was especially problematic for JSON structured output where
the opening brace "{" from the first chunk was lost, resulting
in invalid JSON like: colors": ["red", "blue"]}

The fix:
- Collect all text parts from each SSE chunk into a slice
- Merge all collected text parts into the final response
- Reuse the same pattern as handleGeminiStreamToNonStreaming
  in antigravity_gateway_service.go

Fixes: non-streaming responses returning incomplete text
Fixes: structured output (JSON schema) returning invalid JSON
2026-01-23 13:54:09 +00:00
shaw
4a0fe3b143 feat(gateway): 增加 SUGGESTION MODE 请求拦截
扩展现有的预热请求拦截功能,新增对 SUGGESTION MODE 请求的拦截:
- 检测 messages 最后一条 user 消息是否以 [SUGGESTION MODE: 开头
- 拦截后返回空内容响应,节省 token 消耗
- 重构检测逻辑,合并为单一函数,只解析一次 JSON
2026-01-23 16:57:25 +08:00
shaw
a1292fac81 feat(oauth): 支持Anthropic的Team账号使用sk授权 2026-01-23 16:30:12 +08:00
shaw
7f98be4f91 fix(oauth): 更新 Anthropic 账号 OAuth 参数,同步最新客户端 2026-01-23 16:00:42 +08:00
shaw
fd73b8875d feat(frontend): 优化账号限流状态显示,直接展示倒计时 2026-01-23 15:48:25 +08:00
shaw
f9ab1daa3c feat: 保存并显示OAuth账号邮箱地址 2026-01-23 15:17:47 +08:00
shaw
d27b847442 fix(test): 修复测试中引用不存在的函数名
测试文件引用了 IsTokenVersionStale 函数,但实际函数名为 CheckTokenVersion,导致 CI 构建失败
2026-01-23 10:58:30 +08:00
shaw
dac6bc2228 fix(token-cache): 版本过时时使用最新token而非旧token
上次修复(2665230)只阻止了写入缓存,但仍返回旧token导致403。
现在版本过时时直接使用DB中的最新token返回。

- 重构 IsTokenVersionStale 为 CheckTokenVersion,返回最新account
- 消除重复DB查询,复用版本检查时已获取的account
2026-01-23 10:29:52 +08:00
Wesley Liddick
4bd3dbf2ce Merge pull request #354 from DuckyProject/fix/frontend-table
feat(frontend): 账号表格默认排序/持久化 + 自动刷新 + 更多菜单外部关闭
2026-01-22 21:17:48 +08:00
Wesley Liddick
226df1c23a Merge pull request #358 from 0xff26b9a8/main
fix(antigravity): 修复非流式 Claude To Antigravity 响应内容为空的问题
2026-01-22 21:16:54 +08:00
shaw
2665230a09 fix(token-cache): 修复异步刷新与请求线程的缓存竞态条件
- 新增 _token_version 版本号机制,防止过期 token 污染缓存
- TokenRefreshService 刷新成功后写入版本号并清除缓存
- TokenProvider 写入缓存前检查版本,过时则跳过
- ClearError 时同步清除 token 缓存
2026-01-22 21:09:28 +08:00
0xff26b9a8
4f0c2b794c style: gofmt antigravity_gateway_service.go 2026-01-22 14:38:55 +08:00
0xff26b9a8
e756064c19 fix(antigravity): 修复非流式 Claude To Antigravity 响应内容为空的问题
- 修复 TransformGeminiToClaude 的 JSON 解析逻辑,当 V1InternalResponse
  解析成功但 candidates 为空时,尝试直接解析为 GeminiResponse 格式
- 修复 handleClaudeStreamToNonStreaming 收集流式响应的逻辑,累积所有
  chunks 的内容而不是只保留最后一个(最后一个 chunk 通常 text 为空)
- 新增 mergeCollectedPartsToResponse 函数,合并所有类型的 parts
  (text、thinking、functionCall、inlineData),保持原始顺序
- 连续的普通 text parts 合并为一个,thinking/functionCall/inlineData 保持原样
2026-01-22 14:17:59 +08:00
Wesley Liddick
17dfb0af01 Merge pull request #346 from 0xff26b9a8/main
refactor(antigravity): 提取并同步 Schema 清理逻辑至 schema_cleaner.go
2026-01-22 08:46:11 +08:00
ducky
ff74f517df feat(frontend): 账号表格默认排序/持久化 + 自动刷新 + 更多菜单外部关闭 2026-01-21 22:43:25 +08:00
0xff26b9a8
477a9a180f fix: 修复 schema 清理逻辑 2026-01-21 10:58:39 +08:00
0xff26b9a8
da48df06d2 refactor(antigravity): 提取并同步 Schema 清理逻辑至 schema_cleaner.go
主要变更:
1. 重构代码结构:
   - 将 CleanJSONSchema 及其相关辅助函数从 request_transformer.go 提取到独立的 schema_cleaner.go 文件中,实现逻辑解耦。

2. 逻辑优化与修正:
   - 参考 Antigravity-Manager (json_schema.rs) 的实现逻辑,修正了 Schema 清洗策略。
2026-01-20 23:41:53 +08:00
Wesley Liddick
39fad63ccf Merge pull request #345 from whoismonay/main
mod(frontend): 管理员订阅/兑换码分组选择展示备注
2026-01-20 16:22:53 +08:00
Wesley Liddick
5602d02b1b Merge pull request #343 from mt21625457/main
fix(调度): 完善粘性会话清理与账号调度刷新 和 启用 OpenAI OAuth HTTP/2 并修复清理任务 lint
2026-01-20 16:05:53 +08:00
shaw
81989eed1c test: add promo_code_enabled to API contract test 2026-01-20 16:02:49 +08:00
shaw
192efb84a0 feat(promo-code): complete promo code feature implementation
- Add promo_code_enabled field to SystemSettings and PublicSettings DTOs
- Add promo code validation in registration flow
- Add admin settings UI for promo code configuration
- Add i18n translations for promo code feature
2026-01-20 15:56:26 +08:00
shaw
8672347f93 fix(settings): add missing promo_code_enabled field in public settings API
The field was defined in DTO but not mapped in handler response.
2026-01-20 15:49:57 +08:00
yangjianbo
5e5d4a513b feat: 移动镜像脚本位置 2026-01-20 15:11:27 +08:00
墨颜
88b6358472 build(frontend): vite 加载开发环境变量
- 使用 loadEnv 读取 VITE_DEV_PROXY_TARGET/VITE_DEV_PORT
- 注入 public settings 与 dev proxy 使用同源后端地址
2026-01-20 15:04:18 +08:00
墨颜
dd8d5e2c42 mod(frontend): 订阅分组下拉显示备注
- 订阅管理/分配订阅:下拉项展示分组备注
- 兑换码/订阅类型:下拉项展示分组备注
- 复用 GroupOptionItem/GroupBadge 保持一致体验
2026-01-20 15:02:48 +08:00
shaw
d91e2328fb test(tlsfingerprint): add multi-profile fingerprint verification test
Add TestAllProfiles to verify TLS fingerprint configurations from
config.yaml against tls.peet.ws. Tests check JA4 cipher hash (stable
part) to validate fingerprint spoofing works correctly.
2026-01-20 14:53:15 +08:00
yangjianbo
2a16735495 fix(测试): 修复 SelectAccountWithLoadAwareness 调用缺少参数
为 gateway_multiplatform_test.go 中的 SelectAccountWithLoadAwareness
调用添加缺少的第6个参数 metadataUserID,修复 CI 测试编译错误。

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 14:16:46 +08:00
yangjianbo
292f25f9ca Merge branch 'main' of https://github.com/mt21625457/aicodex2api 2026-01-20 14:02:08 +08:00
yangjianbo
c92e37775a Merge branch 'dev' 2026-01-20 13:57:08 +08:00
yangjianbo
f6ed3d1456 Merge branch 'test' into dev 2026-01-20 11:59:13 +08:00
yangjianbo
84686753e8 Merge branch 'test' of https://github.com/mt21625457/aicodex2api into test 2026-01-20 11:51:44 +08:00
yangjianbo
91f01309da fix(调度): 完善粘性会话清理与账号调度刷新
- Update/BulkUpdate 按不可调度字段触发缓存刷新
- GatewayCache 支持多前缀会话键清理
- 模型路由与混合调度优化粘性会话处理
- 补充调度与缓存相关测试覆盖
2026-01-20 11:40:55 +08:00
yangjianbo
57a1fc9d33 style(仓储): 格式化账号仓储
- gofmt 修正 lint 格式提示
2026-01-20 11:30:36 +08:00
shaw
c95a864975 docs: add TLS fingerprint tool link 2026-01-20 11:30:10 +08:00
yangjianbo
7a83db6180 fix(调度): 完善粘性会话清理与账号调度刷新
- Update/BulkUpdate 按不可调度字段触发缓存刷新
- GatewayCache 支持多前缀会话键清理
- 模型路由与混合调度优化粘性会话处理
- 补充调度与缓存相关测试覆盖
2026-01-20 11:19:32 +08:00
Wesley Liddick
a8513da7ff Merge pull request #335 from geminiwen/main
feat(subscription): 支持调整订阅时长(延长/缩短)
2026-01-20 08:52:19 +08:00
Gemini Wen
53534d3956 style(admin): 统一列设置按钮位置到刷新按钮右侧
将订阅管理和账号管理页面的列设置按钮移动到刷新按钮右侧,
与用户管理页面保持一致的布局。

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 00:18:51 +08:00
Gemini Wen
cc07a0e295 feat(subscription): 支持调整订阅时长(延长/缩短)
- 将"延长订阅"功能改为"调整订阅",支持正数延长、负数缩短
- 后端验证:调整天数范围 -36500 到 36500,缩短后剩余天数必须 > 0
- 前端同步更新界面文案和验证逻辑

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 00:11:30 +08:00
Wesley Liddick
e7bc62500b Merge pull request #333 from whoismonay/main
fix: 普通用户接口移除管理员敏感字段透传
2026-01-19 21:35:51 +08:00
墨颜
c8fb9ef3a5 style(dto): 修复 gofmt 格式问题
- 修复 mappers.go 中 Notes 字段的对齐格式
- 修复 types.go 中 BulkAssignResult 结构体字段的 JSON tag 对齐

修复 golangci-lint 检查中的 gofmt 格式错误
2026-01-19 21:26:30 +08:00
Wesley Liddick
eb5e6214bc Merge pull request #332 from geminiwen/main
fix: 修复手动刷新令牌后缓存未清除导致403错误的问题
2026-01-19 20:53:06 +08:00
shaw
568d6ee10e fix: 修复测试缺少新增设置字段 2026-01-19 20:52:05 +08:00
墨颜
6aef1af76e fix(redeem): 用户兑换历史不返回备注
- 用户侧 RedeemCode DTO 移除 notes 字段,避免泄露内部备注\n- 新增 AdminRedeemCode,并调整管理员兑换码接口继续返回 notes\n- 增加 /api/v1/redeem/history 契约测试,确保用户侧响应不包含 notes
2026-01-19 20:09:35 +08:00
Gemini Wen
a54852e129 fix: 补充API契约测试中缺失的hide_ccs_import_button字段
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 20:06:36 +08:00
Gemini Wen
668118def1 fix: 修复遗漏的测试文件更新和lint错误
- 更新api_contract_test.go以匹配NewAccountHandler新增的tokenCacheInvalidator参数
- 修复errcheck lint错误,显式忽略c.Error()返回值

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 19:58:09 +08:00
yangjianbo
73e6b160f8 feat(认证): 启用 OpenAI OAuth HTTP/2 并修复清理任务 lint
为共享 req 客户端增加 HTTP/2 选项与缓存隔离
OpenAI OAuth 超时提升到 120s,并按协议控制强制
新增客户端池与 OAuth 客户端单测覆盖
修复 usage cleanup 相关 errcheck/ineffassign/staticcheck 并统一格式

测试: make test
2026-01-19 19:50:57 +08:00
Gemini Wen
6fec141de6 fix: 修复手动刷新令牌后缓存未清除导致403错误的问题
手动刷新令牌后,新token保存到数据库但Redis缓存未清除,
导致下游请求仍然使用旧的失效token,上游API返回403错误。

修复方案:在AccountHandler中注入TokenCacheInvalidator,
刷新令牌成功后调用InvalidateToken清除缓存。

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 19:40:43 +08:00
墨颜
31cde6c555 fix(subscriptions): 用户订阅不返回分配信息
- 用户侧 UserSubscription DTO 移除 assigned_by/assigned_at/notes/assigned_by_user 等管理员字段\n- 新增 AdminUserSubscription,并调整管理员订阅接口与批量分配结果使用\n- 增加 /api/v1/subscriptions 契约测试,确保用户侧响应不包含上述字段
2026-01-19 19:35:13 +08:00
shaw
b1a980f344 feat: 添加隐藏CCS导入按钮的设置选项
在管理后台设置页面新增开关,允许管理员隐藏API Keys页面的"导入CCS"按钮
2026-01-19 19:25:16 +08:00
墨颜
00d9fbd220 fix(user): 普通用户接口不返回备注
- 用户侧 dto.User 移除 notes 字段,避免泄露管理员备注\n- 新增 dto.AdminUser 并调整 /admin/users 系列接口使用\n- 前端拆分 User/AdminUser,管理端用户页面使用 AdminUser\n- 更新契约测试:/api/v1/auth/me 响应不包含 notes
2026-01-19 19:23:51 +08:00
墨颜
4f4c9679bf fix(groups): 用户分组不下发内部路由信息
- 普通用户 Group DTO 移除 model_routing/account_count/account_groups,避免泄露内部路由与账号信息\n- 新增 AdminGroup DTO,并仅在管理员分组接口返回完整字段\n- 前端拆分 Group/AdminGroup,管理端页面与 API 使用 AdminGroup\n- 增加 /api/v1/groups/available 契约测试,防止回归
2026-01-19 18:58:42 +08:00
shaw
3dab71729d feat: usage接口支持TLS指纹和缓存User-Agent 2026-01-19 17:06:16 +08:00
墨颜
2f6f758670 fix(usage): 用户使用记录不下发账号计费倍率
- 将 usage log DTO 拆分为用户/管理员两类
- 用户接口不返回 account_rate_multiplier/ip_address/account
- 管理员接口保留管理员字段
- 补充契约测试防止回归
2026-01-19 17:05:42 +08:00
shaw
090c8981dd fix: 更新Claude OAuth授权配置以匹配最新规范
- 更新TokenURL和RedirectURI为platform.claude.com
- 更新scope定义,区分浏览器URL和内部API调用
- 修正state/code_verifier生成算法使用base64url编码
- 修正授权URL参数顺序并添加code=true
- 更新token交换请求头匹配官方实现
- 清理未使用的类型和函数
2026-01-19 16:40:06 +08:00
shaw
fbb572948d fix: 修复会话数量查询使用错误的超时配置 2026-01-19 11:45:04 +08:00
175 changed files with 13886 additions and 1744 deletions

View File

@@ -222,8 +222,9 @@ jobs:
REPO="${{ github.repository }}"
GHCR_IMAGE="ghcr.io/${REPO,,}" # ${,,} converts to lowercase
# 获取 tag message 内容
# 获取 tag message 内容并转义 Markdown 特殊字符
TAG_MESSAGE='${{ steps.tag_message.outputs.message }}'
TAG_MESSAGE=$(echo "$TAG_MESSAGE" | sed 's/\([_*`\[]\)/\\\1/g')
# 限制消息长度Telegram 消息限制 4096 字符,预留空间给头尾固定内容)
if [ ${#TAG_MESSAGE} -gt 3500 ]; then

View File

@@ -18,7 +18,7 @@ English | [中文](README_CN.md)
## Demo
Try Sub2API online: **https://v2.pincc.ai/**
Try Sub2API online: **https://demo.sub2api.org/**
Demo credentials (shared demo environment; **not** created automatically for self-hosted installs):

View File

@@ -1 +1 @@
0.1.46
0.1.61

View File

@@ -70,6 +70,7 @@ func provideCleanup(
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
subscriptionExpiry *service.SubscriptionExpiryService,
usageCleanup *service.UsageCleanupService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
@@ -138,6 +139,10 @@ func provideCleanup(
accountExpiry.Stop()
return nil
}},
{"SubscriptionExpiryService", func() error {
subscriptionExpiry.Stop()
return nil
}},
{"PricingService", func() error {
pricing.Stop()
return nil

View File

@@ -63,7 +63,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService)
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
if err != nil {
return nil, err
}
totpCache := repository.NewTotpCache(redisClient)
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, totpService)
userHandler := handler.NewUserHandler(userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
@@ -84,7 +90,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
}
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
accountRepository := repository.NewAccountRepository(client, db)
schedulerCache := repository.NewSchedulerCache(redisClient)
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
@@ -105,21 +112,22 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
httpUpstream := repository.NewHTTPUpstream(configConfig)
claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream)
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
usageCache := service.NewUsageCache()
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache)
identityCache := repository.NewIdentityCache(redisClient)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
gatewayCache := repository.NewGatewayCache(redisClient)
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
httpUpstream := repository.NewHTTPUpstream(configConfig)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
@@ -128,7 +136,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminRedeemHandler := admin.NewRedeemHandler(adminService)
promoHandler := admin.NewPromoHandler(promoService)
opsRepository := repository.NewOpsRepository(db)
schedulerCache := repository.NewSchedulerCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
@@ -137,7 +144,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
return nil, err
}
billingService := service.NewBillingService(configConfig, pricingService)
identityCache := repository.NewIdentityCache(redisClient)
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
@@ -165,7 +171,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
totpHandler := handler.NewTotpHandler(totpService)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
@@ -178,7 +185,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -211,6 +219,7 @@ func provideCleanup(
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
subscriptionExpiry *service.SubscriptionExpiryService,
usageCleanup *service.UsageCleanupService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
@@ -278,6 +287,10 @@ func provideCleanup(
accountExpiry.Stop()
return nil
}},
{"SubscriptionExpiryService", func() error {
subscriptionExpiry.Stop()
return nil
}},
{"PricingService", func() error {
pricing.Stop()
return nil

View File

@@ -610,6 +610,9 @@ var (
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
{Name: "username", Type: field.TypeString, Size: 100, Default: ""},
{Name: "notes", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
}
// UsersTable holds the schema information for the "users" table.
UsersTable = &schema.Table{

View File

@@ -14360,6 +14360,9 @@ type UserMutation struct {
status *string
username *string
notes *string
totp_secret_encrypted *string
totp_enabled *bool
totp_enabled_at *time.Time
clearedFields map[string]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
@@ -14937,6 +14940,140 @@ func (m *UserMutation) ResetNotes() {
m.notes = nil
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (m *UserMutation) SetTotpSecretEncrypted(s string) {
m.totp_secret_encrypted = &s
}
// TotpSecretEncrypted returns the value of the "totp_secret_encrypted" field in the mutation.
func (m *UserMutation) TotpSecretEncrypted() (r string, exists bool) {
v := m.totp_secret_encrypted
if v == nil {
return
}
return *v, true
}
// OldTotpSecretEncrypted returns the old "totp_secret_encrypted" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldTotpSecretEncrypted(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldTotpSecretEncrypted is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldTotpSecretEncrypted requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldTotpSecretEncrypted: %w", err)
}
return oldValue.TotpSecretEncrypted, nil
}
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
func (m *UserMutation) ClearTotpSecretEncrypted() {
m.totp_secret_encrypted = nil
m.clearedFields[user.FieldTotpSecretEncrypted] = struct{}{}
}
// TotpSecretEncryptedCleared returns if the "totp_secret_encrypted" field was cleared in this mutation.
func (m *UserMutation) TotpSecretEncryptedCleared() bool {
_, ok := m.clearedFields[user.FieldTotpSecretEncrypted]
return ok
}
// ResetTotpSecretEncrypted resets all changes to the "totp_secret_encrypted" field.
func (m *UserMutation) ResetTotpSecretEncrypted() {
m.totp_secret_encrypted = nil
delete(m.clearedFields, user.FieldTotpSecretEncrypted)
}
// SetTotpEnabled sets the "totp_enabled" field.
func (m *UserMutation) SetTotpEnabled(b bool) {
m.totp_enabled = &b
}
// TotpEnabled returns the value of the "totp_enabled" field in the mutation.
func (m *UserMutation) TotpEnabled() (r bool, exists bool) {
v := m.totp_enabled
if v == nil {
return
}
return *v, true
}
// OldTotpEnabled returns the old "totp_enabled" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldTotpEnabled(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldTotpEnabled is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldTotpEnabled requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldTotpEnabled: %w", err)
}
return oldValue.TotpEnabled, nil
}
// ResetTotpEnabled resets all changes to the "totp_enabled" field.
func (m *UserMutation) ResetTotpEnabled() {
m.totp_enabled = nil
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (m *UserMutation) SetTotpEnabledAt(t time.Time) {
m.totp_enabled_at = &t
}
// TotpEnabledAt returns the value of the "totp_enabled_at" field in the mutation.
func (m *UserMutation) TotpEnabledAt() (r time.Time, exists bool) {
v := m.totp_enabled_at
if v == nil {
return
}
return *v, true
}
// OldTotpEnabledAt returns the old "totp_enabled_at" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldTotpEnabledAt(ctx context.Context) (v *time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldTotpEnabledAt is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldTotpEnabledAt requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldTotpEnabledAt: %w", err)
}
return oldValue.TotpEnabledAt, nil
}
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
func (m *UserMutation) ClearTotpEnabledAt() {
m.totp_enabled_at = nil
m.clearedFields[user.FieldTotpEnabledAt] = struct{}{}
}
// TotpEnabledAtCleared returns if the "totp_enabled_at" field was cleared in this mutation.
func (m *UserMutation) TotpEnabledAtCleared() bool {
_, ok := m.clearedFields[user.FieldTotpEnabledAt]
return ok
}
// ResetTotpEnabledAt resets all changes to the "totp_enabled_at" field.
func (m *UserMutation) ResetTotpEnabledAt() {
m.totp_enabled_at = nil
delete(m.clearedFields, user.FieldTotpEnabledAt)
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *UserMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil {
@@ -15403,7 +15540,7 @@ func (m *UserMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UserMutation) Fields() []string {
fields := make([]string, 0, 11)
fields := make([]string, 0, 14)
if m.created_at != nil {
fields = append(fields, user.FieldCreatedAt)
}
@@ -15437,6 +15574,15 @@ func (m *UserMutation) Fields() []string {
if m.notes != nil {
fields = append(fields, user.FieldNotes)
}
if m.totp_secret_encrypted != nil {
fields = append(fields, user.FieldTotpSecretEncrypted)
}
if m.totp_enabled != nil {
fields = append(fields, user.FieldTotpEnabled)
}
if m.totp_enabled_at != nil {
fields = append(fields, user.FieldTotpEnabledAt)
}
return fields
}
@@ -15467,6 +15613,12 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
return m.Username()
case user.FieldNotes:
return m.Notes()
case user.FieldTotpSecretEncrypted:
return m.TotpSecretEncrypted()
case user.FieldTotpEnabled:
return m.TotpEnabled()
case user.FieldTotpEnabledAt:
return m.TotpEnabledAt()
}
return nil, false
}
@@ -15498,6 +15650,12 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
return m.OldUsername(ctx)
case user.FieldNotes:
return m.OldNotes(ctx)
case user.FieldTotpSecretEncrypted:
return m.OldTotpSecretEncrypted(ctx)
case user.FieldTotpEnabled:
return m.OldTotpEnabled(ctx)
case user.FieldTotpEnabledAt:
return m.OldTotpEnabledAt(ctx)
}
return nil, fmt.Errorf("unknown User field %s", name)
}
@@ -15584,6 +15742,27 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
}
m.SetNotes(v)
return nil
case user.FieldTotpSecretEncrypted:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetTotpSecretEncrypted(v)
return nil
case user.FieldTotpEnabled:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetTotpEnabled(v)
return nil
case user.FieldTotpEnabledAt:
v, ok := value.(time.Time)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetTotpEnabledAt(v)
return nil
}
return fmt.Errorf("unknown User field %s", name)
}
@@ -15644,6 +15823,12 @@ func (m *UserMutation) ClearedFields() []string {
if m.FieldCleared(user.FieldDeletedAt) {
fields = append(fields, user.FieldDeletedAt)
}
if m.FieldCleared(user.FieldTotpSecretEncrypted) {
fields = append(fields, user.FieldTotpSecretEncrypted)
}
if m.FieldCleared(user.FieldTotpEnabledAt) {
fields = append(fields, user.FieldTotpEnabledAt)
}
return fields
}
@@ -15661,6 +15846,12 @@ func (m *UserMutation) ClearField(name string) error {
case user.FieldDeletedAt:
m.ClearDeletedAt()
return nil
case user.FieldTotpSecretEncrypted:
m.ClearTotpSecretEncrypted()
return nil
case user.FieldTotpEnabledAt:
m.ClearTotpEnabledAt()
return nil
}
return fmt.Errorf("unknown User nullable field %s", name)
}
@@ -15702,6 +15893,15 @@ func (m *UserMutation) ResetField(name string) error {
case user.FieldNotes:
m.ResetNotes()
return nil
case user.FieldTotpSecretEncrypted:
m.ResetTotpSecretEncrypted()
return nil
case user.FieldTotpEnabled:
m.ResetTotpEnabled()
return nil
case user.FieldTotpEnabledAt:
m.ResetTotpEnabledAt()
return nil
}
return fmt.Errorf("unknown User field %s", name)
}

View File

@@ -736,6 +736,10 @@ func init() {
userDescNotes := userFields[7].Descriptor()
// user.DefaultNotes holds the default value on creation for the notes field.
user.DefaultNotes = userDescNotes.Default.(string)
// userDescTotpEnabled is the schema descriptor for totp_enabled field.
userDescTotpEnabled := userFields[9].Descriptor()
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
_ = userallowedgroupFields
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.

View File

@@ -61,6 +61,17 @@ func (User) Fields() []ent.Field {
field.String("notes").
SchemaType(map[string]string{dialect.Postgres: "text"}).
Default(""),
// TOTP 双因素认证字段
field.String("totp_secret_encrypted").
SchemaType(map[string]string{dialect.Postgres: "text"}).
Optional().
Nillable(),
field.Bool("totp_enabled").
Default(false),
field.Time("totp_enabled_at").
Optional().
Nillable(),
}
}

View File

@@ -39,6 +39,12 @@ type User struct {
Username string `json:"username,omitempty"`
// Notes holds the value of the "notes" field.
Notes string `json:"notes,omitempty"`
// TotpSecretEncrypted holds the value of the "totp_secret_encrypted" field.
TotpSecretEncrypted *string `json:"totp_secret_encrypted,omitempty"`
// TotpEnabled holds the value of the "totp_enabled" field.
TotpEnabled bool `json:"totp_enabled,omitempty"`
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserQuery when eager-loading is set.
Edges UserEdges `json:"edges"`
@@ -156,13 +162,15 @@ func (*User) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case user.FieldTotpEnabled:
values[i] = new(sql.NullBool)
case user.FieldBalance:
values[i] = new(sql.NullFloat64)
case user.FieldID, user.FieldConcurrency:
values[i] = new(sql.NullInt64)
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes:
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted:
values[i] = new(sql.NullString)
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt:
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
@@ -252,6 +260,26 @@ func (_m *User) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.Notes = value.String
}
case user.FieldTotpSecretEncrypted:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field totp_secret_encrypted", values[i])
} else if value.Valid {
_m.TotpSecretEncrypted = new(string)
*_m.TotpSecretEncrypted = value.String
}
case user.FieldTotpEnabled:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field totp_enabled", values[i])
} else if value.Valid {
_m.TotpEnabled = value.Bool
}
case user.FieldTotpEnabledAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field totp_enabled_at", values[i])
} else if value.Valid {
_m.TotpEnabledAt = new(time.Time)
*_m.TotpEnabledAt = value.Time
}
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -367,6 +395,19 @@ func (_m *User) String() string {
builder.WriteString(", ")
builder.WriteString("notes=")
builder.WriteString(_m.Notes)
builder.WriteString(", ")
if v := _m.TotpSecretEncrypted; v != nil {
builder.WriteString("totp_secret_encrypted=")
builder.WriteString(*v)
}
builder.WriteString(", ")
builder.WriteString("totp_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.TotpEnabled))
builder.WriteString(", ")
if v := _m.TotpEnabledAt; v != nil {
builder.WriteString("totp_enabled_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteByte(')')
return builder.String()
}

View File

@@ -37,6 +37,12 @@ const (
FieldUsername = "username"
// FieldNotes holds the string denoting the notes field in the database.
FieldNotes = "notes"
// FieldTotpSecretEncrypted holds the string denoting the totp_secret_encrypted field in the database.
FieldTotpSecretEncrypted = "totp_secret_encrypted"
// FieldTotpEnabled holds the string denoting the totp_enabled field in the database.
FieldTotpEnabled = "totp_enabled"
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
FieldTotpEnabledAt = "totp_enabled_at"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -134,6 +140,9 @@ var Columns = []string{
FieldStatus,
FieldUsername,
FieldNotes,
FieldTotpSecretEncrypted,
FieldTotpEnabled,
FieldTotpEnabledAt,
}
var (
@@ -188,6 +197,8 @@ var (
UsernameValidator func(string) error
// DefaultNotes holds the default value on creation for the "notes" field.
DefaultNotes string
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
DefaultTotpEnabled bool
)
// OrderOption defines the ordering options for the User queries.
@@ -253,6 +264,21 @@ func ByNotes(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldNotes, opts...).ToFunc()
}
// ByTotpSecretEncrypted orders the results by the totp_secret_encrypted field.
func ByTotpSecretEncrypted(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpSecretEncrypted, opts...).ToFunc()
}
// ByTotpEnabled orders the results by the totp_enabled field.
func ByTotpEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabled, opts...).ToFunc()
}
// ByTotpEnabledAt orders the results by the totp_enabled_at field.
func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
}
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {

View File

@@ -110,6 +110,21 @@ func Notes(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldNotes, v))
}
// TotpSecretEncrypted applies equality check predicate on the "totp_secret_encrypted" field. It's identical to TotpSecretEncryptedEQ.
func TotpSecretEncrypted(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpSecretEncrypted, v))
}
// TotpEnabled applies equality check predicate on the "totp_enabled" field. It's identical to TotpEnabledEQ.
func TotpEnabled(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabled, v))
}
// TotpEnabledAt applies equality check predicate on the "totp_enabled_at" field. It's identical to TotpEnabledAtEQ.
func TotpEnabledAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
@@ -710,6 +725,141 @@ func NotesContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldNotes, v))
}
// TotpSecretEncryptedEQ applies the EQ predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedNEQ applies the NEQ predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedNEQ(v string) predicate.User {
return predicate.User(sql.FieldNEQ(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedIn applies the In predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedIn(vs ...string) predicate.User {
return predicate.User(sql.FieldIn(FieldTotpSecretEncrypted, vs...))
}
// TotpSecretEncryptedNotIn applies the NotIn predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedNotIn(vs ...string) predicate.User {
return predicate.User(sql.FieldNotIn(FieldTotpSecretEncrypted, vs...))
}
// TotpSecretEncryptedGT applies the GT predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedGT(v string) predicate.User {
return predicate.User(sql.FieldGT(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedGTE applies the GTE predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedGTE(v string) predicate.User {
return predicate.User(sql.FieldGTE(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedLT applies the LT predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedLT(v string) predicate.User {
return predicate.User(sql.FieldLT(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedLTE applies the LTE predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedLTE(v string) predicate.User {
return predicate.User(sql.FieldLTE(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedContains applies the Contains predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedContains(v string) predicate.User {
return predicate.User(sql.FieldContains(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedHasPrefix applies the HasPrefix predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedHasPrefix(v string) predicate.User {
return predicate.User(sql.FieldHasPrefix(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedHasSuffix applies the HasSuffix predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedHasSuffix(v string) predicate.User {
return predicate.User(sql.FieldHasSuffix(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedIsNil applies the IsNil predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedIsNil() predicate.User {
return predicate.User(sql.FieldIsNull(FieldTotpSecretEncrypted))
}
// TotpSecretEncryptedNotNil applies the NotNil predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpSecretEncrypted))
}
// TotpSecretEncryptedEqualFold applies the EqualFold predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedEqualFold(v string) predicate.User {
return predicate.User(sql.FieldEqualFold(FieldTotpSecretEncrypted, v))
}
// TotpSecretEncryptedContainsFold applies the ContainsFold predicate on the "totp_secret_encrypted" field.
func TotpSecretEncryptedContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldTotpSecretEncrypted, v))
}
// TotpEnabledEQ applies the EQ predicate on the "totp_enabled" field.
func TotpEnabledEQ(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabled, v))
}
// TotpEnabledNEQ applies the NEQ predicate on the "totp_enabled" field.
func TotpEnabledNEQ(v bool) predicate.User {
return predicate.User(sql.FieldNEQ(FieldTotpEnabled, v))
}
// TotpEnabledAtEQ applies the EQ predicate on the "totp_enabled_at" field.
func TotpEnabledAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
}
// TotpEnabledAtNEQ applies the NEQ predicate on the "totp_enabled_at" field.
func TotpEnabledAtNEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldNEQ(FieldTotpEnabledAt, v))
}
// TotpEnabledAtIn applies the In predicate on the "totp_enabled_at" field.
func TotpEnabledAtIn(vs ...time.Time) predicate.User {
return predicate.User(sql.FieldIn(FieldTotpEnabledAt, vs...))
}
// TotpEnabledAtNotIn applies the NotIn predicate on the "totp_enabled_at" field.
func TotpEnabledAtNotIn(vs ...time.Time) predicate.User {
return predicate.User(sql.FieldNotIn(FieldTotpEnabledAt, vs...))
}
// TotpEnabledAtGT applies the GT predicate on the "totp_enabled_at" field.
func TotpEnabledAtGT(v time.Time) predicate.User {
return predicate.User(sql.FieldGT(FieldTotpEnabledAt, v))
}
// TotpEnabledAtGTE applies the GTE predicate on the "totp_enabled_at" field.
func TotpEnabledAtGTE(v time.Time) predicate.User {
return predicate.User(sql.FieldGTE(FieldTotpEnabledAt, v))
}
// TotpEnabledAtLT applies the LT predicate on the "totp_enabled_at" field.
func TotpEnabledAtLT(v time.Time) predicate.User {
return predicate.User(sql.FieldLT(FieldTotpEnabledAt, v))
}
// TotpEnabledAtLTE applies the LTE predicate on the "totp_enabled_at" field.
func TotpEnabledAtLTE(v time.Time) predicate.User {
return predicate.User(sql.FieldLTE(FieldTotpEnabledAt, v))
}
// TotpEnabledAtIsNil applies the IsNil predicate on the "totp_enabled_at" field.
func TotpEnabledAtIsNil() predicate.User {
return predicate.User(sql.FieldIsNull(FieldTotpEnabledAt))
}
// TotpEnabledAtNotNil applies the NotNil predicate on the "totp_enabled_at" field.
func TotpEnabledAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
}
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.User {
return predicate.User(func(s *sql.Selector) {

View File

@@ -167,6 +167,48 @@ func (_c *UserCreate) SetNillableNotes(v *string) *UserCreate {
return _c
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (_c *UserCreate) SetTotpSecretEncrypted(v string) *UserCreate {
_c.mutation.SetTotpSecretEncrypted(v)
return _c
}
// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
func (_c *UserCreate) SetNillableTotpSecretEncrypted(v *string) *UserCreate {
if v != nil {
_c.SetTotpSecretEncrypted(*v)
}
return _c
}
// SetTotpEnabled sets the "totp_enabled" field.
func (_c *UserCreate) SetTotpEnabled(v bool) *UserCreate {
_c.mutation.SetTotpEnabled(v)
return _c
}
// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
func (_c *UserCreate) SetNillableTotpEnabled(v *bool) *UserCreate {
if v != nil {
_c.SetTotpEnabled(*v)
}
return _c
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (_c *UserCreate) SetTotpEnabledAt(v time.Time) *UserCreate {
_c.mutation.SetTotpEnabledAt(v)
return _c
}
// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
if v != nil {
_c.SetTotpEnabledAt(*v)
}
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -362,6 +404,10 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultNotes
_c.mutation.SetNotes(v)
}
if _, ok := _c.mutation.TotpEnabled(); !ok {
v := user.DefaultTotpEnabled
_c.mutation.SetTotpEnabled(v)
}
return nil
}
@@ -422,6 +468,9 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.Notes(); !ok {
return &ValidationError{Name: "notes", err: errors.New(`ent: missing required field "User.notes"`)}
}
if _, ok := _c.mutation.TotpEnabled(); !ok {
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
}
return nil
}
@@ -493,6 +542,18 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldNotes, field.TypeString, value)
_node.Notes = value
}
if value, ok := _c.mutation.TotpSecretEncrypted(); ok {
_spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
_node.TotpSecretEncrypted = &value
}
if value, ok := _c.mutation.TotpEnabled(); ok {
_spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
_node.TotpEnabled = value
}
if value, ok := _c.mutation.TotpEnabledAt(); ok {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
_node.TotpEnabledAt = &value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -815,6 +876,54 @@ func (u *UserUpsert) UpdateNotes() *UserUpsert {
return u
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (u *UserUpsert) SetTotpSecretEncrypted(v string) *UserUpsert {
u.Set(user.FieldTotpSecretEncrypted, v)
return u
}
// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
func (u *UserUpsert) UpdateTotpSecretEncrypted() *UserUpsert {
u.SetExcluded(user.FieldTotpSecretEncrypted)
return u
}
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
func (u *UserUpsert) ClearTotpSecretEncrypted() *UserUpsert {
u.SetNull(user.FieldTotpSecretEncrypted)
return u
}
// SetTotpEnabled sets the "totp_enabled" field.
func (u *UserUpsert) SetTotpEnabled(v bool) *UserUpsert {
u.Set(user.FieldTotpEnabled, v)
return u
}
// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
func (u *UserUpsert) UpdateTotpEnabled() *UserUpsert {
u.SetExcluded(user.FieldTotpEnabled)
return u
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (u *UserUpsert) SetTotpEnabledAt(v time.Time) *UserUpsert {
u.Set(user.FieldTotpEnabledAt, v)
return u
}
// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
func (u *UserUpsert) UpdateTotpEnabledAt() *UserUpsert {
u.SetExcluded(user.FieldTotpEnabledAt)
return u
}
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
u.SetNull(user.FieldTotpEnabledAt)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -1021,6 +1130,62 @@ func (u *UserUpsertOne) UpdateNotes() *UserUpsertOne {
})
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (u *UserUpsertOne) SetTotpSecretEncrypted(v string) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetTotpSecretEncrypted(v)
})
}
// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateTotpSecretEncrypted() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateTotpSecretEncrypted()
})
}
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
func (u *UserUpsertOne) ClearTotpSecretEncrypted() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.ClearTotpSecretEncrypted()
})
}
// SetTotpEnabled sets the "totp_enabled" field.
func (u *UserUpsertOne) SetTotpEnabled(v bool) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetTotpEnabled(v)
})
}
// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateTotpEnabled() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateTotpEnabled()
})
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (u *UserUpsertOne) SetTotpEnabledAt(v time.Time) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetTotpEnabledAt(v)
})
}
// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateTotpEnabledAt() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateTotpEnabledAt()
})
}
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.ClearTotpEnabledAt()
})
}
// Exec executes the query.
func (u *UserUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -1393,6 +1558,62 @@ func (u *UserUpsertBulk) UpdateNotes() *UserUpsertBulk {
})
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (u *UserUpsertBulk) SetTotpSecretEncrypted(v string) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetTotpSecretEncrypted(v)
})
}
// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateTotpSecretEncrypted() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateTotpSecretEncrypted()
})
}
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
func (u *UserUpsertBulk) ClearTotpSecretEncrypted() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.ClearTotpSecretEncrypted()
})
}
// SetTotpEnabled sets the "totp_enabled" field.
func (u *UserUpsertBulk) SetTotpEnabled(v bool) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetTotpEnabled(v)
})
}
// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateTotpEnabled() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateTotpEnabled()
})
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (u *UserUpsertBulk) SetTotpEnabledAt(v time.Time) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetTotpEnabledAt(v)
})
}
// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateTotpEnabledAt() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateTotpEnabledAt()
})
}
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.ClearTotpEnabledAt()
})
}
// Exec executes the query.
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {

View File

@@ -187,6 +187,60 @@ func (_u *UserUpdate) SetNillableNotes(v *string) *UserUpdate {
return _u
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (_u *UserUpdate) SetTotpSecretEncrypted(v string) *UserUpdate {
_u.mutation.SetTotpSecretEncrypted(v)
return _u
}
// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
func (_u *UserUpdate) SetNillableTotpSecretEncrypted(v *string) *UserUpdate {
if v != nil {
_u.SetTotpSecretEncrypted(*v)
}
return _u
}
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
func (_u *UserUpdate) ClearTotpSecretEncrypted() *UserUpdate {
_u.mutation.ClearTotpSecretEncrypted()
return _u
}
// SetTotpEnabled sets the "totp_enabled" field.
func (_u *UserUpdate) SetTotpEnabled(v bool) *UserUpdate {
_u.mutation.SetTotpEnabled(v)
return _u
}
// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
func (_u *UserUpdate) SetNillableTotpEnabled(v *bool) *UserUpdate {
if v != nil {
_u.SetTotpEnabled(*v)
}
return _u
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (_u *UserUpdate) SetTotpEnabledAt(v time.Time) *UserUpdate {
_u.mutation.SetTotpEnabledAt(v)
return _u
}
// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
func (_u *UserUpdate) SetNillableTotpEnabledAt(v *time.Time) *UserUpdate {
if v != nil {
_u.SetTotpEnabledAt(*v)
}
return _u
}
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
_u.mutation.ClearTotpEnabledAt()
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -603,6 +657,21 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.Notes(); ok {
_spec.SetField(user.FieldNotes, field.TypeString, value)
}
if value, ok := _u.mutation.TotpSecretEncrypted(); ok {
_spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
}
if _u.mutation.TotpSecretEncryptedCleared() {
_spec.ClearField(user.FieldTotpSecretEncrypted, field.TypeString)
}
if value, ok := _u.mutation.TotpEnabled(); ok {
_spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
}
if value, ok := _u.mutation.TotpEnabledAt(); ok {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
}
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1147,6 +1216,60 @@ func (_u *UserUpdateOne) SetNillableNotes(v *string) *UserUpdateOne {
return _u
}
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
func (_u *UserUpdateOne) SetTotpSecretEncrypted(v string) *UserUpdateOne {
_u.mutation.SetTotpSecretEncrypted(v)
return _u
}
// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableTotpSecretEncrypted(v *string) *UserUpdateOne {
if v != nil {
_u.SetTotpSecretEncrypted(*v)
}
return _u
}
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
func (_u *UserUpdateOne) ClearTotpSecretEncrypted() *UserUpdateOne {
_u.mutation.ClearTotpSecretEncrypted()
return _u
}
// SetTotpEnabled sets the "totp_enabled" field.
func (_u *UserUpdateOne) SetTotpEnabled(v bool) *UserUpdateOne {
_u.mutation.SetTotpEnabled(v)
return _u
}
// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableTotpEnabled(v *bool) *UserUpdateOne {
if v != nil {
_u.SetTotpEnabled(*v)
}
return _u
}
// SetTotpEnabledAt sets the "totp_enabled_at" field.
func (_u *UserUpdateOne) SetTotpEnabledAt(v time.Time) *UserUpdateOne {
_u.mutation.SetTotpEnabledAt(v)
return _u
}
// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableTotpEnabledAt(v *time.Time) *UserUpdateOne {
if v != nil {
_u.SetTotpEnabledAt(*v)
}
return _u
}
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
_u.mutation.ClearTotpEnabledAt()
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1593,6 +1716,21 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if value, ok := _u.mutation.Notes(); ok {
_spec.SetField(user.FieldNotes, field.TypeString, value)
}
if value, ok := _u.mutation.TotpSecretEncrypted(); ok {
_spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
}
if _u.mutation.TotpSecretEncryptedCleared() {
_spec.ClearField(user.FieldTotpSecretEncrypted, field.TypeString)
}
if value, ok := _u.mutation.TotpEnabled(); ok {
_spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
}
if value, ok := _u.mutation.TotpEnabledAt(); ok {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
}
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,

View File

@@ -37,6 +37,7 @@ require (
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
@@ -106,6 +107,7 @@ require (
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/pquerna/otp v1.5.0 // indirect
github.com/quic-go/qpack v0.6.0 // indirect
github.com/quic-go/quic-go v0.57.1 // indirect
github.com/refraction-networking/utls v1.8.1 // indirect

View File

@@ -20,6 +20,8 @@ github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
@@ -217,6 +219,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=

View File

@@ -47,6 +47,7 @@ type Config struct {
Redis RedisConfig `mapstructure:"redis"`
Ops OpsConfig `mapstructure:"ops"`
JWT JWTConfig `mapstructure:"jwt"`
Totp TotpConfig `mapstructure:"totp"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
@@ -466,6 +467,16 @@ type JWTConfig struct {
ExpireHour int `mapstructure:"expire_hour"`
}
// TotpConfig TOTP 双因素认证配置
type TotpConfig struct {
// EncryptionKey 用于加密 TOTP 密钥的 AES-256 密钥32 字节 hex 编码)
// 如果为空,将自动生成一个随机密钥(仅适用于开发环境)
EncryptionKey string `mapstructure:"encryption_key"`
// EncryptionKeyConfigured 标记加密密钥是否为手动配置(非自动生成)
// 只有手动配置了密钥才允许在管理后台启用 TOTP 功能
EncryptionKeyConfigured bool `mapstructure:"-"`
}
type TurnstileConfig struct {
Required bool `mapstructure:"required"`
}
@@ -626,6 +637,20 @@ func Load() (*Config, error) {
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
}
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
if cfg.Totp.EncryptionKey == "" {
key, err := generateJWTSecret(32) // Reuse the same random generation function
if err != nil {
return nil, fmt.Errorf("generate totp encryption key error: %w", err)
}
cfg.Totp.EncryptionKey = key
cfg.Totp.EncryptionKeyConfigured = false
log.Println("Warning: TOTP encryption key auto-generated. Consider setting a fixed key for production.")
} else {
cfg.Totp.EncryptionKeyConfigured = true
}
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("validate config error: %w", err)
}
@@ -756,6 +781,9 @@ func setDefaults() {
viper.SetDefault("jwt.secret", "")
viper.SetDefault("jwt.expire_hour", 24)
// TOTP
viper.SetDefault("totp.encryption_key", "")
// Default
// Admin credentials are created via the setup flow (web wizard / CLI / AUTO_SETUP).
// Do not ship fixed defaults here to avoid insecure "known credentials" in production.

View File

@@ -45,6 +45,7 @@ type AccountHandler struct {
concurrencyService *service.ConcurrencyService
crsSyncService *service.CRSSyncService
sessionLimitCache service.SessionLimitCache
tokenCacheInvalidator service.TokenCacheInvalidator
}
// NewAccountHandler creates a new admin account handler
@@ -60,6 +61,7 @@ func NewAccountHandler(
concurrencyService *service.ConcurrencyService,
crsSyncService *service.CRSSyncService,
sessionLimitCache service.SessionLimitCache,
tokenCacheInvalidator service.TokenCacheInvalidator,
) *AccountHandler {
return &AccountHandler{
adminService: adminService,
@@ -73,6 +75,7 @@ func NewAccountHandler(
concurrencyService: concurrencyService,
crsSyncService: crsSyncService,
sessionLimitCache: sessionLimitCache,
tokenCacheInvalidator: tokenCacheInvalidator,
}
}
@@ -173,6 +176,7 @@ func (h *AccountHandler) List(c *gin.Context) {
// 识别需要查询窗口费用和会话数的账号Anthropic OAuth/SetupToken 且启用了相应功能)
windowCostAccountIDs := make([]int64, 0)
sessionLimitAccountIDs := make([]int64, 0)
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
for i := range accounts {
acc := &accounts[i]
if acc.IsAnthropicOAuthOrSetupToken() {
@@ -181,6 +185,7 @@ func (h *AccountHandler) List(c *gin.Context) {
}
if acc.GetMaxSessions() > 0 {
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
}
}
}
@@ -189,9 +194,9 @@ func (h *AccountHandler) List(c *gin.Context) {
var windowCosts map[int64]float64
var activeSessions map[int64]int
// 获取活跃会话数(批量查询)
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs)
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
if activeSessions == nil {
activeSessions = make(map[int64]int)
}
@@ -542,9 +547,18 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
}
}
// 如果 project_id 获取失败,先更新凭证,再标记账户为 error
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
// 这确保了即使 LoadCodeAssist 失败project_id 也不会丢失
if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
newCredentials["project_id"] = oldProjectID
}
}
// 如果 project_id 获取失败,更新凭证但不标记为 error
// LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试
if tokenInfo.ProjectIDMissing {
// 先更新凭证
// 先更新凭证token 本身刷新成功了)
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: newCredentials,
})
@@ -552,14 +566,10 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
return
}
// 标记账户为 error
if setErr := h.adminService.SetAccountError(c.Request.Context(), accountID, "missing_project_id: 账户缺少project id可能无法使用Antigravity"); setErr != nil {
response.InternalError(c, "Failed to set account error: "+setErr.Error())
return
}
// 标记为 error,只返回警告信息
response.Success(c, gin.H{
"message": "Token refreshed but project_id is missing, account marked as error",
"warning": "missing_project_id",
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
"warning": "missing_project_id_temporary",
})
return
}
@@ -606,6 +616,14 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
return
}
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
if h.tokenCacheInvalidator != nil {
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), updatedAccount); invalidateErr != nil {
// 缓存失效失败只记录日志,不影响主流程
_ = c.Error(invalidateErr)
}
}
response.Success(c, dto.AccountFromService(updatedAccount))
}
@@ -655,6 +673,15 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
return
}
// 清除错误后,同时清除 token 缓存,确保下次请求会获取最新的 token触发刷新或从 DB 读取)
// 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil {
// 缓存失效失败只记录日志,不影响主流程
_ = c.Error(invalidateErr)
}
}
response.Success(c, dto.AccountFromService(account))
}

View File

@@ -94,9 +94,9 @@ func (h *GroupHandler) List(c *gin.Context) {
return
}
outGroups := make([]dto.Group, 0, len(groups))
outGroups := make([]dto.AdminGroup, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
outGroups = append(outGroups, *dto.GroupFromServiceAdmin(&groups[i]))
}
response.Paginated(c, outGroups, total, page, pageSize)
}
@@ -120,9 +120,9 @@ func (h *GroupHandler) GetAll(c *gin.Context) {
return
}
outGroups := make([]dto.Group, 0, len(groups))
outGroups := make([]dto.AdminGroup, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
outGroups = append(outGroups, *dto.GroupFromServiceAdmin(&groups[i]))
}
response.Success(c, outGroups)
}
@@ -142,7 +142,7 @@ func (h *GroupHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.GroupFromService(group))
response.Success(c, dto.GroupFromServiceAdmin(group))
}
// Create handles creating a new group
@@ -177,7 +177,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
return
}
response.Success(c, dto.GroupFromService(group))
response.Success(c, dto.GroupFromServiceAdmin(group))
}
// Update handles updating a group
@@ -219,7 +219,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
return
}
response.Success(c, dto.GroupFromService(group))
response.Success(c, dto.GroupFromServiceAdmin(group))
}
// Delete handles deleting a group

View File

@@ -54,9 +54,9 @@ func (h *RedeemHandler) List(c *gin.Context) {
return
}
out := make([]dto.RedeemCode, 0, len(codes))
out := make([]dto.AdminRedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
@@ -76,7 +76,7 @@ func (h *RedeemHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.RedeemCodeFromService(code))
response.Success(c, dto.RedeemCodeFromServiceAdmin(code))
}
// Generate handles generating new redeem codes
@@ -100,9 +100,9 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
return
}
out := make([]dto.RedeemCode, 0, len(codes))
out := make([]dto.AdminRedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
}
response.Success(c, out)
}
@@ -163,7 +163,7 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
return
}
response.Success(c, dto.RedeemCodeFromService(code))
response.Success(c, dto.RedeemCodeFromServiceAdmin(code))
}
// GetStats handles getting redeem code statistics

View File

@@ -47,6 +47,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
TotpEnabled: settings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername,
@@ -68,6 +72,9 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
EnableModelFallback: settings.EnableModelFallback,
@@ -87,8 +94,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
// UpdateSettingsRequest 更新设置请求
type UpdateSettingsRequest struct {
// 注册设置
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
// 邮件服务设置
SMTPHost string `json:"smtp_host"`
@@ -111,13 +121,16 @@ type UpdateSettingsRequest struct {
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
// OEM设置
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
@@ -194,6 +207,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
// TOTP 双因素认证参数验证
// 只有手动配置了加密密钥才允许启用 TOTP 功能
if req.TotpEnabled && !previousSettings.TotpEnabled {
// 尝试启用 TOTP检查加密密钥是否已手动配置
if !h.settingService.IsTotpEncryptionKeyConfigured() {
response.BadRequest(c, "Cannot enable TOTP: TOTP_ENCRYPTION_KEY environment variable must be configured first. Generate a key with 'openssl rand -hex 32' and set it in your environment.")
return
}
}
// LinuxDo Connect 参数验证
if req.LinuxDoConnectEnabled {
req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID)
@@ -223,6 +246,34 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
// “购买订阅”页面配置验证
purchaseEnabled := previousSettings.PurchaseSubscriptionEnabled
if req.PurchaseSubscriptionEnabled != nil {
purchaseEnabled = *req.PurchaseSubscriptionEnabled
}
purchaseURL := previousSettings.PurchaseSubscriptionURL
if req.PurchaseSubscriptionURL != nil {
purchaseURL = strings.TrimSpace(*req.PurchaseSubscriptionURL)
}
// - 启用时要求 URL 合法且非空
// - 禁用时允许为空;若提供了 URL 也做基本校验,避免误配置
if purchaseEnabled {
if purchaseURL == "" {
response.BadRequest(c, "Purchase Subscription URL is required when enabled")
return
}
if err := config.ValidateAbsoluteHTTPURL(purchaseURL); err != nil {
response.BadRequest(c, "Purchase Subscription URL must be an absolute http(s) URL")
return
}
} else if purchaseURL != "" {
if err := config.ValidateAbsoluteHTTPURL(purchaseURL); err != nil {
response.BadRequest(c, "Purchase Subscription URL must be an absolute http(s) URL")
return
}
}
// Ops metrics collector interval validation (seconds).
if req.OpsMetricsIntervalSeconds != nil {
v := *req.OpsMetricsIntervalSeconds
@@ -236,38 +287,44 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
APIBaseURL: req.APIBaseURL,
ContactInfo: req.ContactInfo,
DocURL: req.DocURL,
HomeContent: req.HomeContent,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
PromoCodeEnabled: req.PromoCodeEnabled,
PasswordResetEnabled: req.PasswordResetEnabled,
TotpEnabled: req.TotpEnabled,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
APIBaseURL: req.APIBaseURL,
ContactInfo: req.ContactInfo,
DocURL: req.DocURL,
HomeContent: req.HomeContent,
HideCcsImportButton: req.HideCcsImportButton,
PurchaseSubscriptionEnabled: purchaseEnabled,
PurchaseSubscriptionURL: purchaseURL,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
OpsMonitoringEnabled: func() bool {
if req.OpsMonitoringEnabled != nil {
return *req.OpsMonitoringEnabled
@@ -311,6 +368,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
TotpEnabled: updatedSettings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername,
@@ -332,6 +393,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
ContactInfo: updatedSettings.ContactInfo,
DocURL: updatedSettings.DocURL,
HomeContent: updatedSettings.HomeContent,
HideCcsImportButton: updatedSettings.HideCcsImportButton,
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
EnableModelFallback: updatedSettings.EnableModelFallback,
@@ -376,6 +440,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
changed = append(changed, "email_verify_enabled")
}
if before.PasswordResetEnabled != after.PasswordResetEnabled {
changed = append(changed, "password_reset_enabled")
}
if before.TotpEnabled != after.TotpEnabled {
changed = append(changed, "totp_enabled")
}
if before.SMTPHost != after.SMTPHost {
changed = append(changed, "smtp_host")
}
@@ -439,6 +509,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.HomeContent != after.HomeContent {
changed = append(changed, "home_content")
}
if before.HideCcsImportButton != after.HideCcsImportButton {
changed = append(changed, "hide_ccs_import_button")
}
if before.DefaultConcurrency != after.DefaultConcurrency {
changed = append(changed, "default_concurrency")
}

View File

@@ -53,9 +53,9 @@ type BulkAssignSubscriptionRequest struct {
Notes string `json:"notes"`
}
// ExtendSubscriptionRequest represents extend subscription request
type ExtendSubscriptionRequest struct {
Days int `json:"days" binding:"required,min=1,max=36500"` // max 100 years
// AdjustSubscriptionRequest represents adjust subscription request (extend or shorten)
type AdjustSubscriptionRequest struct {
Days int `json:"days" binding:"required,min=-36500,max=36500"` // negative to shorten, positive to extend
}
// List handles listing all subscriptions with pagination and filters
@@ -77,15 +77,19 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
}
status := c.Query("status")
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
// Parse sorting parameters
sortBy := c.DefaultQuery("sort_by", "created_at")
sortOrder := c.DefaultQuery("sort_order", "desc")
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.UserSubscription, 0, len(subscriptions))
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
}
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
}
@@ -105,7 +109,7 @@ func (h *SubscriptionHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.UserSubscriptionFromService(subscription))
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
}
// GetProgress handles getting subscription usage progress
@@ -150,7 +154,7 @@ func (h *SubscriptionHandler) Assign(c *gin.Context) {
return
}
response.Success(c, dto.UserSubscriptionFromService(subscription))
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
}
// BulkAssign handles bulk assigning subscriptions to multiple users
@@ -180,7 +184,7 @@ func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
response.Success(c, dto.BulkAssignResultFromService(result))
}
// Extend handles extending a subscription
// Extend handles adjusting a subscription (extend or shorten)
// POST /api/v1/admin/subscriptions/:id/extend
func (h *SubscriptionHandler) Extend(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
@@ -189,7 +193,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
return
}
var req ExtendSubscriptionRequest
var req AdjustSubscriptionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
@@ -201,7 +205,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
return
}
response.Success(c, dto.UserSubscriptionFromService(subscription))
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
}
// Revoke handles revoking a subscription
@@ -239,9 +243,9 @@ func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
return
}
out := make([]dto.UserSubscription, 0, len(subscriptions))
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
}
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
}
@@ -261,9 +265,9 @@ func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
return
}
out := make([]dto.UserSubscription, 0, len(subscriptions))
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
}
response.Success(c, out)
}

View File

@@ -163,7 +163,7 @@ func (h *UsageHandler) List(c *gin.Context) {
return
}
out := make([]dto.UsageLog, 0, len(records))
out := make([]dto.AdminUsageLog, 0, len(records))
for i := range records {
out = append(out, *dto.UsageLogFromServiceAdmin(&records[i]))
}

View File

@@ -84,9 +84,9 @@ func (h *UserHandler) List(c *gin.Context) {
return
}
out := make([]dto.User, 0, len(users))
out := make([]dto.AdminUser, 0, len(users))
for i := range users {
out = append(out, *dto.UserFromService(&users[i]))
out = append(out, *dto.UserFromServiceAdmin(&users[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
@@ -129,7 +129,7 @@ func (h *UserHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.UserFromService(user))
response.Success(c, dto.UserFromServiceAdmin(user))
}
// Create handles creating a new user
@@ -155,7 +155,7 @@ func (h *UserHandler) Create(c *gin.Context) {
return
}
response.Success(c, dto.UserFromService(user))
response.Success(c, dto.UserFromServiceAdmin(user))
}
// Update handles updating a user
@@ -189,7 +189,7 @@ func (h *UserHandler) Update(c *gin.Context) {
return
}
response.Success(c, dto.UserFromService(user))
response.Success(c, dto.UserFromServiceAdmin(user))
}
// Delete handles deleting a user
@@ -231,7 +231,7 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
return
}
response.Success(c, dto.UserFromService(user))
response.Success(c, dto.UserFromServiceAdmin(user))
}
// GetUserAPIKeys handles getting user's API keys

View File

@@ -1,6 +1,8 @@
package handler
import (
"log/slog"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
@@ -18,16 +20,18 @@ type AuthHandler struct {
userService *service.UserService
settingSvc *service.SettingService
promoService *service.PromoService
totpService *service.TotpService
}
// NewAuthHandler creates a new AuthHandler
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService) *AuthHandler {
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, totpService *service.TotpService) *AuthHandler {
return &AuthHandler{
cfg: cfg,
authService: authService,
userService: userService,
settingSvc: settingService,
promoService: promoService,
totpService: totpService,
}
}
@@ -144,6 +148,100 @@ func (h *AuthHandler) Login(c *gin.Context) {
return
}
// Check if TOTP 2FA is enabled for this user
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
// Create a temporary login session for 2FA
tempToken, err := h.totpService.CreateLoginSession(c.Request.Context(), user.ID, user.Email)
if err != nil {
response.InternalError(c, "Failed to create 2FA session")
return
}
response.Success(c, TotpLoginResponse{
Requires2FA: true,
TempToken: tempToken,
UserEmailMasked: service.MaskEmail(user.Email),
})
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
User: dto.UserFromService(user),
})
}
// TotpLoginResponse represents the response when 2FA is required
type TotpLoginResponse struct {
Requires2FA bool `json:"requires_2fa"`
TempToken string `json:"temp_token,omitempty"`
UserEmailMasked string `json:"user_email_masked,omitempty"`
}
// Login2FARequest represents the 2FA login request
type Login2FARequest struct {
TempToken string `json:"temp_token" binding:"required"`
TotpCode string `json:"totp_code" binding:"required,len=6"`
}
// Login2FA completes the login with 2FA verification
// POST /api/v1/auth/login/2fa
func (h *AuthHandler) Login2FA(c *gin.Context) {
var req Login2FARequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
slog.Debug("login_2fa_request",
"temp_token_len", len(req.TempToken),
"totp_code_len", len(req.TotpCode))
// Get the login session
session, err := h.totpService.GetLoginSession(c.Request.Context(), req.TempToken)
if err != nil || session == nil {
tokenPrefix := ""
if len(req.TempToken) >= 8 {
tokenPrefix = req.TempToken[:8]
}
slog.Debug("login_2fa_session_invalid",
"temp_token_prefix", tokenPrefix,
"error", err)
response.BadRequest(c, "Invalid or expired 2FA session")
return
}
slog.Debug("login_2fa_session_found",
"user_id", session.UserID,
"email", session.Email)
// Verify the TOTP code
if err := h.totpService.VerifyCode(c.Request.Context(), session.UserID, req.TotpCode); err != nil {
slog.Debug("login_2fa_verify_failed",
"user_id", session.UserID,
"error", err)
response.ErrorFrom(c, err)
return
}
// Delete the login session
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
// Get the user
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Generate the JWT token
token, err := h.authService.GenerateToken(user)
if err != nil {
response.InternalError(c, "Failed to generate token")
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
@@ -195,6 +293,15 @@ type ValidatePromoCodeResponse struct {
// ValidatePromoCode 验证优惠码(公开接口,注册前调用)
// POST /api/v1/auth/validate-promo-code
func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
// 检查优惠码功能是否启用
if h.settingSvc != nil && !h.settingSvc.IsPromoCodeEnabled(c.Request.Context()) {
response.Success(c, ValidatePromoCodeResponse{
Valid: false,
ErrorCode: "PROMO_CODE_DISABLED",
})
return
}
var req ValidatePromoCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
@@ -238,3 +345,85 @@ func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
BonusAmount: promoCode.BonusAmount,
})
}
// ForgotPasswordRequest 忘记密码请求
type ForgotPasswordRequest struct {
Email string `json:"email" binding:"required,email"`
TurnstileToken string `json:"turnstile_token"`
}
// ForgotPasswordResponse 忘记密码响应
type ForgotPasswordResponse struct {
Message string `json:"message"`
}
// ForgotPassword 请求密码重置
// POST /api/v1/auth/forgot-password
func (h *AuthHandler) ForgotPassword(c *gin.Context) {
var req ForgotPasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Turnstile 验证
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
response.ErrorFrom(c, err)
return
}
// Build frontend base URL from request
scheme := "https"
if c.Request.TLS == nil {
// Check X-Forwarded-Proto header (common in reverse proxy setups)
if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" {
scheme = proto
} else {
scheme = "http"
}
}
frontendBaseURL := scheme + "://" + c.Request.Host
// Request password reset (async)
// Note: This returns success even if email doesn't exist (to prevent enumeration)
if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, ForgotPasswordResponse{
Message: "If your email is registered, you will receive a password reset link shortly.",
})
}
// ResetPasswordRequest 重置密码请求
type ResetPasswordRequest struct {
Email string `json:"email" binding:"required,email"`
Token string `json:"token" binding:"required"`
NewPassword string `json:"new_password" binding:"required,min=6"`
}
// ResetPasswordResponse 重置密码响应
type ResetPasswordResponse struct {
Message string `json:"message"`
}
// ResetPassword 重置密码
// POST /api/v1/auth/reset-password
func (h *AuthHandler) ResetPassword(c *gin.Context) {
var req ResetPasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Reset password
if err := h.authService.ResetPassword(c.Request.Context(), req.Email, req.Token, req.NewPassword); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, ResetPasswordResponse{
Message: "Your password has been reset successfully. You can now log in with your new password.",
})
}

View File

@@ -15,7 +15,6 @@ func UserFromServiceShallow(u *service.User) *User {
ID: u.ID,
Email: u.Email,
Username: u.Username,
Notes: u.Notes,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
@@ -48,6 +47,22 @@ func UserFromService(u *service.User) *User {
return out
}
// UserFromServiceAdmin converts a service User to DTO for admin users.
// It includes notes - user-facing endpoints must not use this.
func UserFromServiceAdmin(u *service.User) *AdminUser {
if u == nil {
return nil
}
base := UserFromService(u)
if base == nil {
return nil
}
return &AdminUser{
User: *base,
Notes: u.Notes,
}
}
func APIKeyFromService(k *service.APIKey) *APIKey {
if k == nil {
return nil
@@ -72,36 +87,29 @@ func GroupFromServiceShallow(g *service.Group) *Group {
if g == nil {
return nil
}
return &Group{
ID: g.ID,
Name: g.Name,
Description: g.Description,
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUSD,
WeeklyLimitUSD: g.WeeklyLimitUSD,
MonthlyLimitUSD: g.MonthlyLimitUSD,
ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
AccountCount: g.AccountCount,
}
out := groupFromServiceBase(g)
return &out
}
func GroupFromService(g *service.Group) *Group {
if g == nil {
return nil
}
out := GroupFromServiceShallow(g)
return GroupFromServiceShallow(g)
}
// GroupFromServiceAdmin converts a service Group to DTO for admin users.
// It includes internal fields like model_routing and account_count.
func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
if g == nil {
return nil
}
out := &AdminGroup{
Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
AccountCount: g.AccountCount,
}
if len(g.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
for i := range g.AccountGroups {
@@ -112,6 +120,29 @@ func GroupFromService(g *service.Group) *Group {
return out
}
func groupFromServiceBase(g *service.Group) Group {
return Group{
ID: g.ID,
Name: g.Name,
Description: g.Description,
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUSD,
WeeklyLimitUSD: g.WeeklyLimitUSD,
MonthlyLimitUSD: g.MonthlyLimitUSD,
ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}
}
func AccountFromServiceShallow(a *service.Account) *Account {
if a == nil {
return nil
@@ -273,7 +304,24 @@ func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
if rc == nil {
return nil
}
return &RedeemCode{
out := redeemCodeFromServiceBase(rc)
return &out
}
// RedeemCodeFromServiceAdmin converts a service RedeemCode to DTO for admin users.
// It includes notes - user-facing endpoints must not use this.
func RedeemCodeFromServiceAdmin(rc *service.RedeemCode) *AdminRedeemCode {
if rc == nil {
return nil
}
return &AdminRedeemCode{
RedeemCode: redeemCodeFromServiceBase(rc),
Notes: rc.Notes,
}
}
func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode {
return RedeemCode{
ID: rc.ID,
Code: rc.Code,
Type: rc.Type,
@@ -281,7 +329,6 @@ func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
Status: rc.Status,
UsedBy: rc.UsedBy,
UsedAt: rc.UsedAt,
Notes: rc.Notes,
CreatedAt: rc.CreatedAt,
GroupID: rc.GroupID,
ValidityDays: rc.ValidityDays,
@@ -302,14 +349,9 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary {
}
}
// usageLogFromServiceBase is a helper that converts service UsageLog to DTO.
// The account parameter allows caller to control what Account info is included.
// The includeIPAddress parameter controls whether to include the IP address (admin-only).
func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, includeIPAddress bool) *UsageLog {
if l == nil {
return nil
}
result := &UsageLog{
func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
// 普通用户 DTO严禁包含管理员字段例如 account_rate_multiplier、ip_address、account
return UsageLog{
ID: l.ID,
UserID: l.UserID,
APIKeyID: l.APIKeyID,
@@ -331,7 +373,6 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, inclu
TotalCost: l.TotalCost,
ActualCost: l.ActualCost,
RateMultiplier: l.RateMultiplier,
AccountRateMultiplier: l.AccountRateMultiplier,
BillingType: l.BillingType,
Stream: l.Stream,
DurationMs: l.DurationMs,
@@ -342,30 +383,33 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, inclu
CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User),
APIKey: APIKeyFromService(l.APIKey),
Account: account,
Group: GroupFromServiceShallow(l.Group),
Subscription: UserSubscriptionFromService(l.Subscription),
}
// IP 地址仅对管理员可见
if includeIPAddress {
result.IPAddress = l.IPAddress
}
return result
}
// UsageLogFromService converts a service UsageLog to DTO for regular users.
// It excludes Account details and IP address - users should not see these.
func UsageLogFromService(l *service.UsageLog) *UsageLog {
return usageLogFromServiceBase(l, nil, false)
if l == nil {
return nil
}
u := usageLogFromServiceUser(l)
return &u
}
// UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users.
// It includes minimal Account info (ID, Name only) and IP address.
func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog {
func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog {
if l == nil {
return nil
}
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true)
return &AdminUsageLog{
UsageLog: usageLogFromServiceUser(l),
AccountRateMultiplier: l.AccountRateMultiplier,
IPAddress: l.IPAddress,
Account: AccountSummaryFromService(l.Account),
}
}
func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTask {
@@ -414,7 +458,27 @@ func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscriptio
if sub == nil {
return nil
}
return &UserSubscription{
out := userSubscriptionFromServiceBase(sub)
return &out
}
// UserSubscriptionFromServiceAdmin converts a service UserSubscription to DTO for admin users.
// It includes assignment metadata and notes.
func UserSubscriptionFromServiceAdmin(sub *service.UserSubscription) *AdminUserSubscription {
if sub == nil {
return nil
}
return &AdminUserSubscription{
UserSubscription: userSubscriptionFromServiceBase(sub),
AssignedBy: sub.AssignedBy,
AssignedAt: sub.AssignedAt,
Notes: sub.Notes,
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
}
}
func userSubscriptionFromServiceBase(sub *service.UserSubscription) UserSubscription {
return UserSubscription{
ID: sub.ID,
UserID: sub.UserID,
GroupID: sub.GroupID,
@@ -427,14 +491,10 @@ func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscriptio
DailyUsageUSD: sub.DailyUsageUSD,
WeeklyUsageUSD: sub.WeeklyUsageUSD,
MonthlyUsageUSD: sub.MonthlyUsageUSD,
AssignedBy: sub.AssignedBy,
AssignedAt: sub.AssignedAt,
Notes: sub.Notes,
CreatedAt: sub.CreatedAt,
UpdatedAt: sub.UpdatedAt,
User: UserFromServiceShallow(sub.User),
Group: GroupFromServiceShallow(sub.Group),
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
}
}
@@ -442,9 +502,9 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
if r == nil {
return nil
}
subs := make([]UserSubscription, 0, len(r.Subscriptions))
subs := make([]AdminUserSubscription, 0, len(r.Subscriptions))
for i := range r.Subscriptions {
subs = append(subs, *UserSubscriptionFromService(&r.Subscriptions[i]))
subs = append(subs, *UserSubscriptionFromServiceAdmin(&r.Subscriptions[i]))
}
return &BulkAssignResult{
SuccessCount: r.SuccessCount,

View File

@@ -2,8 +2,12 @@ package dto
// SystemSettings represents the admin settings API response payload.
type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
@@ -22,13 +26,16 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
@@ -52,19 +59,25 @@ type SystemSettings struct {
}
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version"`
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version"`
}
// StreamTimeoutSettings 流超时处理配置 DTO

View File

@@ -6,7 +6,6 @@ type User struct {
ID int64 `json:"id"`
Email string `json:"email"`
Username string `json:"username"`
Notes string `json:"notes"`
Role string `json:"role"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
@@ -19,6 +18,14 @@ type User struct {
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
}
// AdminUser 是管理员接口使用的 user DTO包含敏感/内部字段)。
// 注意:普通用户接口不得返回 notes 等管理员备注信息。
type AdminUser struct {
User
Notes string `json:"notes"`
}
type APIKey struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
@@ -58,13 +65,19 @@ type Group struct {
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// AdminGroup 是管理员接口使用的 group DTO包含敏感/内部字段)。
// 注意:普通用户接口不得返回 model_routing/account_count/account_groups 等内部信息。
type AdminGroup struct {
Group
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
AccountCount int64 `json:"account_count,omitempty"`
}
@@ -180,7 +193,6 @@ type RedeemCode struct {
Status string `json:"status"`
UsedBy *int64 `json:"used_by"`
UsedAt *time.Time `json:"used_at"`
Notes string `json:"notes"`
CreatedAt time.Time `json:"created_at"`
GroupID *int64 `json:"group_id"`
@@ -190,6 +202,15 @@ type RedeemCode struct {
Group *Group `json:"group,omitempty"`
}
// AdminRedeemCode 是管理员接口使用的 redeem code DTO包含 notes 等字段)。
// 注意:普通用户接口不得返回 notes 等内部信息。
type AdminRedeemCode struct {
RedeemCode
Notes string `json:"notes"`
}
// UsageLog 是普通用户接口使用的 usage log DTO不包含管理员字段
type UsageLog struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
@@ -209,14 +230,13 @@ type UsageLog struct {
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
InputCost float64 `json:"input_cost"`
OutputCost float64 `json:"output_cost"`
CacheCreationCost float64 `json:"cache_creation_cost"`
CacheReadCost float64 `json:"cache_read_cost"`
TotalCost float64 `json:"total_cost"`
ActualCost float64 `json:"actual_cost"`
RateMultiplier float64 `json:"rate_multiplier"`
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
InputCost float64 `json:"input_cost"`
OutputCost float64 `json:"output_cost"`
CacheCreationCost float64 `json:"cache_creation_cost"`
CacheReadCost float64 `json:"cache_read_cost"`
TotalCost float64 `json:"total_cost"`
ActualCost float64 `json:"actual_cost"`
RateMultiplier float64 `json:"rate_multiplier"`
BillingType int8 `json:"billing_type"`
Stream bool `json:"stream"`
@@ -230,18 +250,28 @@ type UsageLog struct {
// User-Agent
UserAgent *string `json:"user_agent"`
// IP 地址(仅管理员可见)
IPAddress *string `json:"ip_address,omitempty"`
CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"`
APIKey *APIKey `json:"api_key,omitempty"`
Account *AccountSummary `json:"account,omitempty"` // Use minimal AccountSummary to prevent data leakage
Group *Group `json:"group,omitempty"`
Subscription *UserSubscription `json:"subscription,omitempty"`
}
// AdminUsageLog 是管理员接口使用的 usage log DTO包含管理员字段
type AdminUsageLog struct {
UsageLog
// AccountRateMultiplier 账号计费倍率快照nil 表示按 1.0 处理)
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
// IPAddress 用户请求 IP仅管理员可见
IPAddress *string `json:"ip_address,omitempty"`
// Account 最小账号信息(避免泄露敏感字段)
Account *AccountSummary `json:"account,omitempty"`
}
type UsageCleanupFilters struct {
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
@@ -300,23 +330,30 @@ type UserSubscription struct {
WeeklyUsageUSD float64 `json:"weekly_usage_usd"`
MonthlyUsageUSD float64 `json:"monthly_usage_usd"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
}
// AdminUserSubscription 是管理员接口使用的订阅 DTO包含分配信息/备注等字段)。
// 注意:普通用户接口不得返回 assigned_by/assigned_at/notes/assigned_by_user 等管理员字段。
type AdminUserSubscription struct {
UserSubscription
AssignedBy *int64 `json:"assigned_by"`
AssignedAt time.Time `json:"assigned_at"`
Notes string `json:"notes"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
AssignedByUser *User `json:"assigned_by_user,omitempty"`
AssignedByUser *User `json:"assigned_by_user,omitempty"`
}
type BulkAssignResult struct {
SuccessCount int `json:"success_count"`
FailedCount int `json:"failed_count"`
Subscriptions []UserSubscription `json:"subscriptions"`
Errors []string `json:"errors"`
SuccessCount int `json:"success_count"`
FailedCount int `json:"failed_count"`
Subscriptions []AdminUserSubscription `json:"subscriptions"`
Errors []string `json:"errors"`
}
// PromoCode 注册优惠码

View File

@@ -209,17 +209,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account := selection.Account
setOpsSelectedAccount(c, account.ID)
// 检查预热请求拦截(在账号选择后、转发前检查
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
// 检查请求拦截(预热请求、SUGGESTION MODE等
if account.IsInterceptWarmupEnabled() {
interceptType := detectInterceptType(body)
if interceptType != InterceptTypeNone {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if reqStream {
sendMockInterceptStream(c, reqModel, interceptType)
} else {
sendMockInterceptResponse(c, reqModel, interceptType)
}
return
}
if reqStream {
sendMockWarmupStream(c, reqModel)
} else {
sendMockWarmupResponse(c, reqModel)
}
return
}
// 3. 获取账号并发槽位
@@ -344,17 +347,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account := selection.Account
setOpsSelectedAccount(c, account.ID)
// 检查预热请求拦截(在账号选择后、转发前检查
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
// 检查请求拦截(预热请求、SUGGESTION MODE等
if account.IsInterceptWarmupEnabled() {
interceptType := detectInterceptType(body)
if interceptType != InterceptTypeNone {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if reqStream {
sendMockInterceptStream(c, reqModel, interceptType)
} else {
sendMockInterceptResponse(c, reqModel, interceptType)
}
return
}
if reqStream {
sendMockWarmupStream(c, reqModel)
} else {
sendMockWarmupResponse(c, reqModel)
}
return
}
// 3. 获取账号并发槽位
@@ -765,17 +771,30 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
}
}
// isWarmupRequest 检测是否为预热请求标题生成、Warmup等
func isWarmupRequest(body []byte) bool {
// 快速检查如果body不包含关键字直接返回false
// InterceptType 表示请求拦截类型
type InterceptType int
const (
InterceptTypeNone InterceptType = iota
InterceptTypeWarmup // 预热请求(返回 "New Conversation"
InterceptTypeSuggestionMode // SUGGESTION MODE返回空字符串
)
// detectInterceptType 检测请求是否需要拦截,返回拦截类型
func detectInterceptType(body []byte) InterceptType {
// 快速检查:如果不包含任何关键字,直接返回
bodyStr := string(body)
if !strings.Contains(bodyStr, "title") && !strings.Contains(bodyStr, "Warmup") {
return false
hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:")
hasWarmupKeyword := strings.Contains(bodyStr, "title") || strings.Contains(bodyStr, "Warmup")
if !hasSuggestionMode && !hasWarmupKeyword {
return InterceptTypeNone
}
// 解析完整请求
// 解析请求(只解析一次)
var req struct {
Messages []struct {
Role string `json:"role"`
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
@@ -786,43 +805,71 @@ func isWarmupRequest(body []byte) bool {
} `json:"system"`
}
if err := json.Unmarshal(body, &req); err != nil {
return false
return InterceptTypeNone
}
// 检查 messages 中的标题提示模式
for _, msg := range req.Messages {
for _, content := range msg.Content {
if content.Type == "text" {
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
content.Text == "Warmup" {
return true
// 检查 SUGGESTION MODE最后一条 user 消息)
if hasSuggestionMode && len(req.Messages) > 0 {
lastMsg := req.Messages[len(req.Messages)-1]
if lastMsg.Role == "user" && len(lastMsg.Content) > 0 &&
lastMsg.Content[0].Type == "text" &&
strings.HasPrefix(lastMsg.Content[0].Text, "[SUGGESTION MODE:") {
return InterceptTypeSuggestionMode
}
}
// 检查 Warmup 请求
if hasWarmupKeyword {
// 检查 messages 中的标题提示模式
for _, msg := range req.Messages {
for _, content := range msg.Content {
if content.Type == "text" {
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
content.Text == "Warmup" {
return InterceptTypeWarmup
}
}
}
}
// 检查 system 中的标题提取模式
for _, sys := range req.System {
if strings.Contains(sys.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
return InterceptTypeWarmup
}
}
}
// 检查 system 中的标题提取模式
for _, system := range req.System {
if strings.Contains(system.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
return true
}
}
return false
return InterceptTypeNone
}
// sendMockWarmupStream 发送流式 mock 响应(用于预热请求拦截)
func sendMockWarmupStream(c *gin.Context, model string) {
// sendMockInterceptStream 发送流式 mock 响应(用于请求拦截)
func sendMockInterceptStream(c *gin.Context, model string, interceptType InterceptType) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
// 根据拦截类型决定响应内容
var msgID string
var outputTokens int
var textDeltas []string
switch interceptType {
case InterceptTypeSuggestionMode:
msgID = "msg_mock_suggestion"
outputTokens = 1
textDeltas = []string{""} // 空内容
default: // InterceptTypeWarmup
msgID = "msg_mock_warmup"
outputTokens = 2
textDeltas = []string{"New", " Conversation"}
}
// Build message_start event with proper JSON marshaling
messageStart := map[string]any{
"type": "message_start",
"message": map[string]any{
"id": "msg_mock_warmup",
"id": msgID,
"type": "message",
"role": "assistant",
"model": model,
@@ -837,16 +884,46 @@ func sendMockWarmupStream(c *gin.Context, model string) {
}
messageStartJSON, _ := json.Marshal(messageStart)
// Build events
events := []string{
`event: message_start` + "\n" + `data: ` + string(messageStartJSON),
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
`event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
`event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
`event: content_block_stop` + "\n" + `data: {"index":0,"type":"content_block_stop"}`,
`event: message_delta` + "\n" + `data: {"delta":{"stop_reason":"end_turn","stop_sequence":null},"type":"message_delta","usage":{"input_tokens":10,"output_tokens":2}}`,
`event: message_stop` + "\n" + `data: {"type":"message_stop"}`,
}
// Add text deltas
for _, text := range textDeltas {
delta := map[string]any{
"type": "content_block_delta",
"index": 0,
"delta": map[string]string{
"type": "text_delta",
"text": text,
},
}
deltaJSON, _ := json.Marshal(delta)
events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON))
}
// Add final events
messageDelta := map[string]any{
"type": "message_delta",
"delta": map[string]any{
"stop_reason": "end_turn",
"stop_sequence": nil,
},
"usage": map[string]int{
"input_tokens": 10,
"output_tokens": outputTokens,
},
}
messageDeltaJSON, _ := json.Marshal(messageDelta)
events = append(events,
`event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`,
`event: message_delta`+"\n"+`data: `+string(messageDeltaJSON),
`event: message_stop`+"\n"+`data: {"type":"message_stop"}`,
)
for _, event := range events {
_, _ = c.Writer.WriteString(event + "\n\n")
c.Writer.Flush()
@@ -854,18 +931,32 @@ func sendMockWarmupStream(c *gin.Context, model string) {
}
}
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
func sendMockWarmupResponse(c *gin.Context, model string) {
// sendMockInterceptResponse 发送非流式 mock 响应(用于请求拦截)
func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) {
var msgID, text string
var outputTokens int
switch interceptType {
case InterceptTypeSuggestionMode:
msgID = "msg_mock_suggestion"
text = ""
outputTokens = 1
default: // InterceptTypeWarmup
msgID = "msg_mock_warmup"
text = "New Conversation"
outputTokens = 2
}
c.JSON(http.StatusOK, gin.H{
"id": "msg_mock_warmup",
"id": msgID,
"type": "message",
"role": "assistant",
"model": model,
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
"content": []gin.H{{"type": "text", "text": text}},
"stop_reason": "end_turn",
"usage": gin.H{
"input_tokens": 10,
"output_tokens": 2,
"output_tokens": outputTokens,
},
})
}

View File

@@ -0,0 +1,122 @@
//go:build unit
package handler
import (
"crypto/sha256"
"encoding/hex"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestExtractGeminiCLISessionHash(t *testing.T) {
tests := []struct {
name string
body string
privilegedUserID string
wantEmpty bool
wantHash string
}{
{
name: "with privileged-user-id and tmp dir",
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: false,
wantHash: func() string {
combined := "90785f52-8bbe-4b17-b111-a1ddea1636c3:f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"
hash := sha256.Sum256([]byte(combined))
return hex.EncodeToString(hash[:])
}(),
},
{
name: "without privileged-user-id but with tmp dir",
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
privilegedUserID: "",
wantEmpty: false,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "without tmp dir",
body: `{"contents":[{"parts":[{"text":"Hello world"}]}]}`,
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: true,
},
{
name: "empty body",
body: "",
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
wantEmpty: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 创建测试上下文
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/test", nil)
if tt.privilegedUserID != "" {
c.Request.Header.Set("x-gemini-api-privileged-user-id", tt.privilegedUserID)
}
// 调用函数
result := extractGeminiCLISessionHash(c, []byte(tt.body))
// 验证结果
if tt.wantEmpty {
require.Empty(t, result, "expected empty session hash")
} else {
require.NotEmpty(t, result, "expected non-empty session hash")
require.Equal(t, tt.wantHash, result, "session hash mismatch")
}
})
}
}
func TestGeminiCLITmpDirRegex(t *testing.T) {
tests := []struct {
name string
input string
wantMatch bool
wantHash string
}{
{
name: "valid tmp dir path",
input: "/Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
wantMatch: true,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "valid tmp dir path in text",
input: "The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740\nOther text",
wantMatch: true,
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
},
{
name: "invalid hash length",
input: "/Users/ianshaw/.gemini/tmp/abc123",
wantMatch: false,
},
{
name: "no tmp dir",
input: "Hello world",
wantMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
match := geminiCLITmpDirRegex.FindStringSubmatch(tt.input)
if tt.wantMatch {
require.NotNil(t, match, "expected regex to match")
require.Len(t, match, 2, "expected 2 capture groups")
require.Equal(t, tt.wantHash, match[1], "hash mismatch")
} else {
require.Nil(t, match, "expected regex not to match")
}
})
}
}

View File

@@ -1,11 +1,15 @@
package handler
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"io"
"log"
"net/http"
"regexp"
"strings"
"time"
@@ -19,6 +23,17 @@ import (
"github.com/gin-gonic/gin"
)
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
func isGeminiCLIRequest(c *gin.Context, body []byte) bool {
if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" {
return true
}
return geminiCLITmpDirRegex.Match(body)
}
// GeminiV1BetaListModels proxies:
// GET /v1beta/models
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
@@ -214,12 +229,26 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 3) select account (sticky session based on request body)
parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 优先使用 Gemini CLI 的会话标识privileged-user-id + tmp 目录哈希)
sessionHash := extractGeminiCLISessionHash(c, body)
if sessionHash == "" {
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
}
sessionKey := sessionHash
if sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
// 查询粘性会话绑定的账号 ID用于检测账号切换
var sessionBoundAccountID int64
if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
}
isCLI := isGeminiCLIRequest(c, body)
cleanedForUnknownBinding := false
maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
@@ -238,6 +267,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
account := selection.Account
setOpsSelectedAccount(c, account.ID)
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
// 注意Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
body = service.CleanGeminiNativeThoughtSignatures(body)
sessionBoundAccountID = account.ID
} else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
// 无缓存绑定但请求里已有 thoughtSignature常见于缓存丢失/TTL 过期后CLI 继续携带旧签名。
// 为避免第一次转发就 400这里做一次确定性清理让新账号重新生成签名链路。
log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively")
body = service.CleanGeminiNativeThoughtSignatures(body)
cleanedForUnknownBinding = true
sessionBoundAccountID = account.ID
} else if sessionBoundAccountID == 0 {
// 记录本次请求中首次选择到的账号,便于同一请求内 failover 时检测切换。
sessionBoundAccountID = account.ID
}
// 4) account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
@@ -433,3 +480,38 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
}
return false
}
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
//
// 会话标识生成策略:
// 1. 从请求体中提取 tmp 目录哈希64位十六进制
// 2. 从 header 中提取 privileged-user-idUUID
// 3. 组合两者生成 SHA256 哈希作为最终的会话标识
//
// 如果找不到 tmp 目录哈希,返回空字符串(不使用粘性会话)。
//
// extractGeminiCLISessionHash extracts session identifier from Gemini CLI requests.
// Combines x-gemini-api-privileged-user-id header with tmp directory hash from request body.
func extractGeminiCLISessionHash(c *gin.Context, body []byte) string {
// 1. 从请求体中提取 tmp 目录哈希
match := geminiCLITmpDirRegex.FindSubmatch(body)
if len(match) < 2 {
return "" // 没有找到 tmp 目录,不使用粘性会话
}
tmpDirHash := string(match[1])
// 2. 提取 privileged-user-id
privilegedUserID := strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id"))
// 3. 组合生成最终的 session hash
if privilegedUserID != "" {
// 组合两个标识符privileged-user-id + tmp 目录哈希
combined := privilegedUserID + ":" + tmpDirHash
hash := sha256.Sum256([]byte(combined))
return hex.EncodeToString(hash[:])
}
// 如果没有 privileged-user-id直接使用 tmp 目录哈希
return tmpDirHash
}

View File

@@ -37,6 +37,7 @@ type Handlers struct {
Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler
Setting *SettingHandler
Totp *TotpHandler
}
// BuildInfo contains build-time information

View File

@@ -32,18 +32,24 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
}
response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: h.version,
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
TotpEnabled: settings.TotpEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: h.version,
})
}

View File

@@ -0,0 +1,181 @@
package handler
import (
"github.com/gin-gonic/gin"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// TotpHandler handles TOTP-related requests
type TotpHandler struct {
totpService *service.TotpService
}
// NewTotpHandler creates a new TotpHandler
func NewTotpHandler(totpService *service.TotpService) *TotpHandler {
return &TotpHandler{
totpService: totpService,
}
}
// TotpStatusResponse represents the TOTP status response
type TotpStatusResponse struct {
Enabled bool `json:"enabled"`
EnabledAt *int64 `json:"enabled_at,omitempty"` // Unix timestamp
FeatureEnabled bool `json:"feature_enabled"`
}
// GetStatus returns the TOTP status for the current user
// GET /api/v1/user/totp/status
func (h *TotpHandler) GetStatus(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
status, err := h.totpService.GetStatus(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
resp := TotpStatusResponse{
Enabled: status.Enabled,
FeatureEnabled: status.FeatureEnabled,
}
if status.EnabledAt != nil {
ts := status.EnabledAt.Unix()
resp.EnabledAt = &ts
}
response.Success(c, resp)
}
// TotpSetupRequest represents the request to initiate TOTP setup
type TotpSetupRequest struct {
EmailCode string `json:"email_code"`
Password string `json:"password"`
}
// TotpSetupResponse represents the TOTP setup response
type TotpSetupResponse struct {
Secret string `json:"secret"`
QRCodeURL string `json:"qr_code_url"`
SetupToken string `json:"setup_token"`
Countdown int `json:"countdown"`
}
// InitiateSetup starts the TOTP setup process
// POST /api/v1/user/totp/setup
func (h *TotpHandler) InitiateSetup(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req TotpSetupRequest
if err := c.ShouldBindJSON(&req); err != nil {
// Allow empty body (optional params)
req = TotpSetupRequest{}
}
result, err := h.totpService.InitiateSetup(c.Request.Context(), subject.UserID, req.EmailCode, req.Password)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, TotpSetupResponse{
Secret: result.Secret,
QRCodeURL: result.QRCodeURL,
SetupToken: result.SetupToken,
Countdown: result.Countdown,
})
}
// TotpEnableRequest represents the request to enable TOTP
type TotpEnableRequest struct {
TotpCode string `json:"totp_code" binding:"required,len=6"`
SetupToken string `json:"setup_token" binding:"required"`
}
// Enable completes the TOTP setup
// POST /api/v1/user/totp/enable
func (h *TotpHandler) Enable(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req TotpEnableRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.totpService.CompleteSetup(c.Request.Context(), subject.UserID, req.TotpCode, req.SetupToken); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"success": true})
}
// TotpDisableRequest represents the request to disable TOTP
type TotpDisableRequest struct {
EmailCode string `json:"email_code"`
Password string `json:"password"`
}
// Disable disables TOTP for the current user
// POST /api/v1/user/totp/disable
func (h *TotpHandler) Disable(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req TotpDisableRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.totpService.Disable(c.Request.Context(), subject.UserID, req.EmailCode, req.Password); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"success": true})
}
// GetVerificationMethod returns the verification method for TOTP operations
// GET /api/v1/user/totp/verification-method
func (h *TotpHandler) GetVerificationMethod(c *gin.Context) {
method := h.totpService.GetVerificationMethod(c.Request.Context())
response.Success(c, method)
}
// SendVerifyCode sends an email verification code for TOTP operations
// POST /api/v1/user/totp/send-code
func (h *TotpHandler) SendVerifyCode(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"success": true})
}

View File

@@ -47,9 +47,6 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
return
}
// 清空notes字段普通用户不应看到备注
userData.Notes = ""
response.Success(c, dto.UserFromService(userData))
}
@@ -105,8 +102,5 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
return
}
// 清空notes字段普通用户不应看到备注
updatedUser.Notes = ""
response.Success(c, dto.UserFromService(updatedUser))
}

View File

@@ -70,6 +70,7 @@ func ProvideHandlers(
gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler,
settingHandler *SettingHandler,
totpHandler *TotpHandler,
) *Handlers {
return &Handlers{
Auth: authHandler,
@@ -82,6 +83,7 @@ func ProvideHandlers(
Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler,
Setting: settingHandler,
Totp: totpHandler,
}
}
@@ -96,6 +98,7 @@ var ProviderSet = wire.NewSet(
NewSubscriptionHandler,
NewGatewayHandler,
NewOpenAIGatewayHandler,
NewTotpHandler,
ProvideSettingHandler,
// Admin handlers

View File

@@ -7,13 +7,11 @@ import (
"fmt"
"log"
"math/rand"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
@@ -369,8 +367,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
Text: block.Thinking,
Thought: true,
}
// 保留原有 signatureClaude 模型需要有效的 signature
if block.Signature != "" {
// signature 处理:
// - Claude 模型allowDummyThought=false必须是上游返回的真实 signaturedummy 视为缺失)
// - Gemini 模型allowDummyThought=true优先透传真实 signature缺失时使用 dummy signature
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
part.ThoughtSignature = block.Signature
} else if !allowDummyThought {
// Claude 模型需要有效 signature在缺失时降级为普通文本并在上层禁用 thinking mode。
@@ -409,12 +409,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
},
}
// tool_use 的 signature 处理:
// - Gemini 模型:使用 dummy signature跳过 thought_signature 校验
// - Claude 模型:透传上游返回的真实 signatureVertex/Google 需要完整签名链路)
if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature
} else if block.Signature != "" && block.Signature != dummyThoughtSignature {
// - Claude 模型allowDummyThought=false必须是上游返回的真实 signaturedummy 视为缺失
// - Gemini 模型allowDummyThought=true优先透传真实 signature缺失时使用 dummy signature
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
part.ThoughtSignature = block.Signature
} else if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature
}
parts = append(parts, part)
@@ -594,11 +594,14 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
}
// 清理 JSON Schema
params := cleanJSONSchema(inputSchema)
// 1. 深度清理 [undefined] 值
DeepCleanUndefined(inputSchema)
// 2. 转换为符合 Gemini v1internal 的 schema
params := CleanJSONSchema(inputSchema)
// 为 nil schema 提供默认值
if params == nil {
params = map[string]any{
"type": "OBJECT",
"type": "object", // lowercase type
"properties": map[string]any{},
}
}
@@ -631,236 +634,3 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
FunctionDeclarations: funcDecls,
}}
}
// cleanJSONSchema 清理 JSON Schema移除 Antigravity/Gemini 不支持的字段
// 参考 proxycast 的实现,确保 schema 符合 JSON Schema draft 2020-12
func cleanJSONSchema(schema map[string]any) map[string]any {
if schema == nil {
return nil
}
cleaned := cleanSchemaValue(schema, "$")
result, ok := cleaned.(map[string]any)
if !ok {
return nil
}
// 确保有 type 字段(默认 OBJECT
if _, hasType := result["type"]; !hasType {
result["type"] = "OBJECT"
}
// 确保有 properties 字段(默认空对象)
if _, hasProps := result["properties"]; !hasProps {
result["properties"] = make(map[string]any)
}
// 验证 required 中的字段都存在于 properties 中
if required, ok := result["required"].([]any); ok {
if props, ok := result["properties"].(map[string]any); ok {
validRequired := make([]any, 0, len(required))
for _, r := range required {
if reqName, ok := r.(string); ok {
if _, exists := props[reqName]; exists {
validRequired = append(validRequired, r)
}
}
}
if len(validRequired) > 0 {
result["required"] = validRequired
} else {
delete(result, "required")
}
}
}
return result
}
var schemaValidationKeys = map[string]bool{
"minLength": true,
"maxLength": true,
"pattern": true,
"minimum": true,
"maximum": true,
"exclusiveMinimum": true,
"exclusiveMaximum": true,
"multipleOf": true,
"uniqueItems": true,
"minItems": true,
"maxItems": true,
"minProperties": true,
"maxProperties": true,
"patternProperties": true,
"propertyNames": true,
"dependencies": true,
"dependentSchemas": true,
"dependentRequired": true,
}
var warnedSchemaKeys sync.Map
func schemaCleaningWarningsEnabled() bool {
// 可通过环境变量强制开关方便排查SUB2API_SCHEMA_CLEAN_WARN=true/false
if v := strings.TrimSpace(os.Getenv("SUB2API_SCHEMA_CLEAN_WARN")); v != "" {
switch strings.ToLower(v) {
case "1", "true", "yes", "on":
return true
case "0", "false", "no", "off":
return false
}
}
// 默认:非 release 模式下输出debug/test
return gin.Mode() != gin.ReleaseMode
}
func warnSchemaKeyRemovedOnce(key, path string) {
if !schemaCleaningWarningsEnabled() {
return
}
if !schemaValidationKeys[key] {
return
}
if _, loaded := warnedSchemaKeys.LoadOrStore(key, struct{}{}); loaded {
return
}
log.Printf("[SchemaClean] removed unsupported JSON Schema validation field key=%q path=%q", key, path)
}
// excludedSchemaKeys 不支持的 schema 字段
// 基于 Claude API (Vertex AI) 的实际支持情况
// 支持: type, description, enum, properties, required, additionalProperties, items
// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段
var excludedSchemaKeys = map[string]bool{
// 元 schema 字段
"$schema": true,
"$id": true,
"$ref": true,
// 字符串验证Gemini 不支持)
"minLength": true,
"maxLength": true,
"pattern": true,
// 数字验证Claude API 通过 Vertex AI 不支持这些字段)
"minimum": true,
"maximum": true,
"exclusiveMinimum": true,
"exclusiveMaximum": true,
"multipleOf": true,
// 数组验证Claude API 通过 Vertex AI 不支持这些字段)
"uniqueItems": true,
"minItems": true,
"maxItems": true,
// 组合 schemaGemini 不支持)
"oneOf": true,
"anyOf": true,
"allOf": true,
"not": true,
"if": true,
"then": true,
"else": true,
"$defs": true,
"definitions": true,
// 对象验证(仅保留 properties/required/additionalProperties
"minProperties": true,
"maxProperties": true,
"patternProperties": true,
"propertyNames": true,
"dependencies": true,
"dependentSchemas": true,
"dependentRequired": true,
// 其他不支持的字段
"default": true,
"const": true,
"examples": true,
"deprecated": true,
"readOnly": true,
"writeOnly": true,
"contentMediaType": true,
"contentEncoding": true,
// Claude 特有字段
"strict": true,
}
// cleanSchemaValue 递归清理 schema 值
func cleanSchemaValue(value any, path string) any {
switch v := value.(type) {
case map[string]any:
result := make(map[string]any)
for k, val := range v {
// 跳过不支持的字段
if excludedSchemaKeys[k] {
warnSchemaKeyRemovedOnce(k, path)
continue
}
// 特殊处理 type 字段
if k == "type" {
result[k] = cleanTypeValue(val)
continue
}
// 特殊处理 format 字段:只保留 Gemini 支持的 format 值
if k == "format" {
if formatStr, ok := val.(string); ok {
// Gemini 只支持 date-time, date, time
if formatStr == "date-time" || formatStr == "date" || formatStr == "time" {
result[k] = val
}
// 其他 format 值直接跳过
}
continue
}
// 特殊处理 additionalPropertiesClaude API 只支持布尔值,不支持 schema 对象
if k == "additionalProperties" {
if boolVal, ok := val.(bool); ok {
result[k] = boolVal
} else {
// 如果是 schema 对象,转换为 false更安全的默认值
result[k] = false
}
continue
}
// 递归清理所有值
result[k] = cleanSchemaValue(val, path+"."+k)
}
return result
case []any:
// 递归处理数组中的每个元素
cleaned := make([]any, 0, len(v))
for i, item := range v {
cleaned = append(cleaned, cleanSchemaValue(item, fmt.Sprintf("%s[%d]", path, i)))
}
return cleaned
default:
return value
}
}
// cleanTypeValue 处理 type 字段,转换为大写
func cleanTypeValue(value any) any {
switch v := value.(type) {
case string:
return strings.ToUpper(v)
case []any:
// 联合类型 ["string", "null"] -> 取第一个非 null 类型
for _, t := range v {
if ts, ok := t.(string); ok && ts != "null" {
return strings.ToUpper(ts)
}
}
// 如果只有 null返回 STRING
return "STRING"
default:
return value
}
}

View File

@@ -100,7 +100,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"}
]`
t.Run("Gemini uses dummy tool_use signature", func(t *testing.T) {
t.Run("Gemini preserves provided tool_use signature", func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true)
if err != nil {
@@ -109,6 +109,23 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
if len(parts) != 1 || parts[0].FunctionCall == nil {
t.Fatalf("expected 1 functionCall part, got %+v", parts)
}
if parts[0].ThoughtSignature != "sig_tool_abc" {
t.Fatalf("expected preserved tool signature %q, got %q", "sig_tool_abc", parts[0].ThoughtSignature)
}
})
t.Run("Gemini falls back to dummy tool_use signature when missing", func(t *testing.T) {
contentNoSig := `[
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}}
]`
toolIDToName := make(map[string]string)
parts, _, err := buildParts(json.RawMessage(contentNoSig), toolIDToName, true)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
if len(parts) != 1 || parts[0].FunctionCall == nil {
t.Fatalf("expected 1 functionCall part, got %+v", parts)
}
if parts[0].ThoughtSignature != dummyThoughtSignature {
t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature)
}

View File

@@ -3,6 +3,7 @@ package antigravity
import (
"encoding/json"
"fmt"
"log"
"strings"
)
@@ -19,6 +20,15 @@ func TransformGeminiToClaude(geminiResp []byte, originalModel string) ([]byte, *
v1Resp.Response = directResp
v1Resp.ResponseID = directResp.ResponseID
v1Resp.ModelVersion = directResp.ModelVersion
} else if len(v1Resp.Response.Candidates) == 0 {
// 第一次解析成功但 candidates 为空,说明是直接的 GeminiResponse 格式
var directResp GeminiResponse
if err2 := json.Unmarshal(geminiResp, &directResp); err2 != nil {
return nil, nil, fmt.Errorf("parse gemini response as direct: %w", err2)
}
v1Resp.Response = directResp
v1Resp.ResponseID = directResp.ResponseID
v1Resp.ModelVersion = directResp.ModelVersion
}
// 使用处理器转换
@@ -173,16 +183,20 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
p.trailingSignature = ""
}
p.textBuilder += part.Text
// 非空 text 带签名 - 立即刷新并输出空 thinking 块
// 非空 text 带签名 - 特殊处理:先输出 text再输出空 thinking 块
if signature != "" {
p.flushText()
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "text",
Text: part.Text,
})
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking",
Thinking: "",
Signature: signature,
})
} else {
// 普通 text (无签名) - 累积到 builder
p.textBuilder += part.Text
}
}
}
@@ -242,6 +256,14 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
var finishReason string
if len(geminiResp.Candidates) > 0 {
finishReason = geminiResp.Candidates[0].FinishReason
if finishReason == "MALFORMED_FUNCTION_CALL" {
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in response for model %s", originalModel)
if geminiResp.Candidates[0].Content != nil {
if b, err := json.Marshal(geminiResp.Candidates[0].Content); err == nil {
log.Printf("[Antigravity] Malformed content: %s", string(b))
}
}
}
}
stopReason := "end_turn"

View File

@@ -0,0 +1,519 @@
package antigravity
import (
"fmt"
"strings"
)
// CleanJSONSchema 清理 JSON Schema移除 Antigravity/Gemini 不支持的字段
// 参考 Antigravity-Manager/src-tauri/src/proxy/common/json_schema.rs 实现
// 确保 schema 符合 JSON Schema draft 2020-12 且适配 Gemini v1internal
func CleanJSONSchema(schema map[string]any) map[string]any {
if schema == nil {
return nil
}
// 0. 预处理:展开 $ref (Schema Flattening)
// (Go map 是引用的,直接修改 schema)
flattenRefs(schema, extractDefs(schema))
// 递归清理
cleaned := cleanJSONSchemaRecursive(schema)
result, ok := cleaned.(map[string]any)
if !ok {
return nil
}
return result
}
// extractDefs 提取并移除定义的 helper
func extractDefs(schema map[string]any) map[string]any {
defs := make(map[string]any)
if d, ok := schema["$defs"].(map[string]any); ok {
for k, v := range d {
defs[k] = v
}
delete(schema, "$defs")
}
if d, ok := schema["definitions"].(map[string]any); ok {
for k, v := range d {
defs[k] = v
}
delete(schema, "definitions")
}
return defs
}
// flattenRefs 递归展开 $ref
func flattenRefs(schema map[string]any, defs map[string]any) {
if len(defs) == 0 {
return // 无需展开
}
// 检查并替换 $ref
if ref, ok := schema["$ref"].(string); ok {
delete(schema, "$ref")
// 解析引用名 (例如 #/$defs/MyType -> MyType)
parts := strings.Split(ref, "/")
refName := parts[len(parts)-1]
if defSchema, exists := defs[refName]; exists {
if defMap, ok := defSchema.(map[string]any); ok {
// 合并定义内容 (不覆盖现有 key)
for k, v := range defMap {
if _, has := schema[k]; !has {
schema[k] = deepCopy(v) // 需深拷贝避免共享引用
}
}
// 递归处理刚刚合并进来的内容
flattenRefs(schema, defs)
}
}
}
// 遍历子节点
for _, v := range schema {
if subMap, ok := v.(map[string]any); ok {
flattenRefs(subMap, defs)
} else if subArr, ok := v.([]any); ok {
for _, item := range subArr {
if itemMap, ok := item.(map[string]any); ok {
flattenRefs(itemMap, defs)
}
}
}
}
}
// deepCopy 深拷贝 (简单实现,仅针对 JSON 类型)
func deepCopy(src any) any {
if src == nil {
return nil
}
switch v := src.(type) {
case map[string]any:
dst := make(map[string]any)
for k, val := range v {
dst[k] = deepCopy(val)
}
return dst
case []any:
dst := make([]any, len(v))
for i, val := range v {
dst[i] = deepCopy(val)
}
return dst
default:
return src
}
}
// cleanJSONSchemaRecursive 递归核心清理逻辑
// 返回处理后的值 (通常是 input map但可能修改内部结构)
func cleanJSONSchemaRecursive(value any) any {
schemaMap, ok := value.(map[string]any)
if !ok {
return value
}
// 0. [NEW] 合并 allOf
mergeAllOf(schemaMap)
// 1. [CRITICAL] 深度递归处理子项
if props, ok := schemaMap["properties"].(map[string]any); ok {
for _, v := range props {
cleanJSONSchemaRecursive(v)
}
// Go 中不需要像 Rust 那样显式处理 nullable_keys remove required
// 因为我们在子项处理中会正确设置 type 和 description
} else if items, ok := schemaMap["items"]; ok {
// [FIX] Gemini 期望 "items" 是单个 Schema 对象(列表验证),而不是数组(元组验证)。
if itemsArr, ok := items.([]any); ok {
// 策略:将元组 [A, B] 视为 A、B 中的最佳匹配项。
best := extractBestSchemaFromUnion(itemsArr)
if best == nil {
// 回退到通用字符串
best = map[string]any{"type": "string"}
}
// 用处理后的对象替换原有数组
cleanedBest := cleanJSONSchemaRecursive(best)
schemaMap["items"] = cleanedBest
} else {
cleanJSONSchemaRecursive(items)
}
} else {
// 遍历所有值递归
for _, v := range schemaMap {
if _, isMap := v.(map[string]any); isMap {
cleanJSONSchemaRecursive(v)
} else if arr, isArr := v.([]any); isArr {
for _, item := range arr {
cleanJSONSchemaRecursive(item)
}
}
}
}
// 2. [FIX] 处理 anyOf/oneOf 联合类型: 合并属性而非直接删除
var unionArray []any
typeStr, _ := schemaMap["type"].(string)
if typeStr == "" || typeStr == "object" {
if anyOf, ok := schemaMap["anyOf"].([]any); ok {
unionArray = anyOf
} else if oneOf, ok := schemaMap["oneOf"].([]any); ok {
unionArray = oneOf
}
}
if len(unionArray) > 0 {
if bestBranch := extractBestSchemaFromUnion(unionArray); bestBranch != nil {
if bestMap, ok := bestBranch.(map[string]any); ok {
// 合并分支内容
for k, v := range bestMap {
if k == "properties" {
targetProps, _ := schemaMap["properties"].(map[string]any)
if targetProps == nil {
targetProps = make(map[string]any)
schemaMap["properties"] = targetProps
}
if sourceProps, ok := v.(map[string]any); ok {
for pk, pv := range sourceProps {
if _, exists := targetProps[pk]; !exists {
targetProps[pk] = deepCopy(pv)
}
}
}
} else if k == "required" {
targetReq, _ := schemaMap["required"].([]any)
if sourceReq, ok := v.([]any); ok {
for _, rv := range sourceReq {
// 简单的去重添加
exists := false
for _, tr := range targetReq {
if tr == rv {
exists = true
break
}
}
if !exists {
targetReq = append(targetReq, rv)
}
}
schemaMap["required"] = targetReq
}
} else if _, exists := schemaMap[k]; !exists {
schemaMap[k] = deepCopy(v)
}
}
}
}
}
// 3. [SAFETY] 检查当前对象是否为 JSON Schema 节点
looksLikeSchema := hasKey(schemaMap, "type") ||
hasKey(schemaMap, "properties") ||
hasKey(schemaMap, "items") ||
hasKey(schemaMap, "enum") ||
hasKey(schemaMap, "anyOf") ||
hasKey(schemaMap, "oneOf") ||
hasKey(schemaMap, "allOf")
if looksLikeSchema {
// 4. [ROBUST] 约束迁移
migrateConstraints(schemaMap)
// 5. [CRITICAL] 白名单过滤
allowedFields := map[string]bool{
"type": true,
"description": true,
"properties": true,
"required": true,
"items": true,
"enum": true,
"title": true,
}
for k := range schemaMap {
if !allowedFields[k] {
delete(schemaMap, k)
}
}
// 6. [SAFETY] 处理空 Object
if t, _ := schemaMap["type"].(string); t == "object" {
hasProps := false
if props, ok := schemaMap["properties"].(map[string]any); ok && len(props) > 0 {
hasProps = true
}
if !hasProps {
schemaMap["properties"] = map[string]any{
"reason": map[string]any{
"type": "string",
"description": "Reason for calling this tool",
},
}
schemaMap["required"] = []any{"reason"}
}
}
// 7. [SAFETY] Required 字段对齐
if props, ok := schemaMap["properties"].(map[string]any); ok {
if req, ok := schemaMap["required"].([]any); ok {
var validReq []any
for _, r := range req {
if rStr, ok := r.(string); ok {
if _, exists := props[rStr]; exists {
validReq = append(validReq, r)
}
}
}
if len(validReq) > 0 {
schemaMap["required"] = validReq
} else {
delete(schemaMap, "required")
}
}
}
// 8. 处理 type 字段 (Lowercase + Nullable 提取)
isEffectivelyNullable := false
if typeVal, exists := schemaMap["type"]; exists {
var selectedType string
switch v := typeVal.(type) {
case string:
lower := strings.ToLower(v)
if lower == "null" {
isEffectivelyNullable = true
selectedType = "string" // fallback
} else {
selectedType = lower
}
case []any:
// ["string", "null"]
for _, t := range v {
if ts, ok := t.(string); ok {
lower := strings.ToLower(ts)
if lower == "null" {
isEffectivelyNullable = true
} else if selectedType == "" {
selectedType = lower
}
}
}
if selectedType == "" {
selectedType = "string"
}
}
schemaMap["type"] = selectedType
} else {
// 默认 object 如果有 properties (虽然上面白名单过滤可能删了 type 如果它不在... 但 type 必在 allowlist)
// 如果没有 type但有 properties补一个
if hasKey(schemaMap, "properties") {
schemaMap["type"] = "object"
} else {
// 默认为 string ? or object? Gemini 通常需要明确 type
schemaMap["type"] = "object"
}
}
if isEffectivelyNullable {
desc, _ := schemaMap["description"].(string)
if !strings.Contains(desc, "nullable") {
if desc != "" {
desc += " "
}
desc += "(nullable)"
schemaMap["description"] = desc
}
}
// 9. Enum 值强制转字符串
if enumVals, ok := schemaMap["enum"].([]any); ok {
hasNonString := false
for i, val := range enumVals {
if _, isStr := val.(string); !isStr {
hasNonString = true
if val == nil {
enumVals[i] = "null"
} else {
enumVals[i] = fmt.Sprintf("%v", val)
}
}
}
// If we mandated string values, we must ensure type is string
if hasNonString {
schemaMap["type"] = "string"
}
}
}
return schemaMap
}
func hasKey(m map[string]any, k string) bool {
_, ok := m[k]
return ok
}
func migrateConstraints(m map[string]any) {
constraints := []struct {
key string
label string
}{
{"minLength", "minLen"},
{"maxLength", "maxLen"},
{"pattern", "pattern"},
{"minimum", "min"},
{"maximum", "max"},
{"multipleOf", "multipleOf"},
{"exclusiveMinimum", "exclMin"},
{"exclusiveMaximum", "exclMax"},
{"minItems", "minItems"},
{"maxItems", "maxItems"},
{"propertyNames", "propertyNames"},
{"format", "format"},
}
var hints []string
for _, c := range constraints {
if val, ok := m[c.key]; ok && val != nil {
hints = append(hints, fmt.Sprintf("%s: %v", c.label, val))
}
}
if len(hints) > 0 {
suffix := fmt.Sprintf(" [Constraint: %s]", strings.Join(hints, ", "))
desc, _ := m["description"].(string)
if !strings.Contains(desc, suffix) {
m["description"] = desc + suffix
}
}
}
// mergeAllOf 合并 allOf
func mergeAllOf(m map[string]any) {
allOf, ok := m["allOf"].([]any)
if !ok {
return
}
delete(m, "allOf")
mergedProps := make(map[string]any)
mergedReq := make(map[string]bool)
otherFields := make(map[string]any)
for _, sub := range allOf {
if subMap, ok := sub.(map[string]any); ok {
// Props
if props, ok := subMap["properties"].(map[string]any); ok {
for k, v := range props {
mergedProps[k] = v
}
}
// Required
if reqs, ok := subMap["required"].([]any); ok {
for _, r := range reqs {
if s, ok := r.(string); ok {
mergedReq[s] = true
}
}
}
// Others
for k, v := range subMap {
if k != "properties" && k != "required" && k != "allOf" {
if _, exists := otherFields[k]; !exists {
otherFields[k] = v
}
}
}
}
}
// Apply
for k, v := range otherFields {
if _, exists := m[k]; !exists {
m[k] = v
}
}
if len(mergedProps) > 0 {
existProps, _ := m["properties"].(map[string]any)
if existProps == nil {
existProps = make(map[string]any)
m["properties"] = existProps
}
for k, v := range mergedProps {
if _, exists := existProps[k]; !exists {
existProps[k] = v
}
}
}
if len(mergedReq) > 0 {
existReq, _ := m["required"].([]any)
var validReqs []any
for _, r := range existReq {
if s, ok := r.(string); ok {
validReqs = append(validReqs, s)
delete(mergedReq, s) // already exists
}
}
// append new
for r := range mergedReq {
validReqs = append(validReqs, r)
}
m["required"] = validReqs
}
}
// extractBestSchemaFromUnion 从 anyOf/oneOf 中选取最佳分支
func extractBestSchemaFromUnion(unionArray []any) any {
var bestOption any
bestScore := -1
for _, item := range unionArray {
score := scoreSchemaOption(item)
if score > bestScore {
bestScore = score
bestOption = item
}
}
return bestOption
}
func scoreSchemaOption(val any) int {
m, ok := val.(map[string]any)
if !ok {
return 0
}
typeStr, _ := m["type"].(string)
if hasKey(m, "properties") || typeStr == "object" {
return 3
}
if hasKey(m, "items") || typeStr == "array" {
return 2
}
if typeStr != "" && typeStr != "null" {
return 1
}
return 0
}
// DeepCleanUndefined 深度清理值为 "[undefined]" 的字段
func DeepCleanUndefined(value any) {
if value == nil {
return
}
switch v := value.(type) {
case map[string]any:
for k, val := range v {
if s, ok := val.(string); ok && s == "[undefined]" {
delete(v, k)
continue
}
DeepCleanUndefined(val)
}
case []any:
for _, val := range v {
DeepCleanUndefined(val)
}
}
}

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"log"
"strings"
)
@@ -102,6 +103,14 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
// 检查是否结束
if len(geminiResp.Candidates) > 0 {
finishReason := geminiResp.Candidates[0].FinishReason
if finishReason == "MALFORMED_FUNCTION_CALL" {
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in stream for model %s", p.originalModel)
if geminiResp.Candidates[0].Content != nil {
if b, err := json.Marshal(geminiResp.Candidates[0].Content); err == nil {
log.Printf("[Antigravity] Malformed content: %s", string(b))
}
}
}
if finishReason != "" {
_, _ = result.Write(p.emitFinish(finishReason))
}

View File

@@ -13,20 +13,26 @@ import (
"time"
)
// Claude OAuth Constants (from CRS project)
// Claude OAuth Constants
const (
// OAuth Client ID for Claude
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
// OAuth endpoints
AuthorizeURL = "https://claude.ai/oauth/authorize"
TokenURL = "https://console.anthropic.com/v1/oauth/token"
RedirectURI = "https://console.anthropic.com/oauth/code/callback"
TokenURL = "https://platform.claude.com/v1/oauth/token"
RedirectURI = "https://platform.claude.com/oauth/code/callback"
// Scopes
ScopeProfile = "user:profile"
// Scopes - Browser URL (includes org:create_api_key for user authorization)
ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers"
// Scopes - Internal API call (org:create_api_key not supported in API)
ScopeAPI = "user:profile user:inference user:sessions:claude_code user:mcp_servers"
// Scopes - Setup token (inference only)
ScopeInference = "user:inference"
// Code Verifier character set (RFC 7636 compliant)
codeVerifierCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
// Session TTL
SessionTTL = 30 * time.Minute
)
@@ -53,7 +59,6 @@ func NewSessionStore() *SessionStore {
sessions: make(map[string]*OAuthSession),
stopCh: make(chan struct{}),
}
// Start cleanup goroutine
go store.cleanup()
return store
}
@@ -78,7 +83,6 @@ func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
if !ok {
return nil, false
}
// Check if expired
if time.Since(session.CreatedAt) > SessionTTL {
return nil, false
}
@@ -122,13 +126,13 @@ func GenerateRandomBytes(n int) ([]byte, error) {
return b, nil
}
// GenerateState generates a random state string for OAuth
// GenerateState generates a random state string for OAuth (base64url encoded)
func GenerateState() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
return base64URLEncode(bytes), nil
}
// GenerateSessionID generates a unique session ID
@@ -140,13 +144,30 @@ func GenerateSessionID() (string, error) {
return hex.EncodeToString(bytes), nil
}
// GenerateCodeVerifier generates a PKCE code verifier (32 bytes -> base64url)
// GenerateCodeVerifier generates a PKCE code verifier using character set method
func GenerateCodeVerifier() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
const targetLen = 32
charsetLen := len(codeVerifierCharset)
limit := 256 - (256 % charsetLen)
result := make([]byte, 0, targetLen)
randBuf := make([]byte, targetLen*2)
for len(result) < targetLen {
if _, err := rand.Read(randBuf); err != nil {
return "", err
}
for _, b := range randBuf {
if int(b) < limit {
result = append(result, codeVerifierCharset[int(b)%charsetLen])
if len(result) >= targetLen {
break
}
}
}
}
return base64URLEncode(bytes), nil
return base64URLEncode(result), nil
}
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
@@ -158,42 +179,31 @@ func GenerateCodeChallenge(verifier string) string {
// base64URLEncode encodes bytes to base64url without padding
func base64URLEncode(data []byte) string {
encoded := base64.URLEncoding.EncodeToString(data)
// Remove padding
return strings.TrimRight(encoded, "=")
}
// BuildAuthorizationURL builds the OAuth authorization URL
// BuildAuthorizationURL builds the OAuth authorization URL with correct parameter order
func BuildAuthorizationURL(state, codeChallenge, scope string) string {
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", ClientID)
params.Set("redirect_uri", RedirectURI)
params.Set("scope", scope)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
encodedRedirectURI := url.QueryEscape(RedirectURI)
encodedScope := strings.ReplaceAll(url.QueryEscape(scope), "%20", "+")
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
// TokenRequest represents the token exchange request body
type TokenRequest struct {
GrantType string `json:"grant_type"`
ClientID string `json:"client_id"`
Code string `json:"code"`
RedirectURI string `json:"redirect_uri"`
CodeVerifier string `json:"code_verifier"`
State string `json:"state"`
return fmt.Sprintf("%s?code=true&client_id=%s&response_type=code&redirect_uri=%s&scope=%s&code_challenge=%s&code_challenge_method=S256&state=%s",
AuthorizeURL,
ClientID,
encodedRedirectURI,
encodedScope,
codeChallenge,
state,
)
}
// TokenResponse represents the token response from OAuth provider
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
// Organization and Account info from OAuth response
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
Organization *OrgInfo `json:"organization,omitempty"`
Account *AccountInfo `json:"account,omitempty"`
}
@@ -205,33 +215,6 @@ type OrgInfo struct {
// AccountInfo represents account info from OAuth response
type AccountInfo struct {
UUID string `json:"uuid"`
}
// RefreshTokenRequest represents the refresh token request
type RefreshTokenRequest struct {
GrantType string `json:"grant_type"`
RefreshToken string `json:"refresh_token"`
ClientID string `json:"client_id"`
}
// BuildTokenRequest creates a token exchange request
func BuildTokenRequest(code, codeVerifier, state string) *TokenRequest {
return &TokenRequest{
GrantType: "authorization_code",
ClientID: ClientID,
Code: code,
RedirectURI: RedirectURI,
CodeVerifier: codeVerifier,
State: state,
}
}
// BuildRefreshTokenRequest creates a refresh token request
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
return &RefreshTokenRequest{
GrantType: "refresh_token",
RefreshToken: refreshToken,
ClientID: ClientID,
}
UUID string `json:"uuid"`
EmailAddress string `json:"email_address"`
}

View File

@@ -2,6 +2,7 @@
package response
import (
"log"
"math"
"net/http"
@@ -74,6 +75,12 @@ func ErrorFrom(c *gin.Context, err error) bool {
}
statusCode, status := infraerrors.ToHTTP(err)
// Log internal errors with full details for debugging
if statusCode >= 500 && c.Request != nil {
log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, err.Error())
}
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
return true
}

View File

@@ -0,0 +1,278 @@
//go:build integration
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
//
// Integration tests for verifying TLS fingerprint correctness.
// These tests make actual network requests to external services and should be run manually.
//
// Run with: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
package tlsfingerprint
import (
"context"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"time"
)
// skipIfExternalServiceUnavailable checks if the external service is available.
// If not, it skips the test instead of failing.
func skipIfExternalServiceUnavailable(t *testing.T, err error) {
t.Helper()
if err != nil {
// Check for common network/TLS errors that indicate external service issues
errStr := err.Error()
if strings.Contains(errStr, "certificate has expired") ||
strings.Contains(errStr, "certificate is not yet valid") ||
strings.Contains(errStr, "connection refused") ||
strings.Contains(errStr, "no such host") ||
strings.Contains(errStr, "network is unreachable") ||
strings.Contains(errStr, "timeout") {
t.Skipf("skipping test: external service unavailable: %v", err)
}
t.Fatalf("failed to get fingerprint: %v", err)
}
}
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
// This test uses tls.peet.ws to verify the fingerprint.
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
func TestJA3Fingerprint(t *testing.T) {
// Skip if network is unavailable or if running in short mode
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
profile := &Profile{
Name: "Claude CLI Test",
EnableGREASE: false,
}
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
// Use tls.peet.ws fingerprint detection API
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
skipIfExternalServiceUnavailable(t, err)
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
}
// Log all fingerprint information
t.Logf("JA3: %s", fpResp.TLS.JA3)
t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
t.Logf("JA4: %s", fpResp.TLS.JA4)
t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
// Verify JA3 hash matches expected value
expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
if fpResp.TLS.JA3Hash == expectedJA3Hash {
t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
} else {
t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
}
// Verify JA4 fingerprint
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
// The suffix _a33745022dd6_1f22a2ca17c4 should match
expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
} else {
t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
}
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
// d = domain (SNI present), i = IP (no SNI)
// Since we connect to tls.peet.ws (domain), we expect 'd'
expectedJA4Prefix := "t13d5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
} else {
// Also accept 'i' variant for IP connections
altPrefix := "t13i5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
} else {
t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
}
}
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
} else {
t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
}
// Verify extension list (should be 11 extensions including SNI)
// Expected: 0-11-10-35-16-22-23-13-43-45-51
expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
} else {
t.Logf("Warning: JA3 extension list may differ")
}
}
// TestProfileExpectation defines expected fingerprint values for a profile.
type TestProfileExpectation struct {
Profile *Profile
ExpectedJA3 string // Expected JA3 hash (empty = don't check)
ExpectedJA4 string // Expected full JA4 (empty = don't check)
JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
}
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
// Run with: go test -v -tags=integration -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
func TestAllProfiles(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
// Define all profiles to test with their expected fingerprints
// These profiles are from config.yaml gateway.tls_fingerprint.profiles
profiles := []TestProfileExpectation{
{
// Linux x64 Node.js v22.17.1
// Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
Profile: &Profile{
Name: "linux_x64_node_v22171",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part
},
{
// MacOS arm64 Node.js v22.18.0
// Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
// Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
Profile: &Profile{
Name: "macos_arm64_node_v22180",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
},
}
for _, tc := range profiles {
tc := tc // capture range variable
t.Run(tc.Profile.Name, func(t *testing.T) {
fp := fetchFingerprint(t, tc.Profile)
if fp == nil {
return // fetchFingerprint already called t.Fatal
}
t.Logf("Profile: %s", tc.Profile.Name)
t.Logf(" JA3: %s", fp.JA3)
t.Logf(" JA3 Hash: %s", fp.JA3Hash)
t.Logf(" JA4: %s", fp.JA4)
t.Logf(" PeetPrint: %s", fp.PeetPrint)
t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
// Verify expectations
if tc.ExpectedJA3 != "" {
if fp.JA3Hash == tc.ExpectedJA3 {
t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
} else {
t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
}
}
if tc.ExpectedJA4 != "" {
if fp.JA4 == tc.ExpectedJA4 {
t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
} else {
t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
}
}
// Check JA4 cipher hash (stable middle part)
// JA4 format: prefix_cipherHash_extHash
if tc.JA4CipherHash != "" {
if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
} else {
t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
}
}
})
}
}
// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
t.Helper()
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
return nil
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
skipIfExternalServiceUnavailable(t, err)
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
return nil
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
return nil
}
return &fpResp.TLS
}

View File

@@ -1,21 +1,16 @@
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
//
// Integration tests for verifying TLS fingerprint correctness.
// These tests make actual network requests and should be run manually.
// Unit tests for TLS fingerprint dialer.
// Integration tests that require external network are in dialer_integration_test.go
// and require the 'integration' build tag.
//
// Run with: go test -v ./internal/pkg/tlsfingerprint/...
// Run integration tests: go test -v -run TestJA3 ./internal/pkg/tlsfingerprint/...
// Run unit tests: go test -v ./internal/pkg/tlsfingerprint/...
// Run integration tests: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
package tlsfingerprint
import (
"context"
"encoding/json"
"io"
"net/http"
"net/url"
"strings"
"testing"
"time"
)
// FingerprintResponse represents the response from tls.peet.ws/api/all.
@@ -36,148 +31,6 @@ type TLSInfo struct {
SessionID string `json:"session_id"`
}
// TestDialerBasicConnection tests that the dialer can establish TLS connections.
func TestDialerBasicConnection(t *testing.T) {
if testing.Short() {
t.Skip("skipping network test in short mode")
}
// Create a dialer with default profile
profile := &Profile{
Name: "Test Profile",
EnableGREASE: false,
}
dialer := NewDialer(profile, nil)
// Create HTTP client with custom TLS dialer
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
// Make a request to a known HTTPS endpoint
resp, err := client.Get("https://www.google.com")
if err != nil {
t.Fatalf("failed to connect: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
// This test uses tls.peet.ws to verify the fingerprint.
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
func TestJA3Fingerprint(t *testing.T) {
// Skip if network is unavailable or if running in short mode
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
profile := &Profile{
Name: "Claude CLI Test",
EnableGREASE: false,
}
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
// Use tls.peet.ws fingerprint detection API
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("failed to get fingerprint: %v", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
}
// Log all fingerprint information
t.Logf("JA3: %s", fpResp.TLS.JA3)
t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
t.Logf("JA4: %s", fpResp.TLS.JA4)
t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
// Verify JA3 hash matches expected value
expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
if fpResp.TLS.JA3Hash == expectedJA3Hash {
t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
} else {
t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
}
// Verify JA4 fingerprint
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
// The suffix _a33745022dd6_1f22a2ca17c4 should match
expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
} else {
t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
}
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
// d = domain (SNI present), i = IP (no SNI)
// Since we connect to tls.peet.ws (domain), we expect 'd'
expectedJA4Prefix := "t13d5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
} else {
// Also accept 'i' variant for IP connections
altPrefix := "t13i5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
} else {
t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
}
}
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
} else {
t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
}
// Verify extension list (should be 11 extensions including SNI)
// Expected: 0-11-10-35-16-22-23-13-43-45-51
expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
} else {
t.Logf("Warning: JA3 extension list may differ")
}
}
// TestDialerWithProfile tests that different profiles produce different fingerprints.
func TestDialerWithProfile(t *testing.T) {
// Create two dialers with different profiles

View File

@@ -39,9 +39,15 @@ import (
// 设计说明:
// - client: Ent 客户端,用于类型安全的 ORM 操作
// - sql: 原生 SQL 执行器,用于复杂查询和批量操作
// - schedulerCache: 调度器缓存,用于在账号状态变更时同步快照
type accountRepository struct {
client *dbent.Client // Ent ORM 客户端
sql sqlExecutor // 原生 SQL 执行接口
// schedulerCache 用于在账号状态变更时主动同步快照到缓存,
// 确保粘性会话能及时感知账号不可用状态。
// Used to proactively sync account snapshot to cache when status changes,
// ensuring sticky sessions can promptly detect unavailable accounts.
schedulerCache service.SchedulerCache
}
type tempUnschedSnapshot struct {
@@ -51,14 +57,14 @@ type tempUnschedSnapshot struct {
// NewAccountRepository 创建账户仓储实例。
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRepository {
return newAccountRepositoryWithSQL(client, sqlDB)
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
return newAccountRepositoryWithSQL(client, sqlDB, schedulerCache)
}
// newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。
// 这种设计便于单元测试时注入 mock 对象。
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accountRepository {
return &accountRepository{client: client, sql: sqlq}
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor, schedulerCache service.SchedulerCache) *accountRepository {
return &accountRepository{client: client, sql: sqlq, schedulerCache: schedulerCache}
}
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
@@ -356,6 +362,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
}
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
r.syncSchedulerAccountSnapshot(ctx, account.ID)
}
return nil
}
@@ -540,9 +549,32 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
}
r.syncSchedulerAccountSnapshot(ctx, id)
return nil
}
// syncSchedulerAccountSnapshot 在账号状态变更时主动同步快照到调度器缓存。
// 当账号被设置为错误、禁用、不可调度或临时不可调度时调用,
// 确保调度器和粘性会话逻辑能及时感知账号的最新状态,避免继续使用不可用账号。
//
// syncSchedulerAccountSnapshot proactively syncs account snapshot to scheduler cache
// when account status changes. Called when account is set to error, disabled,
// unschedulable, or temporarily unschedulable, ensuring scheduler and sticky session
// logic can promptly detect the latest account state and avoid using unavailable accounts.
func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, accountID int64) {
if r == nil || r.schedulerCache == nil || accountID <= 0 {
return
}
account, err := r.GetByID(ctx, accountID)
if err != nil {
log.Printf("[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
return
}
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
log.Printf("[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
}
}
func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
@@ -873,6 +905,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
}
r.syncSchedulerAccountSnapshot(ctx, id)
return nil
}
@@ -992,6 +1025,9 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
}
if !schedulable {
r.syncSchedulerAccountSnapshot(ctx, id)
}
return nil
}
@@ -1146,6 +1182,18 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
}
shouldSync := false
if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) {
shouldSync = true
}
if updates.Schedulable != nil && !*updates.Schedulable {
shouldSync = true
}
if shouldSync {
for _, id := range ids {
r.syncSchedulerAccountSnapshot(ctx, id)
}
}
}
return rows, nil
}

View File

@@ -21,11 +21,56 @@ type AccountRepoSuite struct {
repo *accountRepository
}
type schedulerCacheRecorder struct {
setAccounts []*service.Account
}
func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
return nil, false, nil
}
func (s *schedulerCacheRecorder) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
return nil
}
func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
return nil, nil
}
func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error {
s.setAccounts = append(s.setAccounts, account)
return nil
}
func (s *schedulerCacheRecorder) DeleteAccount(ctx context.Context, accountID int64) error {
return nil
}
func (s *schedulerCacheRecorder) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
return nil
}
func (s *schedulerCacheRecorder) TryLockBucket(ctx context.Context, bucket service.SchedulerBucket, ttl time.Duration) (bool, error) {
return true, nil
}
func (s *schedulerCacheRecorder) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
return nil, nil
}
func (s *schedulerCacheRecorder) GetOutboxWatermark(ctx context.Context) (int64, error) {
return 0, nil
}
func (s *schedulerCacheRecorder) SetOutboxWatermark(ctx context.Context, id int64) error {
return nil
}
func (s *AccountRepoSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
s.repo = newAccountRepositoryWithSQL(s.client, tx)
s.repo = newAccountRepositoryWithSQL(s.client, tx, nil)
}
func TestAccountRepoSuite(t *testing.T) {
@@ -73,6 +118,20 @@ func (s *AccountRepoSuite) TestUpdate() {
s.Require().Equal("updated", got.Name)
}
func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "sync-update", Status: service.StatusActive, Schedulable: true})
cacheRecorder := &schedulerCacheRecorder{}
s.repo.schedulerCache = cacheRecorder
account.Status = service.StatusDisabled
err := s.repo.Update(s.ctx, account)
s.Require().NoError(err, "Update")
s.Require().Len(cacheRecorder.setAccounts, 1)
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
}
func (s *AccountRepoSuite) TestDelete() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
@@ -174,7 +233,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
// 每个 case 重新获取隔离资源
tx := testEntTx(s.T())
client := tx.Client()
repo := newAccountRepositoryWithSQL(client, tx)
repo := newAccountRepositoryWithSQL(client, tx, nil)
ctx := context.Background()
tt.setup(client)
@@ -365,12 +424,38 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
func (s *AccountRepoSuite) TestSetSchedulable() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-sched", Schedulable: true})
cacheRecorder := &schedulerCacheRecorder{}
s.repo.schedulerCache = cacheRecorder
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().False(got.Schedulable)
s.Require().Len(cacheRecorder.setAccounts, 1)
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
}
func (s *AccountRepoSuite) TestBulkUpdate_SyncSchedulerSnapshotOnDisabled() {
account1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-1", Status: service.StatusActive, Schedulable: true})
account2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-2", Status: service.StatusActive, Schedulable: true})
cacheRecorder := &schedulerCacheRecorder{}
s.repo.schedulerCache = cacheRecorder
disabled := service.StatusDisabled
rows, err := s.repo.BulkUpdate(s.ctx, []int64{account1.ID, account2.ID}, service.AccountBulkUpdate{
Status: &disabled,
})
s.Require().NoError(err)
s.Require().Equal(int64(2), rows)
s.Require().Len(cacheRecorder.setAccounts, 2)
ids := map[int64]struct{}{}
for _, acc := range cacheRecorder.setAccounts {
ids[acc.ID] = struct{}{}
}
s.Require().Contains(ids, account1.ID)
s.Require().Contains(ids, account2.ID)
}
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---

View File

@@ -0,0 +1,95 @@
package repository
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// AESEncryptor implements SecretEncryptor using AES-256-GCM
type AESEncryptor struct {
key []byte
}
// NewAESEncryptor creates a new AES encryptor
func NewAESEncryptor(cfg *config.Config) (service.SecretEncryptor, error) {
key, err := hex.DecodeString(cfg.Totp.EncryptionKey)
if err != nil {
return nil, fmt.Errorf("invalid totp encryption key: %w", err)
}
if len(key) != 32 {
return nil, fmt.Errorf("totp encryption key must be 32 bytes (64 hex chars), got %d bytes", len(key))
}
return &AESEncryptor{key: key}, nil
}
// Encrypt encrypts plaintext using AES-256-GCM
// Output format: base64(nonce + ciphertext + tag)
func (e *AESEncryptor) Encrypt(plaintext string) (string, error) {
block, err := aes.NewCipher(e.key)
if err != nil {
return "", fmt.Errorf("create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("create gcm: %w", err)
}
// Generate a random nonce
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", fmt.Errorf("generate nonce: %w", err)
}
// Encrypt the plaintext
// Seal appends the ciphertext and tag to the nonce
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
// Encode as base64
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// Decrypt decrypts ciphertext using AES-256-GCM
func (e *AESEncryptor) Decrypt(ciphertext string) (string, error) {
// Decode from base64
data, err := base64.StdEncoding.DecodeString(ciphertext)
if err != nil {
return "", fmt.Errorf("decode base64: %w", err)
}
block, err := aes.NewCipher(e.key)
if err != nil {
return "", fmt.Errorf("create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("create gcm: %w", err)
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return "", fmt.Errorf("ciphertext too short")
}
// Extract nonce and ciphertext
nonce, ciphertextData := data[:nonceSize], data[nonceSize:]
// Decrypt
plaintext, err := gcm.Open(nil, nonce, ciphertextData, nil)
if err != nil {
return "", fmt.Errorf("decrypt: %w", err)
}
return string(plaintext), nil
}

View File

@@ -387,17 +387,20 @@ func userEntityToService(u *dbent.User) *service.User {
return nil
}
return &service.User{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
ID: u.ID,
Email: u.Email,
Username: u.Username,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
TotpSecretEncrypted: u.TotpSecretEncrypted,
TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}

View File

@@ -35,7 +35,9 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
client := s.clientFactory(proxyURL)
var orgs []struct {
UUID string `json:"uuid"`
UUID string `json:"uuid"`
Name string `json:"name"`
RavenType *string `json:"raven_type"` // nil for personal, "team" for team organization
}
targetURL := s.baseURL + "/api/organizations"
@@ -65,7 +67,23 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
return "", fmt.Errorf("no organizations found")
}
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
// 如果只有一个组织,直接使用
if len(orgs) == 1 {
log.Printf("[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
return orgs[0].UUID, nil
}
// 如果有多个组织,优先选择 raven_type 为 "team" 的组织
for _, org := range orgs {
if org.RavenType != nil && *org.RavenType == "team" {
log.Printf("[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s",
org.UUID, org.Name, *org.RavenType)
return org.UUID, nil
}
}
// 如果没有 team 类型的组织,使用第一个
log.Printf("[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
return orgs[0].UUID, nil
}
@@ -182,7 +200,9 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
resp, err := client.R().
SetContext(ctx).
SetHeader("Accept", "application/json, text/plain, */*").
SetHeader("Content-Type", "application/json").
SetHeader("User-Agent", "axios/1.8.4").
SetBody(reqBody).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)
@@ -205,8 +225,6 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
client := s.clientFactory(proxyURL)
// 使用 JSON 格式(与 ExchangeCodeForToken 保持一致)
// Anthropic OAuth API 期望 JSON 格式的请求体
reqBody := map[string]any{
"grant_type": "refresh_token",
"refresh_token": refreshToken,
@@ -217,7 +235,9 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
resp, err := client.R().
SetContext(ctx).
SetHeader("Accept", "application/json, text/plain, */*").
SetHeader("Content-Type", "application/json").
SetHeader("User-Agent", "axios/1.8.4").
SetBody(reqBody).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)

View File

@@ -171,7 +171,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
s.client.baseURL = "http://in-process"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "")
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeInference, "cc", "st", "")
if tt.wantErr {
require.Error(s.T(), err)

View File

@@ -14,37 +14,82 @@ import (
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
// 默认 User-Agent与用户抓包的请求一致
const defaultUsageUserAgent = "claude-code/2.1.7"
type claudeUsageService struct {
usageURL string
allowPrivateHosts bool
httpUpstream service.HTTPUpstream
}
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
return &claudeUsageService{usageURL: defaultClaudeUsageURL}
// NewClaudeUsageFetcher 创建 Claude 用量获取服务
// httpUpstream: 可选,如果提供则支持 TLS 指纹伪装
func NewClaudeUsageFetcher(httpUpstream service.HTTPUpstream) service.ClaudeUsageFetcher {
return &claudeUsageService{
usageURL: defaultClaudeUsageURL,
httpUpstream: httpUpstream,
}
}
// FetchUsage 简单版本,不支持 TLS 指纹(向后兼容)
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: s.allowPrivateHosts,
return s.FetchUsageWithOptions(ctx, &service.ClaudeUsageFetchOptions{
AccessToken: accessToken,
ProxyURL: proxyURL,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
}
// FetchUsageWithOptions 完整版本,支持 TLS 指纹和自定义 User-Agent
func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *service.ClaudeUsageFetchOptions) (*service.ClaudeUsageResponse, error) {
if opts == nil {
return nil, fmt.Errorf("options is nil")
}
// 创建请求
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
// 设置请求头(与抓包一致,但不设置 Accept-Encoding让 Go 自动处理压缩)
req.Header.Set("Accept", "application/json, text/plain, */*")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+opts.AccessToken)
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
// 设置 User-Agent优先使用缓存的 Fingerprint否则使用默认值
userAgent := defaultUsageUserAgent
if opts.Fingerprint != nil && opts.Fingerprint.UserAgent != "" {
userAgent = opts.Fingerprint.UserAgent
}
req.Header.Set("User-Agent", userAgent)
var resp *http.Response
// 如果启用 TLS 指纹且有 HTTPUpstream使用 DoWithTLS
if opts.EnableTLSFingerprint && s.httpUpstream != nil {
// accountConcurrency 传 0 使用默认连接池配置usage 请求不需要特殊的并发设置
resp, err = s.httpUpstream.DoWithTLS(req, opts.ProxyURL, opts.AccountID, 0, true)
if err != nil {
return nil, fmt.Errorf("request with TLS fingerprint failed: %w", err)
}
} else {
// 不启用 TLS 指纹,使用普通 HTTP 客户端
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: opts.ProxyURL,
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: s.allowPrivateHosts,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
}
resp, err = client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
}
defer func() { _ = resp.Body.Close() }()

View File

@@ -9,13 +9,27 @@ import (
"github.com/redis/go-redis/v9"
)
const verifyCodeKeyPrefix = "verify_code:"
const (
verifyCodeKeyPrefix = "verify_code:"
passwordResetKeyPrefix = "password_reset:"
passwordResetSentAtKeyPrefix = "password_reset_sent:"
)
// verifyCodeKey generates the Redis key for email verification code.
func verifyCodeKey(email string) string {
return verifyCodeKeyPrefix + email
}
// passwordResetKey generates the Redis key for password reset token.
func passwordResetKey(email string) string {
return passwordResetKeyPrefix + email
}
// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
func passwordResetSentAtKey(email string) string {
return passwordResetSentAtKeyPrefix + email
}
type emailCache struct {
rdb *redis.Client
}
@@ -50,3 +64,45 @@ func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) e
key := verifyCodeKey(email)
return c.rdb.Del(ctx, key).Err()
}
// Password reset token methods
func (c *emailCache) GetPasswordResetToken(ctx context.Context, email string) (*service.PasswordResetTokenData, error) {
key := passwordResetKey(email)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
}
var data service.PasswordResetTokenData
if err := json.Unmarshal([]byte(val), &data); err != nil {
return nil, err
}
return &data, nil
}
func (c *emailCache) SetPasswordResetToken(ctx context.Context, email string, data *service.PasswordResetTokenData, ttl time.Duration) error {
key := passwordResetKey(email)
val, err := json.Marshal(data)
if err != nil {
return err
}
return c.rdb.Set(ctx, key, val, ttl).Err()
}
func (c *emailCache) DeletePasswordResetToken(ctx context.Context, email string) error {
key := passwordResetKey(email)
return c.rdb.Del(ctx, key).Err()
}
// Password reset email cooldown methods
func (c *emailCache) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
key := passwordResetSentAtKey(email)
exists, err := c.rdb.Exists(ctx, key).Result()
return err == nil && exists > 0
}
func (c *emailCache) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
key := passwordResetSentAtKey(email)
return c.rdb.Set(ctx, key, "1", ttl).Err()
}

View File

@@ -39,3 +39,15 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Expire(ctx, key, ttl).Err()
}
// DeleteSessionAccountID 删除粘性会话与账号的绑定关系。
// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用,
// 以便下次请求能够重新选择可用账号。
//
// DeleteSessionAccountID removes the sticky session binding for the given session.
// Called when the bound account becomes unavailable (e.g., error status, disabled,
// or unschedulable), allowing subsequent requests to select a new available account.
func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -78,6 +78,19 @@ func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
}
func (s *GatewayCacheSuite) TestDeleteSessionAccountID() {
sessionID := "openai:s4"
accountID := int64(102)
groupID := int64(1)
sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.DeleteSessionAccountID(s.ctx, groupID, sessionID), "DeleteSessionAccountID")
_, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
}
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
sessionID := "corrupted"
groupID := int64(1)

View File

@@ -24,7 +24,7 @@ func (s *GatewayRoutingSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx)
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx, nil)
}
func TestGatewayRoutingSuite(t *testing.T) {

View File

@@ -2,10 +2,11 @@ package repository
import (
"context"
"fmt"
"net/http"
"net/url"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3"
@@ -38,16 +39,17 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
resp, err := client.R().
SetContext(ctx).
SetHeader("User-Agent", "codex-cli/0.91.0").
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_EXCHANGE_FAILED", "token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
@@ -66,16 +68,17 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
resp, err := client.R().
SetContext(ctx).
SetHeader("User-Agent", "codex-cli/0.91.0").
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
@@ -84,6 +87,6 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
func createOpenAIReqClient(proxyURL string) *req.Client {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 60 * time.Second,
Timeout: 120 * time.Second,
})
}

View File

@@ -244,6 +244,13 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
require.ErrorContains(s.T(), err, "status 401")
}
func TestNewOpenAIOAuthClient_DefaultTokenURL(t *testing.T) {
client := NewOpenAIOAuthClient()
svc, ok := client.(*openaiOAuthService)
require.True(t, ok)
require.Equal(t, openai.TokenURL, svc.tokenURL)
}
func TestOpenAIOAuthServiceSuite(t *testing.T) {
suite.Run(t, new(OpenAIOAuthServiceSuite))
}

View File

@@ -14,6 +14,7 @@ type reqClientOptions struct {
ProxyURL string // 代理 URL支持 http/https/socks5
Timeout time.Duration // 请求超时时间
Impersonate bool // 是否模拟 Chrome 浏览器指纹
ForceHTTP2 bool // 是否强制使用 HTTP/2
}
// sharedReqClients 存储按配置参数缓存的 req 客户端实例
@@ -41,6 +42,9 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
}
client := req.C().SetTimeout(opts.Timeout)
if opts.ForceHTTP2 {
client = client.EnableForceHTTP2()
}
if opts.Impersonate {
client = client.ImpersonateChrome()
}
@@ -56,9 +60,10 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
}
func buildReqClientKey(opts reqClientOptions) string {
return fmt.Sprintf("%s|%s|%t",
return fmt.Sprintf("%s|%s|%t|%t",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.Impersonate,
opts.ForceHTTP2,
)
}

View File

@@ -0,0 +1,90 @@
package repository
import (
"reflect"
"sync"
"testing"
"time"
"unsafe"
"github.com/imroc/req/v3"
"github.com/stretchr/testify/require"
)
func forceHTTPVersion(t *testing.T, client *req.Client) string {
t.Helper()
transport := client.GetTransport()
field := reflect.ValueOf(transport).Elem().FieldByName("forceHttpVersion")
require.True(t, field.IsValid(), "forceHttpVersion field not found")
require.True(t, field.CanAddr(), "forceHttpVersion field not addressable")
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().String()
}
func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) {
sharedReqClients = sync.Map{}
base := reqClientOptions{
ProxyURL: "http://proxy.local:8080",
Timeout: time.Second,
}
clientDefault := getSharedReqClient(base)
force := base
force.ForceHTTP2 = true
clientForce := getSharedReqClient(force)
require.NotSame(t, clientDefault, clientForce)
require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force))
}
func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: "http://proxy.local:8080",
Timeout: 2 * time.Second,
}
first := getSharedReqClient(opts)
second := getSharedReqClient(opts)
require.Same(t, first, second)
}
func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: " http://proxy.local:8080 ",
Timeout: 3 * time.Second,
}
key := buildReqClientKey(opts)
sharedReqClients.Store(key, "invalid")
client := getSharedReqClient(opts)
require.NotNil(t, client)
loaded, ok := sharedReqClients.Load(key)
require.True(t, ok)
require.IsType(t, "invalid", loaded)
}
func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: " http://proxy.local:8080 ",
Timeout: 4 * time.Second,
Impersonate: true,
}
client := getSharedReqClient(opts)
require.NotNil(t, client)
require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts))
}
func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) {
sharedReqClients = sync.Map{}
client := createOpenAIReqClient("http://proxy.local:8080")
require.Equal(t, 120*time.Second, client.GetClient().Timeout)
}
func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) {
sharedReqClients = sync.Map{}
client := createGeminiReqClient("http://proxy.local:8080")
require.Equal(t, "", forceHTTPVersion(t, client))
}

View File

@@ -58,7 +58,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul
return nil, false, err
}
if len(ids) == 0 {
return []*service.Account{}, true, nil
// 空快照视为缓存未命中,触发数据库回退查询
// 这解决了新分组创建后立即绑定账号时的竞态条件问题
return nil, false, nil
}
keys := make([]string, 0, len(ids))

View File

@@ -19,7 +19,7 @@ func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox")
accountRepo := newAccountRepositoryWithSQL(client, integrationDB)
accountRepo := newAccountRepositoryWithSQL(client, integrationDB, nil)
outboxRepo := NewSchedulerOutboxRepository(integrationDB)
cache := NewSchedulerCache(rdb)

View File

@@ -217,7 +217,7 @@ func (c *sessionLimitCache) GetActiveSessionCount(ctx context.Context, accountID
}
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64, idleTimeouts map[int64]time.Duration) (map[int64]int, error) {
if len(accountIDs) == 0 {
return make(map[int64]int), nil
}
@@ -226,11 +226,18 @@ func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, acco
// 使用 pipeline 批量执行
pipe := c.rdb.Pipeline()
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
cmds := make(map[int64]*redis.Cmd, len(accountIDs))
for _, accountID := range accountIDs {
key := sessionLimitKey(accountID)
// 使用各账号自己的 idleTimeout如果没有则用默认值
idleTimeout := c.defaultIdleTimeout
if idleTimeouts != nil {
if t, ok := idleTimeouts[accountID]; ok && t > 0 {
idleTimeout = t
}
}
idleTimeoutSeconds := int(idleTimeout.Seconds())
cmds[accountID] = getActiveSessionCountScript.Run(ctx, pipe, []string{key}, idleTimeoutSeconds)
}

View File

@@ -0,0 +1,149 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"github.com/Wei-Shaw/sub2api/internal/service"
)
const (
totpSetupKeyPrefix = "totp:setup:"
totpLoginKeyPrefix = "totp:login:"
totpAttemptsKeyPrefix = "totp:attempts:"
totpAttemptsTTL = 15 * time.Minute
)
// TotpCache implements service.TotpCache using Redis
type TotpCache struct {
rdb *redis.Client
}
// NewTotpCache creates a new TOTP cache
func NewTotpCache(rdb *redis.Client) service.TotpCache {
return &TotpCache{rdb: rdb}
}
// GetSetupSession retrieves a TOTP setup session
func (c *TotpCache) GetSetupSession(ctx context.Context, userID int64) (*service.TotpSetupSession, error) {
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
data, err := c.rdb.Get(ctx, key).Bytes()
if err != nil {
if err == redis.Nil {
return nil, nil
}
return nil, fmt.Errorf("get setup session: %w", err)
}
var session service.TotpSetupSession
if err := json.Unmarshal(data, &session); err != nil {
return nil, fmt.Errorf("unmarshal setup session: %w", err)
}
return &session, nil
}
// SetSetupSession stores a TOTP setup session
func (c *TotpCache) SetSetupSession(ctx context.Context, userID int64, session *service.TotpSetupSession, ttl time.Duration) error {
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
data, err := json.Marshal(session)
if err != nil {
return fmt.Errorf("marshal setup session: %w", err)
}
if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
return fmt.Errorf("set setup session: %w", err)
}
return nil
}
// DeleteSetupSession deletes a TOTP setup session
func (c *TotpCache) DeleteSetupSession(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
return c.rdb.Del(ctx, key).Err()
}
// GetLoginSession retrieves a TOTP login session
func (c *TotpCache) GetLoginSession(ctx context.Context, tempToken string) (*service.TotpLoginSession, error) {
key := totpLoginKeyPrefix + tempToken
data, err := c.rdb.Get(ctx, key).Bytes()
if err != nil {
if err == redis.Nil {
return nil, nil
}
return nil, fmt.Errorf("get login session: %w", err)
}
var session service.TotpLoginSession
if err := json.Unmarshal(data, &session); err != nil {
return nil, fmt.Errorf("unmarshal login session: %w", err)
}
return &session, nil
}
// SetLoginSession stores a TOTP login session
func (c *TotpCache) SetLoginSession(ctx context.Context, tempToken string, session *service.TotpLoginSession, ttl time.Duration) error {
key := totpLoginKeyPrefix + tempToken
data, err := json.Marshal(session)
if err != nil {
return fmt.Errorf("marshal login session: %w", err)
}
if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
return fmt.Errorf("set login session: %w", err)
}
return nil
}
// DeleteLoginSession deletes a TOTP login session
func (c *TotpCache) DeleteLoginSession(ctx context.Context, tempToken string) error {
key := totpLoginKeyPrefix + tempToken
return c.rdb.Del(ctx, key).Err()
}
// IncrementVerifyAttempts increments the verify attempt counter
func (c *TotpCache) IncrementVerifyAttempts(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
// Use pipeline for atomic increment and set TTL
pipe := c.rdb.Pipeline()
incrCmd := pipe.Incr(ctx, key)
pipe.Expire(ctx, key, totpAttemptsTTL)
if _, err := pipe.Exec(ctx); err != nil {
return 0, fmt.Errorf("increment verify attempts: %w", err)
}
count, err := incrCmd.Result()
if err != nil {
return 0, fmt.Errorf("get increment result: %w", err)
}
return int(count), nil
}
// GetVerifyAttempts gets the current verify attempt count
func (c *TotpCache) GetVerifyAttempts(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
count, err := c.rdb.Get(ctx, key).Int()
if err != nil {
if err == redis.Nil {
return 0, nil
}
return 0, fmt.Errorf("get verify attempts: %w", err)
}
return count, nil
}
// ClearVerifyAttempts clears the verify attempt counter
func (c *TotpCache) ClearVerifyAttempts(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"sort"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
@@ -466,3 +467,46 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
client := clientFromContext(ctx, r.client)
update := client.User.UpdateOneID(userID)
if encryptedSecret == nil {
update = update.ClearTotpSecretEncrypted()
} else {
update = update.SetTotpSecretEncrypted(*encryptedSecret)
}
_, err := update.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return nil
}
// EnableTotp 启用用户的 TOTP 双因素认证
func (r *userRepository) EnableTotp(ctx context.Context, userID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.User.UpdateOneID(userID).
SetTotpEnabled(true).
SetTotpEnabledAt(time.Now()).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return nil
}
// DisableTotp 禁用用户的 TOTP 双因素认证
func (r *userRepository) DisableTotp(ctx context.Context, userID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.User.UpdateOneID(userID).
SetTotpEnabled(false).
ClearTotpEnabledAt().
ClearTotpSecretEncrypted().
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return nil
}

View File

@@ -190,7 +190,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
}
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query()
if userID != nil {
@@ -199,7 +199,31 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
if groupID != nil {
q = q.Where(usersubscription.GroupIDEQ(*groupID))
}
if status != "" {
// Status filtering with real-time expiration check
now := time.Now()
switch status {
case service.SubscriptionStatusActive:
// Active: status is active AND not yet expired
q = q.Where(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtGT(now),
)
case service.SubscriptionStatusExpired:
// Expired: status is expired OR (status is active but already expired)
q = q.Where(
usersubscription.Or(
usersubscription.StatusEQ(service.SubscriptionStatusExpired),
usersubscription.And(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtLTE(now),
),
),
)
case "":
// No filter
default:
// Other status (e.g., revoked)
q = q.Where(usersubscription.StatusEQ(status))
}
@@ -208,11 +232,28 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
return nil, nil, err
}
// Apply sorting
q = q.WithUser().WithGroup().WithAssignedByUser()
// Determine sort field
var field string
switch sortBy {
case "expires_at":
field = usersubscription.FieldExpiresAt
case "status":
field = usersubscription.FieldStatus
default:
field = usersubscription.FieldCreatedAt
}
// Determine sort order (default: desc)
if sortOrder == "asc" && sortBy != "" {
q = q.Order(dbent.Asc(field))
} else {
q = q.Order(dbent.Desc(field))
}
subs, err := q.
WithUser().
WithGroup().
WithAssignedByUser().
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
Offset(params.Offset()).
Limit(params.Limit()).
All(ctx)

View File

@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
group := s.mustCreateGroup("g-list")
s.mustCreateSubscription(user.ID, group.ID, nil)
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "")
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "")
s.Require().NoError(err, "List")
s.Require().Len(subs, 1)
s.Require().Equal(int64(1), page.Total)
@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
s.mustCreateSubscription(user1.ID, group.ID, nil)
s.mustCreateSubscription(user2.ID, group.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "")
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(user1.ID, subs[0].UserID)
@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
s.mustCreateSubscription(user.ID, g1.ID, nil)
s.mustCreateSubscription(user.ID, g2.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "")
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(g1.ID, subs[0].GroupID)
@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
})
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)

View File

@@ -82,6 +82,10 @@ var ProviderSet = wire.NewSet(
NewSchedulerCache,
NewSchedulerOutboxRepository,
NewProxyLatencyCache,
NewTotpCache,
// Encryptors
NewAESEncryptor,
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier,

View File

@@ -51,7 +51,6 @@ func TestAPIContracts(t *testing.T) {
"id": 1,
"email": "alice@example.com",
"username": "alice",
"notes": "hello",
"role": "user",
"balance": 12.5,
"concurrency": 5,
@@ -131,6 +130,153 @@ func TestAPIContracts(t *testing.T) {
}
}`,
},
{
name: "GET /api/v1/groups/available",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
// 普通用户可见的分组列表不应包含内部字段(如 model_routing/account_count
deps.groupRepo.SetActive([]service.Group{
{
ID: 10,
Name: "Group One",
Description: "desc",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.5,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
"claude-3-*": []int64{101, 102},
},
AccountCount: 2,
CreatedAt: deps.now,
UpdatedAt: deps.now,
},
})
deps.userSubRepo.SetActiveByUserID(1, nil)
},
method: http.MethodGet,
path: "/api/v1/groups/available",
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": [
{
"id": 10,
"name": "Group One",
"description": "desc",
"platform": "anthropic",
"rate_multiplier": 1.5,
"is_exclusive": false,
"status": "active",
"subscription_type": "standard",
"daily_limit_usd": null,
"weekly_limit_usd": null,
"monthly_limit_usd": null,
"image_price_1k": null,
"image_price_2k": null,
"image_price_4k": null,
"claude_code_only": false,
"fallback_group_id": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
]
}`,
},
{
name: "GET /api/v1/subscriptions",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
// 普通用户订阅接口不应包含 assigned_* / notes 等管理员字段。
deps.userSubRepo.SetByUserID(1, []service.UserSubscription{
{
ID: 501,
UserID: 1,
GroupID: 10,
StartsAt: deps.now,
ExpiresAt: time.Date(2099, 1, 2, 3, 4, 5, 0, time.UTC), // 使用未来日期避免 normalizeSubscriptionStatus 标记为过期
Status: service.SubscriptionStatusActive,
DailyUsageUSD: 1.23,
WeeklyUsageUSD: 2.34,
MonthlyUsageUSD: 3.45,
AssignedBy: ptr(int64(999)),
AssignedAt: deps.now,
Notes: "admin-note",
CreatedAt: deps.now,
UpdatedAt: deps.now,
},
})
},
method: http.MethodGet,
path: "/api/v1/subscriptions",
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": [
{
"id": 501,
"user_id": 1,
"group_id": 10,
"starts_at": "2025-01-02T03:04:05Z",
"expires_at": "2099-01-02T03:04:05Z",
"status": "active",
"daily_window_start": null,
"weekly_window_start": null,
"monthly_window_start": null,
"daily_usage_usd": 1.23,
"weekly_usage_usd": 2.34,
"monthly_usage_usd": 3.45,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
]
}`,
},
{
name: "GET /api/v1/redeem/history",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
// 普通用户兑换历史不应包含 notes 等内部字段。
deps.redeemRepo.SetByUser(1, []service.RedeemCode{
{
ID: 900,
Code: "CODE-123",
Type: service.RedeemTypeBalance,
Value: 1.25,
Status: service.StatusUsed,
UsedBy: ptr(int64(1)),
UsedAt: ptr(deps.now),
Notes: "internal-note",
CreatedAt: deps.now,
},
})
},
method: http.MethodGet,
path: "/api/v1/redeem/history",
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": [
{
"id": 900,
"code": "CODE-123",
"type": "balance",
"value": 1.25,
"status": "used",
"used_by": 1,
"used_at": "2025-01-02T03:04:05Z",
"created_at": "2025-01-02T03:04:05Z",
"group_id": null,
"validity_days": 0
}
]
}`,
},
{
name: "GET /api/v1/usage/stats",
setup: func(t *testing.T, deps *contractDeps) {
@@ -190,24 +336,25 @@ func TestAPIContracts(t *testing.T) {
t.Helper()
deps.usageRepo.SetUserLogs(1, []service.UsageLog{
{
ID: 1,
UserID: 1,
APIKeyID: 100,
AccountID: 200,
RequestID: "req_123",
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
CacheCreationTokens: 1,
CacheReadTokens: 2,
TotalCost: 0.5,
ActualCost: 0.5,
RateMultiplier: 1,
BillingType: service.BillingTypeBalance,
Stream: true,
DurationMs: ptr(100),
FirstTokenMs: ptr(50),
CreatedAt: deps.now,
ID: 1,
UserID: 1,
APIKeyID: 100,
AccountID: 200,
AccountRateMultiplier: ptr(0.5),
RequestID: "req_123",
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
CacheCreationTokens: 1,
CacheReadTokens: 2,
TotalCost: 0.5,
ActualCost: 0.5,
RateMultiplier: 1,
BillingType: service.BillingTypeBalance,
Stream: true,
DurationMs: ptr(100),
FirstTokenMs: ptr(50),
CreatedAt: deps.now,
},
})
},
@@ -238,10 +385,9 @@ func TestAPIContracts(t *testing.T) {
"output_cost": 0,
"cache_creation_cost": 0,
"cache_read_cost": 0,
"total_cost": 0.5,
"total_cost": 0.5,
"actual_cost": 0.5,
"rate_multiplier": 1,
"account_rate_multiplier": null,
"billing_type": 0,
"stream": true,
"duration_ms": 100,
@@ -266,6 +412,7 @@ func TestAPIContracts(t *testing.T) {
deps.settingRepo.SetAll(map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false",
service.SettingKeyPromoCodeEnabled: "true",
service.SettingKeySMTPHost: "smtp.example.com",
service.SettingKeySMTPPort: "587",
@@ -304,6 +451,10 @@ func TestAPIContracts(t *testing.T) {
"data": {
"registration_enabled": true,
"email_verify_enabled": false,
"promo_code_enabled": true,
"password_reset_enabled": false,
"totp_enabled": false,
"totp_encryption_key_configured": false,
"smtp_host": "smtp.example.com",
"smtp_port": 587,
"smtp_username": "user",
@@ -337,7 +488,10 @@ func TestAPIContracts(t *testing.T) {
"fallback_model_openai": "gpt-4o",
"enable_identity_patch": true,
"identity_patch_prompt": "",
"home_content": ""
"home_content": "",
"hide_ccs_import_button": false,
"purchase_subscription_enabled": false,
"purchase_subscription_url": ""
}
}`,
},
@@ -385,8 +539,11 @@ type contractDeps struct {
now time.Time
router http.Handler
apiKeyRepo *stubApiKeyRepo
groupRepo *stubGroupRepo
userSubRepo *stubUserSubscriptionRepo
usageRepo *stubUsageLogRepo
settingRepo *stubSettingRepo
redeemRepo *stubRedeemCodeRepo
}
func newContractDeps(t *testing.T) *contractDeps {
@@ -414,11 +571,11 @@ func newContractDeps(t *testing.T) *contractDeps {
apiKeyRepo := newStubApiKeyRepo(now)
apiKeyCache := stubApiKeyCache{}
groupRepo := stubGroupRepo{}
userSubRepo := stubUserSubscriptionRepo{}
groupRepo := &stubGroupRepo{}
userSubRepo := &stubUserSubscriptionRepo{}
accountRepo := stubAccountRepo{}
proxyRepo := stubProxyRepo{}
redeemRepo := stubRedeemCodeRepo{}
redeemRepo := &stubRedeemCodeRepo{}
cfg := &config.Config{
Default: config.DefaultConfig{
@@ -433,15 +590,21 @@ func newContractDeps(t *testing.T) *contractDeps {
usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil)
redeemHandler := handler.NewRedeemHandler(redeemService)
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
jwtAuth := func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
@@ -472,12 +635,21 @@ func newContractDeps(t *testing.T) *contractDeps {
v1Keys.Use(jwtAuth)
v1Keys.GET("/keys", apiKeyHandler.List)
v1Keys.POST("/keys", apiKeyHandler.Create)
v1Keys.GET("/groups/available", apiKeyHandler.GetAvailableGroups)
v1Usage := v1.Group("")
v1Usage.Use(jwtAuth)
v1Usage.GET("/usage", usageHandler.List)
v1Usage.GET("/usage/stats", usageHandler.Stats)
v1Subs := v1.Group("")
v1Subs.Use(jwtAuth)
v1Subs.GET("/subscriptions", subscriptionHandler.List)
v1Redeem := v1.Group("")
v1Redeem.Use(jwtAuth)
v1Redeem.GET("/redeem/history", redeemHandler.GetHistory)
v1Admin := v1.Group("/admin")
v1Admin.Use(adminAuth)
v1Admin.GET("/settings", adminSettingHandler.GetSettings)
@@ -487,8 +659,11 @@ func newContractDeps(t *testing.T) *contractDeps {
now: now,
router: r,
apiKeyRepo: apiKeyRepo,
groupRepo: groupRepo,
userSubRepo: userSubRepo,
usageRepo: usageRepo,
settingRepo: settingRepo,
redeemRepo: redeemRepo,
}
}
@@ -584,6 +759,18 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
return 0, errors.New("not implemented")
}
func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error {
return errors.New("not implemented")
}
type stubApiKeyCache struct{}
func (stubApiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
@@ -626,7 +813,13 @@ func (stubApiKeyCache) SubscribeAuthCacheInvalidation(ctx context.Context, handl
return nil
}
type stubGroupRepo struct{}
type stubGroupRepo struct {
active []service.Group
}
func (r *stubGroupRepo) SetActive(groups []service.Group) {
r.active = append([]service.Group(nil), groups...)
}
func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error {
return errors.New("not implemented")
@@ -660,12 +853,19 @@ func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.Pagi
return nil, nil, errors.New("not implemented")
}
func (stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) {
return nil, errors.New("not implemented")
func (r *stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) {
return append([]service.Group(nil), r.active...), nil
}
func (stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
return nil, errors.New("not implemented")
func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
out := make([]service.Group, 0, len(r.active))
for i := range r.active {
g := r.active[i]
if g.Platform == platform {
out = append(out, g)
}
}
return out, nil
}
func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
@@ -883,7 +1083,16 @@ func (stubProxyRepo) ListAccountSummariesByProxyID(ctx context.Context, proxyID
return nil, errors.New("not implemented")
}
type stubRedeemCodeRepo struct{}
type stubRedeemCodeRepo struct {
byUser map[int64][]service.RedeemCode
}
func (r *stubRedeemCodeRepo) SetByUser(userID int64, codes []service.RedeemCode) {
if r.byUser == nil {
r.byUser = make(map[int64][]service.RedeemCode)
}
r.byUser[userID] = append([]service.RedeemCode(nil), codes...)
}
func (stubRedeemCodeRepo) Create(ctx context.Context, code *service.RedeemCode) error {
return errors.New("not implemented")
@@ -921,11 +1130,35 @@ func (stubRedeemCodeRepo) ListWithFilters(ctx context.Context, params pagination
return nil, nil, errors.New("not implemented")
}
func (stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) {
return nil, errors.New("not implemented")
func (r *stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) {
if r.byUser == nil {
return nil, nil
}
codes := r.byUser[userID]
if limit > 0 && len(codes) > limit {
codes = codes[:limit]
}
return append([]service.RedeemCode(nil), codes...), nil
}
type stubUserSubscriptionRepo struct{}
type stubUserSubscriptionRepo struct {
byUser map[int64][]service.UserSubscription
activeByUser map[int64][]service.UserSubscription
}
func (r *stubUserSubscriptionRepo) SetByUserID(userID int64, subs []service.UserSubscription) {
if r.byUser == nil {
r.byUser = make(map[int64][]service.UserSubscription)
}
r.byUser[userID] = append([]service.UserSubscription(nil), subs...)
}
func (r *stubUserSubscriptionRepo) SetActiveByUserID(userID int64, subs []service.UserSubscription) {
if r.activeByUser == nil {
r.activeByUser = make(map[int64][]service.UserSubscription)
}
r.activeByUser[userID] = append([]service.UserSubscription(nil), subs...)
}
func (stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
return errors.New("not implemented")
@@ -945,16 +1178,22 @@ func (stubUserSubscriptionRepo) Update(ctx context.Context, sub *service.UserSub
func (stubUserSubscriptionRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
return nil, errors.New("not implemented")
func (r *stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
if r.byUser == nil {
return nil, nil
}
return append([]service.UserSubscription(nil), r.byUser[userID]...), nil
}
func (stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
return nil, errors.New("not implemented")
func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
if r.activeByUser == nil {
return nil, nil
}
return append([]service.UserSubscription(nil), r.activeByUser[userID]...), nil
}
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {

View File

@@ -367,7 +367,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in
return nil, nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}

View File

@@ -26,11 +26,20 @@ func RegisterAuthRoutes(
{
auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login)
auth.POST("/login/2fa", h.Auth.Login2FA)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
// 优惠码验证接口添加速率限制:每分钟最多 10 次Redis 故障时 fail-close
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ValidatePromoCode)
// 忘记密码接口添加速率限制:每分钟最多 5 次Redis 故障时 fail-close
auth.POST("/forgot-password", rateLimiter.LimitWithOptions("forgot-password", 5, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ForgotPassword)
// 重置密码接口添加速率限制:每分钟最多 10 次Redis 故障时 fail-close
auth.POST("/reset-password", rateLimiter.LimitWithOptions("reset-password", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ResetPassword)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
}

View File

@@ -22,6 +22,17 @@ func RegisterUserRoutes(
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile)
// TOTP 双因素认证
totp := user.Group("/totp")
{
totp.GET("/status", h.Totp.GetStatus)
totp.GET("/verification-method", h.Totp.GetVerificationMethod)
totp.POST("/send-code", h.Totp.SendVerifyCode)
totp.POST("/setup", h.Totp.InitiateSetup)
totp.POST("/enable", h.Totp.Enable)
totp.POST("/disable", h.Totp.Disable)
}
}
// API Key管理

View File

@@ -197,6 +197,35 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time {
return nil
}
// GetCredentialAsInt64 解析凭证中的 int64 字段
// 用于读取 _token_version 等内部字段
func (a *Account) GetCredentialAsInt64(key string) int64 {
if a == nil || a.Credentials == nil {
return 0
}
val, ok := a.Credentials[key]
if !ok || val == nil {
return 0
}
switch v := val.(type) {
case int64:
return v
case float64:
return int64(v)
case int:
return int64(v)
case json.Number:
if i, err := v.Int64(); err == nil {
return i
}
case string:
if i, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64); err == nil {
return i
}
}
return 0
}
func (a *Account) IsTempUnschedulableEnabled() bool {
if a.Credentials == nil {
return false

View File

@@ -157,9 +157,20 @@ type ClaudeUsageResponse struct {
} `json:"seven_day_sonnet"`
}
// ClaudeUsageFetchOptions 包含获取 Claude 用量数据所需的所有选项
type ClaudeUsageFetchOptions struct {
AccessToken string // OAuth access token
ProxyURL string // 代理 URL可选
AccountID int64 // 账号 ID用于 TLS 指纹选择)
EnableTLSFingerprint bool // 是否启用 TLS 指纹伪装
Fingerprint *Fingerprint // 缓存的指纹信息User-Agent 等)
}
// ClaudeUsageFetcher fetches usage data from Anthropic OAuth API
type ClaudeUsageFetcher interface {
FetchUsage(ctx context.Context, accessToken, proxyURL string) (*ClaudeUsageResponse, error)
// FetchUsageWithOptions 使用完整选项获取用量数据,支持 TLS 指纹和自定义 User-Agent
FetchUsageWithOptions(ctx context.Context, opts *ClaudeUsageFetchOptions) (*ClaudeUsageResponse, error)
}
// AccountUsageService 账号使用量查询服务
@@ -170,6 +181,7 @@ type AccountUsageService struct {
geminiQuotaService *GeminiQuotaService
antigravityQuotaFetcher *AntigravityQuotaFetcher
cache *UsageCache
identityCache IdentityCache
}
// NewAccountUsageService 创建AccountUsageService实例
@@ -180,6 +192,7 @@ func NewAccountUsageService(
geminiQuotaService *GeminiQuotaService,
antigravityQuotaFetcher *AntigravityQuotaFetcher,
cache *UsageCache,
identityCache IdentityCache,
) *AccountUsageService {
return &AccountUsageService{
accountRepo: accountRepo,
@@ -188,6 +201,7 @@ func NewAccountUsageService(
geminiQuotaService: geminiQuotaService,
antigravityQuotaFetcher: antigravityQuotaFetcher,
cache: cache,
identityCache: identityCache,
}
}
@@ -424,6 +438,8 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI
}
// fetchOAuthUsageRaw 从 Anthropic API 获取原始响应(不构建 UsageInfo
// 如果账号开启了 TLS 指纹,则使用 TLS 指纹伪装
// 如果有缓存的 Fingerprint则使用缓存的 User-Agent 等信息
func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *Account) (*ClaudeUsageResponse, error) {
accessToken := account.GetCredential("access_token")
if accessToken == "" {
@@ -435,7 +451,22 @@ func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *A
proxyURL = account.Proxy.URL()
}
return s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL)
// 构建完整的选项
opts := &ClaudeUsageFetchOptions{
AccessToken: accessToken,
ProxyURL: proxyURL,
AccountID: account.ID,
EnableTLSFingerprint: account.IsTLSFingerprintEnabled(),
}
// 尝试获取缓存的 Fingerprint包含 User-Agent 等信息)
if s.identityCache != nil {
if fp, err := s.identityCache.GetFingerprint(ctx, account.ID); err == nil && fp != nil {
opts.Fingerprint = fp
}
}
return s.usageFetcher.FetchUsageWithOptions(ctx, opts)
}
// parseTime 尝试多种格式解析时间

View File

@@ -93,6 +93,18 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
panic("unexpected RemoveGroupFromAllowedGroups call")
}
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
func (s *userRepoStub) EnableTotp(ctx context.Context, userID int64) error {
panic("unexpected EnableTotp call")
}
func (s *userRepoStub) DisableTotp(ctx context.Context, userID int64) error {
panic("unexpected DisableTotp call")
}
type groupRepoStub struct {
affectedUserIDs []int64
deleteErr error

View File

@@ -1305,6 +1305,14 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
return nil, err
}
// 清理 Schema
if cleanedBody, err := cleanGeminiRequest(injectedBody); err == nil {
injectedBody = cleanedBody
log.Printf("[Antigravity] Cleaned request schema in forwarded request for account %s", account.Name)
} else {
log.Printf("[Antigravity] Failed to clean schema: %v", err)
}
// 包装请求
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody)
if err != nil {
@@ -1705,6 +1713,19 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
// Check for MALFORMED_FUNCTION_CALL
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok {
if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward stream")
if content, ok := cand["content"]; ok {
if b, err := json.Marshal(content); err == nil {
log.Printf("[Antigravity] Malformed content: %s", string(b))
}
}
}
}
}
}
if firstTokenMs == nil {
@@ -1854,6 +1875,20 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
usage = u
}
// Check for MALFORMED_FUNCTION_CALL
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok {
if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward non-stream collect")
if content, ok := cand["content"]; ok {
if b, err := json.Marshal(content); err == nil {
log.Printf("[Antigravity] Malformed content: %s", string(b))
}
}
}
}
}
// 保留最后一个有 parts 的响应
if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed
@@ -1950,6 +1985,58 @@ func getOrCreateGeminiParts(response map[string]any) (result map[string]any, exi
return result, existingParts, setParts
}
// mergeCollectedPartsToResponse 将收集的所有 parts 合并到 Gemini 响应中
// 这个函数会合并所有类型的 partstext、thinking、functionCall、inlineData 等
// 保持原始顺序,只合并连续的普通 text parts
func mergeCollectedPartsToResponse(response map[string]any, collectedParts []map[string]any) map[string]any {
if len(collectedParts) == 0 {
return response
}
result, _, setParts := getOrCreateGeminiParts(response)
// 合并策略:
// 1. 保持原始顺序
// 2. 连续的普通 text parts 合并为一个
// 3. thinking、functionCall、inlineData 等保持原样
var mergedParts []any
var textBuffer strings.Builder
flushTextBuffer := func() {
if textBuffer.Len() > 0 {
mergedParts = append(mergedParts, map[string]any{
"text": textBuffer.String(),
})
textBuffer.Reset()
}
}
for _, part := range collectedParts {
// 检查是否是普通 text part
if text, ok := part["text"].(string); ok {
// 检查是否有 thought 标记
if thought, _ := part["thought"].(bool); thought {
// thinking part先刷新 text buffer然后保留原样
flushTextBuffer()
mergedParts = append(mergedParts, part)
} else {
// 普通 text累积到 buffer
_, _ = textBuffer.WriteString(text)
}
} else {
// 非 text partfunctionCall、inlineData 等),先刷新 text buffer然后保留原样
flushTextBuffer()
mergedParts = append(mergedParts, part)
}
}
// 刷新剩余的 text
flushTextBuffer()
setParts(mergedParts)
return result
}
// mergeImagePartsToResponse 将收集到的图片 parts 合并到 Gemini 响应中
func mergeImagePartsToResponse(response map[string]any, imageParts []map[string]any) map[string]any {
if len(imageParts) == 0 {
@@ -2133,6 +2220,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
var firstTokenMs *int
var last map[string]any
var lastWithParts map[string]any
var collectedParts []map[string]any // 收集所有 parts包括 text、thinking、functionCall、inlineData 等)
type scanEvent struct {
line string
@@ -2227,9 +2315,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
last = parsed
// 保留最后一个有 parts 的响应
// 保留最后一个有 parts 的响应,并收集所有 parts
if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed
// 收集所有 partstext、thinking、functionCall、inlineData 等)
collectedParts = append(collectedParts, parts...)
}
case <-intervalCh:
@@ -2252,6 +2343,11 @@ returnResponse:
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream")
}
// 将收集的所有 parts 合并到最终响应中
if len(collectedParts) > 0 {
finalResponse = mergeCollectedPartsToResponse(finalResponse, collectedParts)
}
// 序列化为 JSONGemini 格式)
geminiBody, err := json.Marshal(finalResponse)
if err != nil {
@@ -2459,3 +2555,55 @@ func isImageGenerationModel(model string) bool {
modelLower == "gemini-2.5-flash-image-preview" ||
strings.HasPrefix(modelLower, "gemini-2.5-flash-image-")
}
// cleanGeminiRequest 清理 Gemini 请求体中的 Schema
func cleanGeminiRequest(body []byte) ([]byte, error) {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return nil, err
}
modified := false
// 1. 清理 Tools
if tools, ok := payload["tools"].([]any); ok && len(tools) > 0 {
for _, t := range tools {
toolMap, ok := t.(map[string]any)
if !ok {
continue
}
// function_declarations (snake_case) or functionDeclarations (camelCase)
var funcs []any
if f, ok := toolMap["functionDeclarations"].([]any); ok {
funcs = f
} else if f, ok := toolMap["function_declarations"].([]any); ok {
funcs = f
}
if len(funcs) == 0 {
continue
}
for _, f := range funcs {
funcMap, ok := f.(map[string]any)
if !ok {
continue
}
if params, ok := funcMap["parameters"].(map[string]any); ok {
antigravity.DeepCleanUndefined(params)
cleaned := antigravity.CleanJSONSchema(params)
funcMap["parameters"] = cleaned
modified = true
}
}
}
}
if !modified {
return body, nil
}
return json.Marshal(payload)
}

View File

@@ -142,12 +142,13 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
result.Email = userInfo.Email
}
// 获取 project_id部分账户类型可能没有
loadResp, _, err := client.LoadCodeAssist(ctx, tokenResp.AccessToken)
if err != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败: %v\n", err)
} else if loadResp != nil && loadResp.CloudAICompanionProject != "" {
result.ProjectID = loadResp.CloudAICompanionProject
// 获取 project_id部分账户类型可能没有,失败时重试
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenResp.AccessToken, proxyURL, 3)
if loadErr != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr)
result.ProjectIDMissing = true
} else {
result.ProjectID = projectID
}
return result, nil
@@ -237,21 +238,60 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
tokenInfo.Email = existingEmail
}
// 每次刷新都调用 LoadCodeAssist 获取 project_id
client := antigravity.NewClient(proxyURL)
loadResp, _, err := client.LoadCodeAssist(ctx, tokenInfo.AccessToken)
if err != nil || loadResp == nil || loadResp.CloudAICompanionProject == "" {
// LoadCodeAssist 失败或返回空,保留原有 project_id标记缺失
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
// 每次刷新都调用 LoadCodeAssist 获取 project_id,失败时重试
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
if loadErr != nil {
// LoadCodeAssist 失败,保留原有 project_id
tokenInfo.ProjectID = existingProjectID
tokenInfo.ProjectIDMissing = true
// 只有从未获取过 project_id 且本次也获取失败时,才标记为真正缺失
// 如果之前有 project_id本次只是临时故障不应标记为错误
if existingProjectID == "" {
tokenInfo.ProjectIDMissing = true
}
} else {
tokenInfo.ProjectID = loadResp.CloudAICompanionProject
tokenInfo.ProjectID = projectID
}
return tokenInfo, nil
}
// loadProjectIDWithRetry 带重试机制获取 project_id
// 返回 project_id 和错误,失败时会重试指定次数
func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, accessToken, proxyURL string, maxRetries int) (string, error) {
var lastErr error
for attempt := 0; attempt <= maxRetries; attempt++ {
if attempt > 0 {
// 指数退避1s, 2s, 4s
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
if backoff > 8*time.Second {
backoff = 8 * time.Second
}
time.Sleep(backoff)
}
client := antigravity.NewClient(proxyURL)
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
return loadResp.CloudAICompanionProject, nil
}
// 记录错误
if err != nil {
lastErr = err
} else if loadResp == nil {
lastErr = fmt.Errorf("LoadCodeAssist 返回空响应")
} else {
lastErr = fmt.Errorf("LoadCodeAssist 返回空 project_id")
}
}
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
}
// BuildAccountCredentials 构建账户凭证
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
creds := map[string]any{

View File

@@ -94,14 +94,14 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
var handleErrorCalled bool
result, err := antigravityRetryLoop(antigravityRetryLoopParams{
prefix: "[test]",
ctx: context.Background(),
account: account,
proxyURL: "",
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
quotaScope: AntigravityQuotaScopeClaude,
prefix: "[test]",
ctx: context.Background(),
account: account,
proxyURL: "",
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
quotaScope: AntigravityQuotaScopeClaude,
httpUpstream: upstream,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
handleErrorCalled = true

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"log"
"log/slog"
"strconv"
"strings"
"time"
@@ -101,21 +102,32 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return "", errors.New("access_token not found in credentials")
}
// 3. 存入缓存
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if p.tokenCache != nil {
ttl := 30 * time.Minute
if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > antigravityTokenCacheSkew:
ttl = until - antigravityTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if isStale && latestAccount != nil {
// 版本过时,使用 DB 中的最新 token
slog.Debug("antigravity_token_version_stale_use_latest", "account_id", account.ID)
accessToken = latestAccount.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found after version check")
}
// 不写入缓存,让下次请求重新处理
} else {
ttl := 30 * time.Minute
if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > antigravityTokenCacheSkew:
ttl = until - antigravityTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
return accessToken, nil

View File

@@ -3,6 +3,8 @@ package service
import (
"context"
"fmt"
"log"
"strings"
"time"
)
@@ -55,15 +57,32 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
}
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
// 合并旧的 credentials保留新 credentials 中不存在的字段
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
// 如果 project_id 获取失败,返回 credentials 但同时返回错误让账户被标记
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
// 这确保了即使 LoadCodeAssist 失败project_id 也不会丢失
if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
newCredentials["project_id"] = oldProjectID
}
}
// 如果 project_id 获取失败,只记录警告,不返回错误
// LoadCodeAssist 失败可能是临时网络问题,应该允许重试而不是立即标记为不可重试错误
// Token 刷新本身是成功的access_token 和 refresh_token 已更新)
if tokenInfo.ProjectIDMissing {
return newCredentials, fmt.Errorf("missing_project_id: 账户缺少project id可能无法使用Antigravity")
if tokenInfo.ProjectID != "" {
// 有旧的 project_id本次获取失败保留旧值
log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 临时失败,保留旧 project_id", account.ID)
} else {
// 从未获取过 project_id本次也失败但不返回错误以允许下次重试
log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 失败project_id 缺失,但 token 已更新,将在下次刷新时重试", account.ID)
}
}
return newCredentials, nil

View File

@@ -153,8 +153,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, ErrServiceUnavailable
}
// 应用优惠码(如果提供)
if promoCode != "" && s.promoService != nil {
// 应用优惠码(如果提供且功能已启用
if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) {
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
// 优惠码应用失败不影响注册,只记录日志
log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err)
@@ -580,3 +580,149 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
// 生成新token
return s.GenerateToken(user)
}
// IsPasswordResetEnabled 检查是否启用密码重置功能
// 要求:必须同时开启邮件验证且 SMTP 配置正确
func (s *AuthService) IsPasswordResetEnabled(ctx context.Context) bool {
if s.settingService == nil {
return false
}
// Must have email verification enabled and SMTP configured
if !s.settingService.IsEmailVerifyEnabled(ctx) {
return false
}
return s.settingService.IsPasswordResetEnabled(ctx)
}
// preparePasswordReset validates the password reset request and returns necessary data
// Returns (siteName, resetURL, shouldProceed)
// shouldProceed is false when we should silently return success (to prevent enumeration)
func (s *AuthService) preparePasswordReset(ctx context.Context, email, frontendBaseURL string) (string, string, bool) {
// Check if user exists (but don't reveal this to the caller)
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
// Security: Log but don't reveal that user doesn't exist
log.Printf("[Auth] Password reset requested for non-existent email: %s", email)
return "", "", false
}
log.Printf("[Auth] Database error checking email for password reset: %v", err)
return "", "", false
}
// Check if user is active
if !user.IsActive() {
log.Printf("[Auth] Password reset requested for inactive user: %s", email)
return "", "", false
}
// Get site name
siteName := "Sub2API"
if s.settingService != nil {
siteName = s.settingService.GetSiteName(ctx)
}
// Build reset URL base
resetURL := fmt.Sprintf("%s/reset-password", strings.TrimSuffix(frontendBaseURL, "/"))
return siteName, resetURL, true
}
// RequestPasswordReset 请求密码重置(同步发送)
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendBaseURL string) error {
if !s.IsPasswordResetEnabled(ctx) {
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
}
if s.emailService == nil {
return ErrServiceUnavailable
}
siteName, resetURL, shouldProceed := s.preparePasswordReset(ctx, email, frontendBaseURL)
if !shouldProceed {
return nil // Silent success to prevent enumeration
}
if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
log.Printf("[Auth] Failed to send password reset email to %s: %v", email, err)
return nil // Silent success to prevent enumeration
}
log.Printf("[Auth] Password reset email sent to: %s", email)
return nil
}
// RequestPasswordResetAsync 异步请求密码重置(队列发送)
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, frontendBaseURL string) error {
if !s.IsPasswordResetEnabled(ctx) {
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
}
if s.emailQueueService == nil {
return ErrServiceUnavailable
}
siteName, resetURL, shouldProceed := s.preparePasswordReset(ctx, email, frontendBaseURL)
if !shouldProceed {
return nil // Silent success to prevent enumeration
}
if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL); err != nil {
log.Printf("[Auth] Failed to enqueue password reset email for %s: %v", email, err)
return nil // Silent success to prevent enumeration
}
log.Printf("[Auth] Password reset email enqueued for: %s", email)
return nil
}
// ResetPassword 重置密码
// Security: Increments TokenVersion to invalidate all existing JWT tokens
func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPassword string) error {
// Check if password reset is enabled
if !s.IsPasswordResetEnabled(ctx) {
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
}
if s.emailService == nil {
return ErrServiceUnavailable
}
// Verify and consume the reset token (one-time use)
if err := s.emailService.ConsumePasswordResetToken(ctx, email, token); err != nil {
return err
}
// Get user
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
return ErrInvalidResetToken // Token was valid but user was deleted
}
log.Printf("[Auth] Database error getting user for password reset: %v", err)
return ErrServiceUnavailable
}
// Check if user is active
if !user.IsActive() {
return ErrUserNotActive
}
// Hash new password
hashedPassword, err := s.HashPassword(newPassword)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
// Update password and increment TokenVersion
user.PasswordHash = hashedPassword
user.TokenVersion++ // Invalidate all existing tokens
if err := s.userRepo.Update(ctx, user); err != nil {
log.Printf("[Auth] Database error updating password for user %d: %v", user.ID, err)
return ErrServiceUnavailable
}
log.Printf("[Auth] Password reset successful for user: %s", email)
return nil
}

View File

@@ -71,6 +71,26 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin
return nil
}
func (s *emailCacheStub) GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error) {
return nil, nil
}
func (s *emailCacheStub) SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error {
return nil
}
func (s *emailCacheStub) DeletePasswordResetToken(ctx context.Context, email string) error {
return nil
}
func (s *emailCacheStub) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
return false
}
func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
return nil
}
func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService {
cfg := &config.Config{
JWT: config.JWTConfig{

View File

@@ -181,26 +181,37 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return "", errors.New("access_token not found in credentials")
}
// 3. 存入缓存
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if p.tokenCache != nil {
ttl := 30 * time.Minute
if refreshFailed {
// 刷新失败时使用短 TTL避免失效 token 长时间缓存导致 401 抖动
ttl = time.Minute
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
} else if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > claudeTokenCacheSkew:
ttl = until - claudeTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if isStale && latestAccount != nil {
// 版本过时,使用 DB 中的最新 token
slog.Debug("claude_token_version_stale_use_latest", "account_id", account.ID)
accessToken = latestAccount.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found after version check")
}
// 不写入缓存,让下次请求重新处理
} else {
ttl := 30 * time.Minute
if refreshFailed {
// 刷新失败时使用短 TTL避免失效 token 长时间缓存导致 401 抖动
ttl = time.Minute
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
} else if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > claudeTokenCacheSkew:
ttl = until - claudeTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err)
}
}
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err)
}
}

View File

@@ -69,8 +69,10 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// Setting keys
const (
// 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
// 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
@@ -86,6 +88,9 @@ const (
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
// TOTP 双因素认证设置
SettingKeyTotpEnabled = "totp_enabled" // 是否启用 TOTP 2FA 功能
// LinuxDo Connect OAuth 登录设置
SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled"
SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id"
@@ -93,13 +98,16 @@ const (
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
// OEM设置
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyAPIBaseURL = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocURL = "doc_url" // 文档链接
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML或 URL 作为 iframe src
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyAPIBaseURL = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocURL = "doc_url" // 文档链接
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML或 URL 作为 iframe src
SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮
SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示“购买订阅”页面入口
SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // “购买订阅”页面 URL作为 iframe src
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量

View File

@@ -8,11 +8,18 @@ import (
"time"
)
// Task type constants
const (
TaskTypeVerifyCode = "verify_code"
TaskTypePasswordReset = "password_reset"
)
// EmailTask 邮件发送任务
type EmailTask struct {
Email string
SiteName string
TaskType string // "verify_code"
TaskType string // "verify_code" or "password_reset"
ResetURL string // Only used for password_reset task type
}
// EmailQueueService 异步邮件队列服务
@@ -73,12 +80,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
defer cancel()
switch task.TaskType {
case "verify_code":
case TaskTypeVerifyCode:
if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil {
log.Printf("[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err)
} else {
log.Printf("[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email)
}
case TaskTypePasswordReset:
if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL); err != nil {
log.Printf("[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err)
} else {
log.Printf("[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email)
}
default:
log.Printf("[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType)
}
@@ -89,7 +102,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
task := EmailTask{
Email: email,
SiteName: siteName,
TaskType: "verify_code",
TaskType: TaskTypeVerifyCode,
}
select {
@@ -101,6 +114,24 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
}
}
// EnqueuePasswordReset 将密码重置邮件任务加入队列
func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL string) error {
task := EmailTask{
Email: email,
SiteName: siteName,
TaskType: TaskTypePasswordReset,
ResetURL: resetURL,
}
select {
case s.taskChan <- task:
log.Printf("[EmailQueue] Enqueued password reset task for %s", email)
return nil
default:
return fmt.Errorf("email queue is full")
}
}
// Stop 停止队列服务
func (s *EmailQueueService) Stop() {
close(s.stopChan)

View File

@@ -3,11 +3,14 @@ package service
import (
"context"
"crypto/rand"
"crypto/subtle"
"crypto/tls"
"encoding/hex"
"fmt"
"log"
"math/big"
"net/smtp"
"net/url"
"strconv"
"time"
@@ -19,6 +22,9 @@ var (
ErrInvalidVerifyCode = infraerrors.BadRequest("INVALID_VERIFY_CODE", "invalid or expired verification code")
ErrVerifyCodeTooFrequent = infraerrors.TooManyRequests("VERIFY_CODE_TOO_FREQUENT", "please wait before requesting a new code")
ErrVerifyCodeMaxAttempts = infraerrors.TooManyRequests("VERIFY_CODE_MAX_ATTEMPTS", "too many failed attempts, please request a new code")
// Password reset errors
ErrInvalidResetToken = infraerrors.BadRequest("INVALID_RESET_TOKEN", "invalid or expired password reset token")
)
// EmailCache defines cache operations for email service
@@ -26,6 +32,16 @@ type EmailCache interface {
GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error)
SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
DeleteVerificationCode(ctx context.Context, email string) error
// Password reset token methods
GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error)
SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error
DeletePasswordResetToken(ctx context.Context, email string) error
// Password reset email cooldown methods
// Returns true if in cooldown period (email was sent recently)
IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool
SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error
}
// VerificationCodeData represents verification code data
@@ -35,10 +51,22 @@ type VerificationCodeData struct {
CreatedAt time.Time
}
// PasswordResetTokenData represents password reset token data
type PasswordResetTokenData struct {
Token string
CreatedAt time.Time
}
const (
verifyCodeTTL = 15 * time.Minute
verifyCodeCooldown = 1 * time.Minute
maxVerifyCodeAttempts = 5
// Password reset token settings
passwordResetTokenTTL = 30 * time.Minute
// Password reset email cooldown (prevent email bombing)
passwordResetEmailCooldown = 30 * time.Second
)
// SMTPConfig SMTP配置
@@ -254,8 +282,8 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
return ErrVerifyCodeMaxAttempts
}
// 验证码不匹配
if data.Code != code {
// 验证码不匹配 (constant-time comparison to prevent timing attacks)
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
data.Attempts++
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
log.Printf("[Email] Failed to update verification attempt count: %v", err)
@@ -357,3 +385,157 @@ func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
return client.Quit()
}
// GeneratePasswordResetToken generates a secure 32-byte random token (64 hex characters)
func (s *EmailService) GeneratePasswordResetToken() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// SendPasswordResetEmail sends a password reset email with a reset link
func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteName, resetURL string) error {
var token string
var needSaveToken bool
// Check if token already exists
existing, err := s.cache.GetPasswordResetToken(ctx, email)
if err == nil && existing != nil {
// Token exists, reuse it (allows resending email without generating new token)
token = existing.Token
needSaveToken = false
} else {
// Generate new token
token, err = s.GeneratePasswordResetToken()
if err != nil {
return fmt.Errorf("generate token: %w", err)
}
needSaveToken = true
}
// Save token to Redis (only if new token generated)
if needSaveToken {
data := &PasswordResetTokenData{
Token: token,
CreatedAt: time.Now(),
}
if err := s.cache.SetPasswordResetToken(ctx, email, data, passwordResetTokenTTL); err != nil {
return fmt.Errorf("save reset token: %w", err)
}
}
// Build full reset URL with URL-encoded token and email
fullResetURL := fmt.Sprintf("%s?email=%s&token=%s", resetURL, url.QueryEscape(email), url.QueryEscape(token))
// Build email content
subject := fmt.Sprintf("[%s] 密码重置请求", siteName)
body := s.buildPasswordResetEmailBody(fullResetURL, siteName)
// Send email
if err := s.SendEmail(ctx, email, subject, body); err != nil {
return fmt.Errorf("send email: %w", err)
}
return nil
}
// SendPasswordResetEmailWithCooldown sends password reset email with cooldown check (called by queue worker)
// This method wraps SendPasswordResetEmail with email cooldown to prevent email bombing
func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error {
// Check email cooldown to prevent email bombing
if s.cache.IsPasswordResetEmailInCooldown(ctx, email) {
log.Printf("[Email] Password reset email skipped (cooldown): %s", email)
return nil // Silent success to prevent revealing cooldown to attackers
}
// Send email using core method
if err := s.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
return err
}
// Set cooldown marker (Redis TTL handles expiration)
if err := s.cache.SetPasswordResetEmailCooldown(ctx, email, passwordResetEmailCooldown); err != nil {
log.Printf("[Email] Failed to set password reset cooldown for %s: %v", email, err)
}
return nil
}
// VerifyPasswordResetToken verifies the password reset token without consuming it
func (s *EmailService) VerifyPasswordResetToken(ctx context.Context, email, token string) error {
data, err := s.cache.GetPasswordResetToken(ctx, email)
if err != nil || data == nil {
return ErrInvalidResetToken
}
// Use constant-time comparison to prevent timing attacks
if subtle.ConstantTimeCompare([]byte(data.Token), []byte(token)) != 1 {
return ErrInvalidResetToken
}
return nil
}
// ConsumePasswordResetToken verifies and deletes the token (one-time use)
func (s *EmailService) ConsumePasswordResetToken(ctx context.Context, email, token string) error {
// Verify first
if err := s.VerifyPasswordResetToken(ctx, email, token); err != nil {
return err
}
// Delete after verification (one-time use)
if err := s.cache.DeletePasswordResetToken(ctx, email); err != nil {
log.Printf("[Email] Failed to delete password reset token after consumption: %v", err)
}
return nil
}
// buildPasswordResetEmailBody builds the HTML content for password reset email
func (s *EmailService) buildPasswordResetEmailBody(resetURL, siteName string) string {
return fmt.Sprintf(`
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; text-align: center; }
.button { display: inline-block; background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 14px 32px; text-decoration: none; border-radius: 8px; font-size: 16px; font-weight: 600; margin: 20px 0; }
.button:hover { opacity: 0.9; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
.link-fallback { color: #666; font-size: 12px; word-break: break-all; margin-top: 20px; padding: 15px; background-color: #f8f9fa; border-radius: 4px; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
.warning { color: #e74c3c; font-weight: 500; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>%s</h1>
</div>
<div class="content">
<p style="font-size: 18px; color: #333;">密码重置请求</p>
<p style="color: #666;">您已请求重置密码。请点击下方按钮设置新密码:</p>
<a href="%s" class="button">重置密码</a>
<div class="info">
<p>此链接将在 <strong>30 分钟</strong>后失效。</p>
<p class="warning">如果您没有请求重置密码,请忽略此邮件。您的密码将保持不变。</p>
</div>
<div class="link-fallback">
<p>如果按钮无法点击,请复制以下链接到浏览器中打开:</p>
<p>%s</p>
</div>
</div>
<div class="footer">
<p>这是一封自动发送的邮件,请勿回复。</p>
</div>
</div>
</body>
</html>
`, siteName, resetURL, resetURL)
}

File diff suppressed because it is too large Load Diff

View File

@@ -99,11 +99,24 @@ var allowedHeaders = map[string]bool{
"content-type": true,
}
// GatewayCache defines cache operations for gateway service
// GatewayCache 定义网关服务的缓存操作接口。
// 提供粘性会话Sticky Session的存储、查询、刷新和删除功能。
//
// GatewayCache defines cache operations for gateway service.
// Provides sticky session storage, retrieval, refresh and deletion capabilities.
type GatewayCache interface {
// GetSessionAccountID 获取粘性会话绑定的账号 ID
// Get the account ID bound to a sticky session
GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error)
// SetSessionAccountID 设置粘性会话与账号的绑定关系
// Set the binding between sticky session and account
SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error
// RefreshSessionTTL 刷新粘性会话的过期时间
// Refresh the expiration time of a sticky session
RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
}
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
@@ -114,6 +127,28 @@ func derefGroupID(groupID *int64) int64 {
return *groupID
}
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
// 当账号状态为错误、禁用、不可调度,或处于临时不可调度期间时,返回 true。
// 这确保后续请求不会继续使用不可用的账号。
//
// shouldClearStickySession checks if an account is in an unschedulable state
// and the sticky session binding should be cleared.
// Returns true when account status is error/disabled, schedulable is false,
// or within temporary unschedulable period.
// This ensures subsequent requests won't continue using unavailable accounts.
func shouldClearStickySession(account *Account) bool {
if account == nil {
return false
}
if account.Status == StatusError || account.Status == StatusDisabled || !account.Schedulable {
return true
}
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
return true
}
return false
}
type AccountWaitPlan struct {
AccountID int64
MaxConcurrency int
@@ -270,6 +305,19 @@ func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64,
return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, accountID, stickySessionTTL)
}
// GetCachedSessionAccountID retrieves the account ID bound to a sticky session.
// Returns 0 if no binding exists or on error.
func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID *int64, sessionHash string) (int64, error) {
if sessionHash == "" || s.cache == nil {
return 0, nil
}
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err != nil {
return 0, err
}
return accountID, nil
}
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
if parsed == nil {
return ""
@@ -658,6 +706,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
}
} else {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
}
}
@@ -764,41 +814,52 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, ok := accountByID[accountID]
if ok && s.isAccountInGroup(account, groupID) &&
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulableForModel(requestedModel) &&
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
if ok {
// 检查账户是否需要清理粘性会话绑定
// Check if the account needs sticky session cleanup
clearSticky := shouldClearStickySession(account)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) &&
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulableForModel(requestedModel) &&
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
// Session count limit check
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
// 会话数量限制检查(等待计划也需要占用会话配额)
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
// 会话限制已满,继续到 Layer 2
} else {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
// 会话数量限制检查(等待计划也需要占用会话配额)
// Session count limit check (wait plan also requires session quota)
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
// 会话限制已满,继续到 Layer 2
// Session limit full, continue to Layer 2
} else {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
}
}
@@ -1418,14 +1479,20 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
if err == nil {
clearSticky := shouldClearStickySession(account)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
return account, nil
}
return account, nil
}
}
}
@@ -1515,11 +1582,17 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
if err == nil {
clearSticky := shouldClearStickySession(account)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
}
return account, nil
}
}
}
@@ -1619,15 +1692,21 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和有效性原生平台直接匹配antigravity 需要启用混合调度
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
if err == nil {
clearSticky := shouldClearStickySession(account)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
return account, nil
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
return account, nil
}
}
}
@@ -1718,12 +1797,18 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和有效性原生平台直接匹配antigravity 需要启用混合调度
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
if err == nil {
clearSticky := shouldClearStickySession(account)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
}
return account, nil
}
}
}
@@ -3287,17 +3372,19 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
} `json:"usage"`
}
if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" {
// output_tokens 总是从 message_delta 获取
usage.OutputTokens = msgDelta.Usage.OutputTokens
// 如果 message_start 中没有值,则从 message_delta 获取兼容GLM等API
if usage.InputTokens == 0 {
// message_delta 仅覆盖存在且非0的字段
// 避免覆盖 message_start 中已有的值(如 input_tokens
// Claude API 的 message_delta 通常只包含 output_tokens
if msgDelta.Usage.InputTokens > 0 {
usage.InputTokens = msgDelta.Usage.InputTokens
}
if usage.CacheCreationInputTokens == 0 {
if msgDelta.Usage.OutputTokens > 0 {
usage.OutputTokens = msgDelta.Usage.OutputTokens
}
if msgDelta.Usage.CacheCreationInputTokens > 0 {
usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens
}
if usage.CacheReadInputTokens == 0 {
if msgDelta.Usage.CacheReadInputTokens > 0 {
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
}
}

View File

@@ -82,70 +82,23 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
}
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 优先检查 context 中的强制平台(/antigravity 路由)
var platform string
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
platform = forcePlatform
} else if groupID != nil {
// 根据分组 platform 决定查询哪种账号
var group *Group
if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
group = ctxGroup
} else {
var err error
group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
}
}
platform = group.Platform
} else {
// 无分组时只使用原生 gemini 平台
platform = PlatformGemini
// 1. 确定目标平台和调度模式
// Determine target platform and scheduling mode
platform, useMixedScheduling, hasForcePlatform, err := s.resolvePlatformAndSchedulingMode(ctx, groupID)
if err != nil {
return nil, err
}
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
cacheKey := "gemini:" + sessionHash
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号是否有效原生平台直接匹配antigravity 需要启用混合调度
if err == nil && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
valid := false
if account.Platform == platform {
valid = true
} else if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
valid = true
}
if valid {
usable := true
if s.rateLimitService != nil && requestedModel != "" {
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
}
if !ok {
usable = false
}
}
if usable {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL)
return account, nil
}
}
}
}
}
// 2. 尝试粘性会话命中
// Try sticky session hit
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs, platform, useMixedScheduling); account != nil {
return account, nil
}
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
// 3. 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
// Query schedulable accounts (force platform mode: try group first, fallback to all)
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
@@ -158,56 +111,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
}
}
var selected *Account
for i := range accounts {
acc := &accounts[i]
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// 混合调度模式下原生平台直接通过antigravity 需要启用 mixed_scheduling
// 非混合调度模式antigravity 分组):不需要过滤
if useMixedScheduling && acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
if s.rateLimitService != nil && requestedModel != "" {
ok, err := s.rateLimitService.PreCheckUsage(ctx, acc, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", acc.ID, err)
}
if !ok {
continue
}
}
if selected == nil {
selected = acc
continue
}
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
// Prefer OAuth accounts when both are unused (more compatible for Code Assist flows).
if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
}
}
// 4. 按优先级 + LRU 选择最佳账号
// Select best account by priority + LRU
selected := s.selectBestGeminiAccount(ctx, accounts, requestedModel, excludedIDs, platform, useMixedScheduling)
if selected == nil {
if requestedModel != "" {
@@ -216,6 +122,8 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
return nil, errors.New("no available Gemini accounts")
}
// 5. 设置粘性会话绑定
// Set sticky session binding
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL)
}
@@ -223,6 +131,229 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
return selected, nil
}
// resolvePlatformAndSchedulingMode 解析目标平台和调度模式。
// 返回:平台名称、是否使用混合调度、是否强制平台、错误。
//
// resolvePlatformAndSchedulingMode resolves target platform and scheduling mode.
// Returns: platform name, whether to use mixed scheduling, whether force platform, error.
func (s *GeminiMessagesCompatService) resolvePlatformAndSchedulingMode(ctx context.Context, groupID *int64) (platform string, useMixedScheduling bool, hasForcePlatform bool, err error) {
// 优先检查 context 中的强制平台(/antigravity 路由)
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
return forcePlatform, false, true, nil
}
if groupID != nil {
// 根据分组 platform 决定查询哪种账号
var group *Group
if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
group = ctxGroup
} else {
group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
if err != nil {
return "", false, false, fmt.Errorf("get group failed: %w", err)
}
}
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
return group.Platform, group.Platform == PlatformGemini, false, nil
}
// 无分组时只使用原生 gemini 平台
return PlatformGemini, true, false, nil
}
// tryStickySessionHit 尝试从粘性会话获取账号。
// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。
//
// tryStickySessionHit attempts to get account from sticky session.
// Returns account if hit and usable; clears session and returns nil if account unavailable.
func (s *GeminiMessagesCompatService) tryStickySessionHit(
ctx context.Context,
groupID *int64,
sessionHash, cacheKey, requestedModel string,
excludedIDs map[int64]struct{},
platform string,
useMixedScheduling bool,
) *Account {
if sessionHash == "" {
return nil
}
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
if err != nil || accountID <= 0 {
return nil
}
if _, excluded := excludedIDs[accountID]; excluded {
return nil
}
account, err := s.getSchedulableAccount(ctx, accountID)
if err != nil {
return nil
}
// 检查账号是否需要清理粘性会话
// Check if sticky session should be cleared
if shouldClearStickySession(account) {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
return nil
}
// 验证账号是否可用于当前请求
// Verify account is usable for current request
if !s.isAccountUsableForRequest(ctx, account, requestedModel, platform, useMixedScheduling) {
return nil
}
// 刷新会话 TTL 并返回账号
// Refresh session TTL and return account
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL)
return account
}
// isAccountUsableForRequest 检查账号是否可用于当前请求。
// 验证:模型调度、模型支持、平台匹配、速率限制预检。
//
// isAccountUsableForRequest checks if account is usable for current request.
// Validates: model scheduling, model support, platform matching, rate limit precheck.
func (s *GeminiMessagesCompatService) isAccountUsableForRequest(
ctx context.Context,
account *Account,
requestedModel, platform string,
useMixedScheduling bool,
) bool {
// 检查模型调度能力
// Check model scheduling capability
if !account.IsSchedulableForModel(requestedModel) {
return false
}
// 检查模型支持
// Check model support
if requestedModel != "" && !s.isModelSupportedByAccount(account, requestedModel) {
return false
}
// 检查平台匹配
// Check platform matching
if !s.isAccountValidForPlatform(account, platform, useMixedScheduling) {
return false
}
// 速率限制预检
// Rate limit precheck
if !s.passesRateLimitPreCheck(ctx, account, requestedModel) {
return false
}
return true
}
// isAccountValidForPlatform 检查账号是否匹配目标平台。
// 原生平台直接匹配;混合调度模式下 antigravity 需要启用 mixed_scheduling。
//
// isAccountValidForPlatform checks if account matches target platform.
// Native platform matches directly; mixed scheduling mode requires antigravity to enable mixed_scheduling.
func (s *GeminiMessagesCompatService) isAccountValidForPlatform(account *Account, platform string, useMixedScheduling bool) bool {
if account.Platform == platform {
return true
}
if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
return true
}
return false
}
// passesRateLimitPreCheck 执行速率限制预检。
// 返回 true 表示通过预检或无需预检。
//
// passesRateLimitPreCheck performs rate limit precheck.
// Returns true if passed or precheck not required.
func (s *GeminiMessagesCompatService) passesRateLimitPreCheck(ctx context.Context, account *Account, requestedModel string) bool {
if s.rateLimitService == nil || requestedModel == "" {
return true
}
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
}
return ok
}
// selectBestGeminiAccount 从候选账号中选择最佳账号(优先级 + LRU + OAuth 优先)。
// 返回 nil 表示无可用账号。
//
// selectBestGeminiAccount selects best account from candidates (priority + LRU + OAuth preferred).
// Returns nil if no available account.
func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
ctx context.Context,
accounts []Account,
requestedModel string,
excludedIDs map[int64]struct{},
platform string,
useMixedScheduling bool,
) *Account {
var selected *Account
for i := range accounts {
acc := &accounts[i]
// 跳过被排除的账号
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// 检查账号是否可用于当前请求
if !s.isAccountUsableForRequest(ctx, acc, requestedModel, platform, useMixedScheduling) {
continue
}
// 选择最佳账号
if selected == nil {
selected = acc
continue
}
if s.isBetterGeminiAccount(acc, selected) {
selected = acc
}
}
return selected
}
// isBetterGeminiAccount 判断 candidate 是否比 current 更优。
// 规则优先级更高数值更小优先同优先级时未使用过的优先OAuth > 非 OAuth其次是最久未使用的。
//
// isBetterGeminiAccount checks if candidate is better than current.
// Rules: higher priority (lower value) wins; same priority: never used (OAuth > non-OAuth) > least recently used.
func (s *GeminiMessagesCompatService) isBetterGeminiAccount(candidate, current *Account) bool {
// 优先级更高(数值更小)
if candidate.Priority < current.Priority {
return true
}
if candidate.Priority > current.Priority {
return false
}
// 同优先级,比较最后使用时间
switch {
case candidate.LastUsedAt == nil && current.LastUsedAt != nil:
// candidate 从未使用,优先
return true
case candidate.LastUsedAt != nil && current.LastUsedAt == nil:
// current 从未使用,保持
return false
case candidate.LastUsedAt == nil && current.LastUsedAt == nil:
// 都未使用,优先选择 OAuth 账号(更兼容 Code Assist 流程)
return candidate.Type == AccountTypeOAuth && current.Type != AccountTypeOAuth
default:
// 都使用过,选择最久未使用的
return candidate.LastUsedAt.Before(*current.LastUsedAt)
}
}
// isModelSupportedByAccount 根据账户平台检查模型支持
func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
@@ -800,6 +931,13 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
}
// 图片生成计费
imageCount := 0
imageSize := s.extractImageSize(body)
if isImageGenerationModel(originalModel) {
imageCount = 1
}
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
@@ -807,6 +945,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
Stream: req.Stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ImageCount: imageCount,
ImageSize: imageSize,
}, nil
}
@@ -1240,6 +1380,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
usage = &ClaudeUsage{}
}
// 图片生成计费
imageCount := 0
imageSize := s.extractImageSize(body)
if isImageGenerationModel(originalModel) {
imageCount = 1
}
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
@@ -1247,6 +1394,8 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
Stream: stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ImageCount: imageCount,
ImageSize: imageSize,
}, nil
}
@@ -1841,6 +1990,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
var last map[string]any
var lastWithParts map[string]any
var collectedTextParts []string // Collect all text parts for aggregation
usage := &ClaudeUsage{}
for {
@@ -1852,7 +2002,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
switch payload {
case "", "[DONE]":
if payload == "[DONE]" {
return pickGeminiCollectResult(last, lastWithParts), usage, nil
return mergeCollectedTextParts(pickGeminiCollectResult(last, lastWithParts), collectedTextParts), usage, nil
}
default:
var parsed map[string]any
@@ -1871,6 +2021,12 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
}
if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed
// Collect text from each part for aggregation
for _, part := range parts {
if text, ok := part["text"].(string); ok && text != "" {
collectedTextParts = append(collectedTextParts, text)
}
}
}
}
}
@@ -1885,7 +2041,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
}
}
return pickGeminiCollectResult(last, lastWithParts), usage, nil
return mergeCollectedTextParts(pickGeminiCollectResult(last, lastWithParts), collectedTextParts), usage, nil
}
func pickGeminiCollectResult(last map[string]any, lastWithParts map[string]any) map[string]any {
@@ -1898,6 +2054,83 @@ func pickGeminiCollectResult(last map[string]any, lastWithParts map[string]any)
return map[string]any{}
}
// mergeCollectedTextParts merges all collected text chunks into the final response.
// This fixes the issue where non-streaming responses only returned the last chunk
// instead of the complete aggregated text.
func mergeCollectedTextParts(response map[string]any, textParts []string) map[string]any {
if len(textParts) == 0 {
return response
}
// Join all text parts
mergedText := strings.Join(textParts, "")
// Deep copy response
result := make(map[string]any)
for k, v := range response {
result[k] = v
}
// Get or create candidates
candidates, ok := result["candidates"].([]any)
if !ok || len(candidates) == 0 {
candidates = []any{map[string]any{}}
}
// Get first candidate
candidate, ok := candidates[0].(map[string]any)
if !ok {
candidate = make(map[string]any)
candidates[0] = candidate
}
// Get or create content
content, ok := candidate["content"].(map[string]any)
if !ok {
content = map[string]any{"role": "model"}
candidate["content"] = content
}
// Get existing parts
existingParts, ok := content["parts"].([]any)
if !ok {
existingParts = []any{}
}
// Find and update first text part, or create new one
newParts := make([]any, 0, len(existingParts)+1)
textUpdated := false
for _, p := range existingParts {
pm, ok := p.(map[string]any)
if !ok {
newParts = append(newParts, p)
continue
}
if _, hasText := pm["text"]; hasText && !textUpdated {
// Replace with merged text
newPart := make(map[string]any)
for k, v := range pm {
newPart[k] = v
}
newPart["text"] = mergedText
newParts = append(newParts, newPart)
textUpdated = true
} else {
newParts = append(newParts, pm)
}
}
if !textUpdated {
newParts = append([]any{map[string]any{"text": mergedText}}, newParts...)
}
content["parts"] = newParts
result["candidates"] = candidates
return result
}
type geminiNativeStreamResult struct {
usage *ClaudeUsage
firstTokenMs *int
@@ -2816,3 +3049,26 @@ func convertClaudeGenerationConfig(req map[string]any) map[string]any {
}
return out
}
// extractImageSize 从 Gemini 请求中提取 image_size 参数
func (s *GeminiMessagesCompatService) extractImageSize(body []byte) string {
var req struct {
GenerationConfig *struct {
ImageConfig *struct {
ImageSize string `json:"imageSize"`
} `json:"imageConfig"`
} `json:"generationConfig"`
}
if err := json.Unmarshal(body, &req); err != nil {
return "2K"
}
if req.GenerationConfig != nil && req.GenerationConfig.ImageConfig != nil {
size := strings.ToUpper(strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize))
if size == "1K" || size == "2K" || size == "4K" {
return size
}
}
return "2K"
}

View File

@@ -15,8 +15,10 @@ import (
// mockAccountRepoForGemini Gemini 测试用的 mock
type mockAccountRepoForGemini struct {
accounts []Account
accountsByID map[int64]*Account
accounts []Account
accountsByID map[int64]*Account
listByGroupFunc func(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
listByPlatformFunc func(ctx context.Context, platforms []string) ([]Account, error)
}
func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Account, error) {
@@ -107,6 +109,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupID(ctx context.Context,
return nil, nil
}
func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
if m.listByPlatformFunc != nil {
return m.listByPlatformFunc(ctx, platforms)
}
var result []Account
platformSet := make(map[string]bool)
for _, p := range platforms {
@@ -120,6 +125,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Contex
return result, nil
}
func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
if m.listByGroupFunc != nil {
return m.listByGroupFunc(ctx, groupID, platforms)
}
return m.ListSchedulableByPlatforms(ctx, platforms)
}
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
@@ -215,6 +223,7 @@ var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
type mockGatewayCacheForGemini struct {
sessionBindings map[string]int64
deletedSessions map[string]int
}
func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
@@ -236,6 +245,18 @@ func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, group
return nil
}
func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
if m.sessionBindings == nil {
return nil
}
if m.deletedSessions == nil {
m.deletedSessions = make(map[string]int)
}
m.deletedSessions[sessionHash]++
delete(m.sessionBindings, sessionHash)
return nil
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
ctx := context.Background()
@@ -526,6 +547,274 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyS
// 粘性会话未命中,按优先级选择
require.Equal(t, int64(2), acc.ID, "粘性会话未命中,应按优先级选择")
})
t.Run("粘性会话不可调度-清理并回退选择", func(t *testing.T) {
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusDisabled, Schedulable: true},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{
sessionBindings: map[string]int64{"gemini:session-123": 1},
}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
require.Equal(t, 1, cache.deletedSessions["gemini:session-123"])
require.Equal(t, int64(2), cache.sessionBindings["gemini:session-123"])
})
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ForcePlatformFallback(t *testing.T) {
ctx := context.Background()
groupID := int64(9)
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, PlatformAntigravity)
repo := &mockAccountRepoForGemini{
listByGroupFunc: func(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
return nil, nil
},
listByPlatformFunc: func(ctx context.Context, platforms []string) ([]Account, error) {
return []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
}, nil
},
accountsByID: map[int64]*Account{
1: {ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
},
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID)
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoModelSupport(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{
ID: 1,
Platform: PlatformGemini,
Priority: 1,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.0-pro": "gemini-1.0-pro"}},
},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "supporting model")
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyMixedScheduling(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{
sessionBindings: map[string]int64{"gemini:session-999": 1},
}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-999", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID)
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_SkipDisabledMixedScheduling(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ExcludedAccount(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
excluded := map[int64]struct{}{1: {}}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", excluded)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ListError(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
listByPlatformFunc: func(ctx context.Context, platforms []string) ([]Account, error) {
return nil, errors.New("query failed")
},
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "query accounts failed")
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_PreferOAuth(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-pro", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_PreferLeastRecentlyUsed(t *testing.T) {
ctx := context.Background()
oldTime := time.Now().Add(-2 * time.Hour)
newTime := time.Now().Add(-1 * time.Hour)
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &newTime},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &oldTime},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-pro", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
}
// TestGeminiPlatformRouting_DocumentRouteDecision 测试平台路由决策逻辑

View File

@@ -0,0 +1,72 @@
package service
import (
"encoding/json"
)
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中移除 thoughtSignature 字段,
// 以避免跨账号签名验证错误。
//
// 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature
// 会导致新账号的签名验证失败。通过移除这些签名,让新账号重新生成有效的签名。
//
// CleanGeminiNativeThoughtSignatures removes thoughtSignature fields from Gemini native API requests
// to avoid cross-account signature validation errors.
//
// When sticky session switches accounts (e.g., original account becomes unavailable),
// thoughtSignatures from the old account will cause validation failures on the new account.
// By removing these signatures, we allow the new account to generate valid signatures.
func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
if len(body) == 0 {
return body
}
// 解析 JSON
var data any
if err := json.Unmarshal(body, &data); err != nil {
// 如果解析失败,返回原始 body可能不是 JSON 或格式不正确)
return body
}
// 递归清理 thoughtSignature
cleaned := cleanThoughtSignaturesRecursive(data)
// 重新序列化
result, err := json.Marshal(cleaned)
if err != nil {
// 如果序列化失败,返回原始 body
return body
}
return result
}
// cleanThoughtSignaturesRecursive 递归遍历数据结构,移除所有 thoughtSignature 字段
func cleanThoughtSignaturesRecursive(data any) any {
switch v := data.(type) {
case map[string]any:
// 创建新的 map移除 thoughtSignature
result := make(map[string]any, len(v))
for key, value := range v {
// 跳过 thoughtSignature 字段
if key == "thoughtSignature" {
continue
}
// 递归处理嵌套结构
result[key] = cleanThoughtSignaturesRecursive(value)
}
return result
case []any:
// 递归处理数组中的每个元素
result := make([]any, len(v))
for i, item := range v {
result[i] = cleanThoughtSignaturesRecursive(item)
}
return result
default:
// 基本类型string, number, bool, null直接返回
return v
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"log"
"log/slog"
"strconv"
"strings"
"time"
@@ -131,21 +132,32 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
}
// 3) Populate cache with TTL.
// 3) Populate cache with TTL(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if p.tokenCache != nil {
ttl := 30 * time.Minute
if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > geminiTokenCacheSkew:
ttl = until - geminiTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if isStale && latestAccount != nil {
// 版本过时,使用 DB 中的最新 token
slog.Debug("gemini_token_version_stale_use_latest", "account_id", account.ID)
accessToken = latestAccount.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found after version check")
}
// 不写入缓存,让下次请求重新处理
} else {
ttl := 30 * time.Minute
if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > geminiTokenCacheSkew:
ttl = until - geminiTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
return accessToken, nil

View File

@@ -48,8 +48,7 @@ type GenerateAuthURLResult struct {
// GenerateAuthURL generates an OAuth authorization URL with full scope
func (s *OAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64) (*GenerateAuthURLResult, error) {
scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference)
return s.generateAuthURLWithScope(ctx, scope, proxyID)
return s.generateAuthURLWithScope(ctx, oauth.ScopeOAuth, proxyID)
}
// GenerateSetupTokenURL generates an OAuth authorization URL for setup token (inference only)
@@ -123,6 +122,7 @@ type TokenInfo struct {
Scope string `json:"scope,omitempty"`
OrgUUID string `json:"org_uuid,omitempty"`
AccountUUID string `json:"account_uuid,omitempty"`
EmailAddress string `json:"email_address,omitempty"`
}
// ExchangeCode exchanges authorization code for tokens
@@ -176,7 +176,8 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) (
}
// Determine scope and if this is a setup token
scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference)
// Internal API call uses ScopeAPI (org:create_api_key not supported)
scope := oauth.ScopeAPI
isSetupToken := false
if input.Scope == "inference" {
scope = oauth.ScopeInference
@@ -252,9 +253,15 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif
tokenInfo.OrgUUID = tokenResp.Organization.UUID
log.Printf("[OAuth] Got org_uuid: %s", tokenInfo.OrgUUID)
}
if tokenResp.Account != nil && tokenResp.Account.UUID != "" {
tokenInfo.AccountUUID = tokenResp.Account.UUID
log.Printf("[OAuth] Got account_uuid: %s", tokenInfo.AccountUUID)
if tokenResp.Account != nil {
if tokenResp.Account.UUID != "" {
tokenInfo.AccountUUID = tokenResp.Account.UUID
log.Printf("[OAuth] Got account_uuid: %s", tokenInfo.AccountUUID)
}
if tokenResp.Account.EmailAddress != "" {
tokenInfo.EmailAddress = tokenResp.Account.EmailAddress
log.Printf("[OAuth] Got email_address: %s", tokenInfo.EmailAddress)
}
}
return tokenInfo, nil

View File

@@ -60,6 +60,92 @@ type OpenAICodexUsageSnapshot struct {
UpdatedAt string `json:"updated_at,omitempty"`
}
// NormalizedCodexLimits contains normalized 5h/7d rate limit data
type NormalizedCodexLimits struct {
Used5hPercent *float64
Reset5hSeconds *int
Window5hMinutes *int
Used7dPercent *float64
Reset7dSeconds *int
Window7dMinutes *int
}
// Normalize converts primary/secondary fields to canonical 5h/7d fields.
// Strategy: Compare window_minutes to determine which is 5h vs 7d.
// Returns nil if snapshot is nil or has no useful data.
func (s *OpenAICodexUsageSnapshot) Normalize() *NormalizedCodexLimits {
if s == nil {
return nil
}
result := &NormalizedCodexLimits{}
primaryMins := 0
secondaryMins := 0
hasPrimaryWindow := false
hasSecondaryWindow := false
if s.PrimaryWindowMinutes != nil {
primaryMins = *s.PrimaryWindowMinutes
hasPrimaryWindow = true
}
if s.SecondaryWindowMinutes != nil {
secondaryMins = *s.SecondaryWindowMinutes
hasSecondaryWindow = true
}
// Determine mapping based on window_minutes
use5hFromPrimary := false
use7dFromPrimary := false
if hasPrimaryWindow && hasSecondaryWindow {
// Both known: smaller window is 5h, larger is 7d
if primaryMins < secondaryMins {
use5hFromPrimary = true
} else {
use7dFromPrimary = true
}
} else if hasPrimaryWindow {
// Only primary known: classify by threshold (<=360 min = 6h -> 5h window)
if primaryMins <= 360 {
use5hFromPrimary = true
} else {
use7dFromPrimary = true
}
} else if hasSecondaryWindow {
// Only secondary known: classify by threshold
if secondaryMins <= 360 {
// 5h from secondary, so primary (if any data) is 7d
use7dFromPrimary = true
} else {
// 7d from secondary, so primary (if any data) is 5h
use5hFromPrimary = true
}
} else {
// No window_minutes: fall back to legacy assumption (primary=7d, secondary=5h)
use7dFromPrimary = true
}
// Assign values
if use5hFromPrimary {
result.Used5hPercent = s.PrimaryUsedPercent
result.Reset5hSeconds = s.PrimaryResetAfterSeconds
result.Window5hMinutes = s.PrimaryWindowMinutes
result.Used7dPercent = s.SecondaryUsedPercent
result.Reset7dSeconds = s.SecondaryResetAfterSeconds
result.Window7dMinutes = s.SecondaryWindowMinutes
} else if use7dFromPrimary {
result.Used7dPercent = s.PrimaryUsedPercent
result.Reset7dSeconds = s.PrimaryResetAfterSeconds
result.Window7dMinutes = s.PrimaryWindowMinutes
result.Used5hPercent = s.SecondaryUsedPercent
result.Reset5hSeconds = s.SecondaryResetAfterSeconds
result.Window5hMinutes = s.SecondaryWindowMinutes
}
return result
}
// OpenAIUsage represents OpenAI API response usage
type OpenAIUsage struct {
InputTokens int `json:"input_tokens"`
@@ -180,67 +266,26 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
}
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
// SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 1. Check sticky session
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
// Refresh sticky session TTL
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
return account, nil
}
}
}
cacheKey := "openai:" + sessionHash
// 1. 尝试粘性会话命中
// Try sticky session hit
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs); account != nil {
return account, nil
}
// 2. Get schedulable OpenAI accounts
// 2. 获取可调度的 OpenAI 账号
// Get schedulable OpenAI accounts
accounts, err := s.listSchedulableAccounts(ctx, groupID)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
// 3. Select by priority + LRU
var selected *Account
for i := range accounts {
acc := &accounts[i]
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if !acc.IsSchedulable() {
continue
}
// Check model support
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
}
if selected == nil {
selected = acc
continue
}
// Lower priority value means higher priority
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
// keep selected (both never used)
default:
// Same priority, select least recently used
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
}
}
// 3. 按优先级 + LRU 选择最佳账号
// Select by priority + LRU
selected := s.selectBestAccount(accounts, requestedModel, excludedIDs)
if selected == nil {
if requestedModel != "" {
@@ -249,14 +294,138 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
return nil, errors.New("no available OpenAI accounts")
}
// 4. Set sticky session
// 4. 设置粘性会话绑定
// Set sticky session binding
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, selected.ID, openaiStickySessionTTL)
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, openaiStickySessionTTL)
}
return selected, nil
}
// tryStickySessionHit 尝试从粘性会话获取账号。
// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。
//
// tryStickySessionHit attempts to get account from sticky session.
// Returns account if hit and usable; clears session and returns nil if account is unavailable.
func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, cacheKey, requestedModel string, excludedIDs map[int64]struct{}) *Account {
if sessionHash == "" {
return nil
}
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
if err != nil || accountID <= 0 {
return nil
}
if _, excluded := excludedIDs[accountID]; excluded {
return nil
}
account, err := s.getSchedulableAccount(ctx, accountID)
if err != nil {
return nil
}
// 检查账号是否需要清理粘性会话
// Check if sticky session should be cleared
if shouldClearStickySession(account) {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
return nil
}
// 验证账号是否可用于当前请求
// Verify account is usable for current request
if !account.IsSchedulable() || !account.IsOpenAI() {
return nil
}
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
return nil
}
// 刷新会话 TTL 并返回账号
// Refresh session TTL and return account
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, openaiStickySessionTTL)
return account
}
// selectBestAccount 从候选账号中选择最佳账号(优先级 + LRU
// 返回 nil 表示无可用账号。
//
// selectBestAccount selects the best account from candidates (priority + LRU).
// Returns nil if no available account.
func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
var selected *Account
for i := range accounts {
acc := &accounts[i]
// 跳过被排除的账号
// Skip excluded accounts
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// 调度器快照可能暂时过时,这里重新检查可调度性和平台
// Scheduler snapshots can be temporarily stale; re-check schedulability and platform
if !acc.IsSchedulable() || !acc.IsOpenAI() {
continue
}
// 检查模型支持
// Check model support
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
}
// 选择优先级最高且最久未使用的账号
// Select highest priority and least recently used
if selected == nil {
selected = acc
continue
}
if s.isBetterAccount(acc, selected) {
selected = acc
}
}
return selected
}
// isBetterAccount 判断 candidate 是否比 current 更优。
// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先,其次是最久未使用的。
//
// isBetterAccount checks if candidate is better than current.
// Rules: higher priority (lower value) wins; same priority: never used > least recently used.
func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool {
// 优先级更高(数值更小)
// Higher priority (lower value)
if candidate.Priority < current.Priority {
return true
}
if candidate.Priority > current.Priority {
return false
}
// 同优先级,比较最后使用时间
// Same priority, compare last used time
switch {
case candidate.LastUsedAt == nil && current.LastUsedAt != nil:
// candidate 从未使用,优先
return true
case candidate.LastUsedAt != nil && current.LastUsedAt == nil:
// current 从未使用,保持
return false
case candidate.LastUsedAt == nil && current.LastUsedAt == nil:
// 都未使用,保持
return false
default:
// 都使用过,选择最久未使用的
return candidate.LastUsedAt.Before(*current.LastUsedAt)
}
}
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
cfg := s.schedulingConfig()
@@ -325,29 +494,35 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.getSchedulableAccount(ctx, accountID)
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
if err == nil {
clearSticky := shouldClearStickySession(account)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
}
if !clearSticky && account.IsSchedulable() && account.IsOpenAI() &&
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
}
}
@@ -778,7 +953,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if account.Type == AccountTypeOAuth {
if snapshot := extractCodexUsageHeaders(resp.Header); snapshot != nil {
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
}
}
@@ -1576,8 +1751,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
return nil
}
// extractCodexUsageHeaders extracts Codex usage limits from response headers
func extractCodexUsageHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
// ParseCodexRateLimitHeaders extracts Codex usage limits from response headers.
// Exported for use in ratelimit_service when handling OpenAI 429 responses.
func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
snapshot := &OpenAICodexUsageSnapshot{}
hasData := false
@@ -1651,6 +1827,8 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
// Convert snapshot to map for merging into Extra
updates := make(map[string]any)
// Save raw primary/secondary fields for debugging/tracing
if snapshot.PrimaryUsedPercent != nil {
updates["codex_primary_used_percent"] = *snapshot.PrimaryUsedPercent
}
@@ -1674,109 +1852,25 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
}
updates["codex_usage_updated_at"] = snapshot.UpdatedAt
// Normalize to canonical 5h/7d fields based on window_minutes
// This fixes the issue where OpenAI's primary/secondary naming is reversed
// Strategy: Compare the two windows and assign the smaller one to 5h, larger one to 7d
// IMPORTANT: We can only reliably determine window type from window_minutes field
// The reset_after_seconds is remaining time, not window size, so it cannot be used for comparison
var primaryWindowMins, secondaryWindowMins int
var hasPrimaryWindow, hasSecondaryWindow bool
// Only use window_minutes for reliable window size comparison
if snapshot.PrimaryWindowMinutes != nil {
primaryWindowMins = *snapshot.PrimaryWindowMinutes
hasPrimaryWindow = true
}
if snapshot.SecondaryWindowMinutes != nil {
secondaryWindowMins = *snapshot.SecondaryWindowMinutes
hasSecondaryWindow = true
}
// Determine which is 5h and which is 7d
var use5hFromPrimary, use7dFromPrimary bool
var use5hFromSecondary, use7dFromSecondary bool
if hasPrimaryWindow && hasSecondaryWindow {
// Both window sizes known: compare and assign smaller to 5h, larger to 7d
if primaryWindowMins < secondaryWindowMins {
use5hFromPrimary = true
use7dFromSecondary = true
} else {
use5hFromSecondary = true
use7dFromPrimary = true
// Normalize to canonical 5h/7d fields
if normalized := snapshot.Normalize(); normalized != nil {
if normalized.Used5hPercent != nil {
updates["codex_5h_used_percent"] = *normalized.Used5hPercent
}
} else if hasPrimaryWindow {
// Only primary window size known: classify by absolute threshold
if primaryWindowMins <= 360 {
use5hFromPrimary = true
} else {
use7dFromPrimary = true
if normalized.Reset5hSeconds != nil {
updates["codex_5h_reset_after_seconds"] = *normalized.Reset5hSeconds
}
} else if hasSecondaryWindow {
// Only secondary window size known: classify by absolute threshold
if secondaryWindowMins <= 360 {
use5hFromSecondary = true
} else {
use7dFromSecondary = true
if normalized.Window5hMinutes != nil {
updates["codex_5h_window_minutes"] = *normalized.Window5hMinutes
}
} else {
// No window_minutes available: cannot reliably determine window types
// Fall back to legacy assumption (may be incorrect)
// Assume primary=7d, secondary=5h based on historical observation
if snapshot.SecondaryUsedPercent != nil || snapshot.SecondaryResetAfterSeconds != nil || snapshot.SecondaryWindowMinutes != nil {
use5hFromSecondary = true
if normalized.Used7dPercent != nil {
updates["codex_7d_used_percent"] = *normalized.Used7dPercent
}
if snapshot.PrimaryUsedPercent != nil || snapshot.PrimaryResetAfterSeconds != nil || snapshot.PrimaryWindowMinutes != nil {
use7dFromPrimary = true
if normalized.Reset7dSeconds != nil {
updates["codex_7d_reset_after_seconds"] = *normalized.Reset7dSeconds
}
}
// Write canonical 5h fields
if use5hFromPrimary {
if snapshot.PrimaryUsedPercent != nil {
updates["codex_5h_used_percent"] = *snapshot.PrimaryUsedPercent
}
if snapshot.PrimaryResetAfterSeconds != nil {
updates["codex_5h_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
}
if snapshot.PrimaryWindowMinutes != nil {
updates["codex_5h_window_minutes"] = *snapshot.PrimaryWindowMinutes
}
} else if use5hFromSecondary {
if snapshot.SecondaryUsedPercent != nil {
updates["codex_5h_used_percent"] = *snapshot.SecondaryUsedPercent
}
if snapshot.SecondaryResetAfterSeconds != nil {
updates["codex_5h_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
}
if snapshot.SecondaryWindowMinutes != nil {
updates["codex_5h_window_minutes"] = *snapshot.SecondaryWindowMinutes
}
}
// Write canonical 7d fields
if use7dFromPrimary {
if snapshot.PrimaryUsedPercent != nil {
updates["codex_7d_used_percent"] = *snapshot.PrimaryUsedPercent
}
if snapshot.PrimaryResetAfterSeconds != nil {
updates["codex_7d_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
}
if snapshot.PrimaryWindowMinutes != nil {
updates["codex_7d_window_minutes"] = *snapshot.PrimaryWindowMinutes
}
} else if use7dFromSecondary {
if snapshot.SecondaryUsedPercent != nil {
updates["codex_7d_used_percent"] = *snapshot.SecondaryUsedPercent
}
if snapshot.SecondaryResetAfterSeconds != nil {
updates["codex_7d_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
}
if snapshot.SecondaryWindowMinutes != nil {
updates["codex_7d_window_minutes"] = *snapshot.SecondaryWindowMinutes
if normalized.Window7dMinutes != nil {
updates["codex_7d_window_minutes"] = *normalized.Window7dMinutes
}
}

View File

@@ -21,19 +21,50 @@ type stubOpenAIAccountRepo struct {
accounts []Account
}
func (r stubOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) {
for i := range r.accounts {
if r.accounts[i].ID == id {
return &r.accounts[i], nil
}
}
return nil, errors.New("account not found")
}
func (r stubOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
return append([]Account(nil), r.accounts...), nil
var result []Account
for _, acc := range r.accounts {
if acc.Platform == platform {
result = append(result, acc)
}
}
return result, nil
}
func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
return append([]Account(nil), r.accounts...), nil
var result []Account
for _, acc := range r.accounts {
if acc.Platform == platform {
result = append(result, acc)
}
}
return result, nil
}
type stubConcurrencyCache struct {
ConcurrencyCache
loadBatchErr error
loadMap map[int64]*AccountLoadInfo
acquireResults map[int64]bool
waitCounts map[int64]int
skipDefaultLoad bool
}
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
if c.acquireResults != nil {
if result, ok := c.acquireResults[accountID]; ok {
return result, nil
}
}
return true, nil
}
@@ -42,8 +73,25 @@ func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID
}
func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if c.loadBatchErr != nil {
return nil, c.loadBatchErr
}
out := make(map[int64]*AccountLoadInfo, len(accounts))
if c.skipDefaultLoad && c.loadMap != nil {
for _, acc := range accounts {
if load, ok := c.loadMap[acc.ID]; ok {
out[acc.ID] = load
}
}
return out, nil
}
for _, acc := range accounts {
if c.loadMap != nil {
if load, ok := c.loadMap[acc.ID]; ok {
out[acc.ID] = load
continue
}
}
out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
}
return out, nil
@@ -92,6 +140,51 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
}
}
func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if c.waitCounts != nil {
if count, ok := c.waitCounts[accountID]; ok {
return count, nil
}
}
return 0, nil
}
type stubGatewayCache struct {
sessionBindings map[string]int64
deletedSessions map[string]int
}
func (c *stubGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
if id, ok := c.sessionBindings[sessionHash]; ok {
return id, nil
}
return 0, errors.New("not found")
}
func (c *stubGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
if c.sessionBindings == nil {
c.sessionBindings = make(map[string]int64)
}
c.sessionBindings[sessionHash] = accountID
return nil
}
func (c *stubGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
return nil
}
func (c *stubGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
if c.sessionBindings == nil {
return nil
}
if c.deletedSessions == nil {
c.deletedSessions = make(map[string]int)
}
c.deletedSessions[sessionHash]++
delete(c.sessionBindings, sessionHash)
return nil
}
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)
@@ -182,6 +275,515 @@ func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurre
}
}
func TestOpenAISelectAccountForModelWithExclusions_StickyUnschedulableClearsSession(t *testing.T) {
sessionHash := "session-1"
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusDisabled, Schedulable: true, Concurrency: 1},
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1},
},
}
cache := &stubGatewayCache{
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
}
if acc == nil || acc.ID != 2 {
t.Fatalf("expected account 2, got %+v", acc)
}
if cache.deletedSessions["openai:"+sessionHash] != 1 {
t.Fatalf("expected sticky session to be deleted")
}
if cache.sessionBindings["openai:"+sessionHash] != 2 {
t.Fatalf("expected sticky session to bind to account 2")
}
}
func TestOpenAISelectAccountWithLoadAwareness_StickyUnschedulableClearsSession(t *testing.T) {
sessionHash := "session-2"
groupID := int64(1)
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusDisabled, Schedulable: true, Concurrency: 1},
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1},
},
}
cache := &stubGatewayCache{
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, sessionHash, "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
t.Fatalf("expected account 2, got %+v", selection)
}
if cache.deletedSessions["openai:"+sessionHash] != 1 {
t.Fatalf("expected sticky session to be deleted")
}
if cache.sessionBindings["openai:"+sessionHash] != 2 {
t.Fatalf("expected sticky session to bind to account 2")
}
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAISelectAccountForModelWithExclusions_NoModelSupport(t *testing.T) {
repo := stubOpenAIAccountRepo{
accounts: []Account{
{
ID: 1,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{"model_mapping": map[string]any{"gpt-3.5-turbo": "gpt-3.5-turbo"}},
},
},
}
cache := &stubGatewayCache{}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil)
if err == nil {
t.Fatalf("expected error for unsupported model")
}
if acc != nil {
t.Fatalf("expected nil account for unsupported model")
}
if !strings.Contains(err.Error(), "supporting model") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorFallback(t *testing.T) {
groupID := int64(1)
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
},
}
cache := &stubGatewayCache{}
concurrencyCache := stubConcurrencyCache{
loadBatchErr: errors.New("load batch failed"),
}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "fallback", "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.Account == nil {
t.Fatalf("expected selection")
}
if selection.Account.ID != 2 {
t.Fatalf("expected account 2, got %d", selection.Account.ID)
}
if cache.sessionBindings["openai:fallback"] != 2 {
t.Fatalf("expected sticky session updated")
}
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAISelectAccountWithLoadAwareness_NoSlotFallbackWait(t *testing.T) {
groupID := int64(1)
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
},
}
cache := &stubGatewayCache{}
concurrencyCache := stubConcurrencyCache{
acquireResults: map[int64]bool{1: false},
loadMap: map[int64]*AccountLoadInfo{
1: {AccountID: 1, LoadRate: 10},
},
}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.WaitPlan == nil {
t.Fatalf("expected wait plan fallback")
}
if selection.Account == nil || selection.Account.ID != 1 {
t.Fatalf("expected account 1")
}
}
func TestOpenAISelectAccountForModelWithExclusions_SetsStickyBinding(t *testing.T) {
sessionHash := "bind"
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
},
}
cache := &stubGatewayCache{}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
}
if acc == nil || acc.ID != 1 {
t.Fatalf("expected account 1")
}
if cache.sessionBindings["openai:"+sessionHash] != 1 {
t.Fatalf("expected sticky session binding")
}
}
func TestOpenAISelectAccountWithLoadAwareness_StickyWaitPlan(t *testing.T) {
sessionHash := "sticky-wait"
groupID := int64(1)
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
},
}
cache := &stubGatewayCache{
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
}
concurrencyCache := stubConcurrencyCache{
acquireResults: map[int64]bool{1: false},
waitCounts: map[int64]int{1: 0},
}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, sessionHash, "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.WaitPlan == nil {
t.Fatalf("expected sticky wait plan")
}
if selection.Account == nil || selection.Account.ID != 1 {
t.Fatalf("expected account 1")
}
}
func TestOpenAISelectAccountWithLoadAwareness_PrefersLowerLoad(t *testing.T) {
groupID := int64(1)
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
},
}
cache := &stubGatewayCache{}
concurrencyCache := stubConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
1: {AccountID: 1, LoadRate: 80},
2: {AccountID: 2, LoadRate: 10},
},
}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "load", "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
t.Fatalf("expected account 2")
}
if cache.sessionBindings["openai:load"] != 2 {
t.Fatalf("expected sticky session updated")
}
}
func TestOpenAISelectAccountForModelWithExclusions_StickyExcludedFallback(t *testing.T) {
sessionHash := "excluded"
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
},
}
cache := &stubGatewayCache{
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
}
excluded := map[int64]struct{}{1: {}}
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", excluded)
if err != nil {
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
}
if acc == nil || acc.ID != 2 {
t.Fatalf("expected account 2")
}
}
func TestOpenAISelectAccountForModelWithExclusions_StickyNonOpenAI(t *testing.T) {
sessionHash := "non-openai"
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
},
}
cache := &stubGatewayCache{
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
}
if acc == nil || acc.ID != 2 {
t.Fatalf("expected account 2")
}
}
func TestOpenAISelectAccountForModelWithExclusions_NoAccounts(t *testing.T) {
repo := stubOpenAIAccountRepo{accounts: []Account{}}
cache := &stubGatewayCache{}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "", nil)
if err == nil {
t.Fatalf("expected error for no accounts")
}
if acc != nil {
t.Fatalf("expected nil account")
}
if !strings.Contains(err.Error(), "no available OpenAI accounts") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestOpenAISelectAccountWithLoadAwareness_NoCandidates(t *testing.T) {
groupID := int64(1)
resetAt := time.Now().Add(1 * time.Hour)
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, RateLimitResetAt: &resetAt},
},
}
cache := &stubGatewayCache{}
concurrencyCache := stubConcurrencyCache{}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
if err == nil {
t.Fatalf("expected error for no candidates")
}
if selection != nil {
t.Fatalf("expected nil selection")
}
}
func TestOpenAISelectAccountWithLoadAwareness_AllFullWaitPlan(t *testing.T) {
groupID := int64(1)
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
},
}
cache := &stubGatewayCache{}
concurrencyCache := stubConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
1: {AccountID: 1, LoadRate: 100},
},
}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.WaitPlan == nil {
t.Fatalf("expected wait plan")
}
}
func TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorNoAcquire(t *testing.T) {
groupID := int64(1)
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
},
}
cache := &stubGatewayCache{}
concurrencyCache := stubConcurrencyCache{
loadBatchErr: errors.New("load batch failed"),
acquireResults: map[int64]bool{1: false},
}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.WaitPlan == nil {
t.Fatalf("expected wait plan")
}
}
func TestOpenAISelectAccountWithLoadAwareness_MissingLoadInfo(t *testing.T) {
groupID := int64(1)
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
},
}
cache := &stubGatewayCache{}
concurrencyCache := stubConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
1: {AccountID: 1, LoadRate: 50},
},
skipDefaultLoad: true,
}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
t.Fatalf("expected account 2")
}
}
func TestOpenAISelectAccountForModelWithExclusions_LeastRecentlyUsed(t *testing.T) {
oldTime := time.Now().Add(-2 * time.Hour)
newTime := time.Now().Add(-1 * time.Hour)
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1, LastUsedAt: &newTime},
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1, LastUsedAt: &oldTime},
},
}
cache := &stubGatewayCache{}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
}
if acc == nil || acc.ID != 2 {
t.Fatalf("expected account 2")
}
}
func TestOpenAISelectAccountWithLoadAwareness_PreferNeverUsed(t *testing.T) {
groupID := int64(1)
lastUsed := time.Now().Add(-1 * time.Hour)
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, LastUsedAt: &lastUsed},
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
},
}
cache := &stubGatewayCache{}
concurrencyCache := stubConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
1: {AccountID: 1, LoadRate: 10},
2: {AccountID: 2, LoadRate: 10},
},
}
svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
t.Fatalf("expected account 2")
}
}
func TestOpenAIStreamingTimeout(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{

View File

@@ -2,9 +2,10 @@ package service
import (
"context"
"fmt"
"net/http"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
)
@@ -35,12 +36,12 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
// Generate PKCE values
state, err := openai.GenerateState()
if err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err)
return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_STATE_FAILED", "failed to generate state: %v", err)
}
codeVerifier, err := openai.GenerateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_VERIFIER_FAILED", "failed to generate code verifier: %v", err)
}
codeChallenge := openai.GenerateCodeChallenge(codeVerifier)
@@ -48,14 +49,17 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
// Generate session ID
sessionID, err := openai.GenerateSessionID()
if err != nil {
return nil, fmt.Errorf("failed to generate session ID: %w", err)
return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_SESSION_FAILED", "failed to generate session ID: %v", err)
}
// Get proxy URL if specified
var proxyURL string
if proxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err == nil && proxy != nil {
if err != nil {
return nil, infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
}
if proxy != nil {
proxyURL = proxy.URL()
}
}
@@ -110,14 +114,17 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
// Get session
session, ok := s.sessionStore.Get(input.SessionID)
if !ok {
return nil, fmt.Errorf("session not found or expired")
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired")
}
// Get proxy URL
// Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL
proxyURL := session.ProxyURL
if input.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
if err == nil && proxy != nil {
if err != nil {
return nil, infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
}
if proxy != nil {
proxyURL = proxy.URL()
}
}
@@ -131,7 +138,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
// Exchange code for token
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL)
if err != nil {
return nil, fmt.Errorf("failed to exchange code: %w", err)
return nil, err
}
// Parse ID token to get user info
@@ -201,12 +208,12 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
// RefreshAccountToken refreshes token for an OpenAI account
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
if !account.IsOpenAI() {
return nil, fmt.Errorf("account is not an OpenAI account")
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account")
}
refreshToken := account.GetOpenAIRefreshToken()
if refreshToken == "" {
return nil, fmt.Errorf("no refresh token available")
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
}
var proxyURL string

View File

@@ -162,26 +162,37 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
return "", errors.New("access_token not found in credentials")
}
// 3. 存入缓存
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if p.tokenCache != nil {
ttl := 30 * time.Minute
if refreshFailed {
// 刷新失败时使用短 TTL避免失效 token 长时间缓存导致 401 抖动
ttl = time.Minute
slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
} else if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > openAITokenCacheSkew:
ttl = until - openAITokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if isStale && latestAccount != nil {
// 版本过时,使用 DB 中的最新 token
slog.Debug("openai_token_version_stale_use_latest", "account_id", account.ID)
accessToken = latestAccount.GetOpenAIAccessToken()
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found after version check")
}
// 不写入缓存,让下次请求重新处理
} else {
ttl := 30 * time.Minute
if refreshFailed {
// 刷新失败时使用短 TTL避免失效 token 长时间缓存导致 401 抖动
ttl = time.Minute
slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
} else if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > openAITokenCacheSkew:
ttl = until - openAITokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
slog.Warn("openai_token_cache_set_failed", "account_id", account.ID, "error", err)
}
}
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
slog.Warn("openai_token_cache_set_failed", "account_id", account.ID, "error", err)
}
}

View File

@@ -343,9 +343,48 @@ func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *A
// handle429 处理429限流错误
// 解析响应头获取重置时间,标记账号为限流状态
func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header, responseBody []byte) {
// 解析重置时间戳
// 1. OpenAI 平台:优先尝试解析 x-codex-* 响应头(用于 rate_limit_exceeded
if account.Platform == PlatformOpenAI {
if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil {
if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
}
slog.Info("openai_account_rate_limited", "account_id", account.ID, "reset_at", *resetAt)
return
}
}
// 2. 尝试从响应头解析重置时间Anthropic
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
// 3. 如果响应头没有尝试从响应体解析OpenAI usage_limit_reached, Gemini
if resetTimestamp == "" {
switch account.Platform {
case PlatformOpenAI:
// 尝试解析 OpenAI 的 usage_limit_reached 错误
if resetAt := parseOpenAIRateLimitResetTime(responseBody); resetAt != nil {
resetTime := time.Unix(*resetAt, 0)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
}
slog.Info("account_rate_limited", "account_id", account.ID, "platform", account.Platform, "reset_at", resetTime, "reset_in", time.Until(resetTime).Truncate(time.Second))
return
}
case PlatformGemini, PlatformAntigravity:
// 尝试解析 Gemini 格式(用于其他平台)
if resetAt := ParseGeminiRateLimitResetTime(responseBody); resetAt != nil {
resetTime := time.Unix(*resetAt, 0)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
}
slog.Info("account_rate_limited", "account_id", account.ID, "platform", account.Platform, "reset_at", resetTime, "reset_in", time.Until(resetTime).Truncate(time.Second))
return
}
}
// 没有重置时间使用默认5分钟
resetAt := time.Now().Add(5 * time.Minute)
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
@@ -356,6 +395,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
}
return
}
slog.Warn("rate_limit_no_reset_time", "account_id", account.ID, "platform", account.Platform, "using_default", "5m")
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
}
@@ -419,6 +459,108 @@ func (s *RateLimitService) shouldScopeClaudeSonnetRateLimit(account *Account, re
return strings.Contains(msg, "sonnet")
}
// calculateOpenAI429ResetTime 从 OpenAI 429 响应头计算正确的重置时间
// 返回 nil 表示无法从响应头中确定重置时间
func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time {
snapshot := ParseCodexRateLimitHeaders(headers)
if snapshot == nil {
return nil
}
normalized := snapshot.Normalize()
if normalized == nil {
return nil
}
now := time.Now()
// 判断哪个限制被触发used_percent >= 100
is7dExhausted := normalized.Used7dPercent != nil && *normalized.Used7dPercent >= 100
is5hExhausted := normalized.Used5hPercent != nil && *normalized.Used5hPercent >= 100
// 优先使用被触发限制的重置时间
if is7dExhausted && normalized.Reset7dSeconds != nil {
resetAt := now.Add(time.Duration(*normalized.Reset7dSeconds) * time.Second)
slog.Info("openai_429_7d_limit_exhausted", "reset_after_seconds", *normalized.Reset7dSeconds, "reset_at", resetAt)
return &resetAt
}
if is5hExhausted && normalized.Reset5hSeconds != nil {
resetAt := now.Add(time.Duration(*normalized.Reset5hSeconds) * time.Second)
slog.Info("openai_429_5h_limit_exhausted", "reset_after_seconds", *normalized.Reset5hSeconds, "reset_at", resetAt)
return &resetAt
}
// 都未达到100%但收到429使用较长的重置时间
var maxResetSecs int
if normalized.Reset7dSeconds != nil && *normalized.Reset7dSeconds > maxResetSecs {
maxResetSecs = *normalized.Reset7dSeconds
}
if normalized.Reset5hSeconds != nil && *normalized.Reset5hSeconds > maxResetSecs {
maxResetSecs = *normalized.Reset5hSeconds
}
if maxResetSecs > 0 {
resetAt := now.Add(time.Duration(maxResetSecs) * time.Second)
slog.Info("openai_429_using_max_reset", "max_reset_seconds", maxResetSecs, "reset_at", resetAt)
return &resetAt
}
return nil
}
// parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳
// OpenAI 的 usage_limit_reached 错误格式:
//
// {
// "error": {
// "message": "The usage limit has been reached",
// "type": "usage_limit_reached",
// "resets_at": 1769404154,
// "resets_in_seconds": 133107
// }
// }
func parseOpenAIRateLimitResetTime(body []byte) *int64 {
var parsed map[string]any
if err := json.Unmarshal(body, &parsed); err != nil {
return nil
}
errObj, ok := parsed["error"].(map[string]any)
if !ok {
return nil
}
// 检查是否为 usage_limit_reached 或 rate_limit_exceeded 类型
errType, _ := errObj["type"].(string)
if errType != "usage_limit_reached" && errType != "rate_limit_exceeded" {
return nil
}
// 优先使用 resets_atUnix 时间戳)
if resetsAt, ok := errObj["resets_at"].(float64); ok {
ts := int64(resetsAt)
return &ts
}
if resetsAt, ok := errObj["resets_at"].(string); ok {
if ts, err := strconv.ParseInt(resetsAt, 10, 64); err == nil {
return &ts
}
}
// 如果没有 resets_at尝试使用 resets_in_seconds
if resetsInSeconds, ok := errObj["resets_in_seconds"].(float64); ok {
ts := time.Now().Unix() + int64(resetsInSeconds)
return &ts
}
if resetsInSeconds, ok := errObj["resets_in_seconds"].(string); ok {
if sec, err := strconv.ParseInt(resetsInSeconds, 10, 64); err == nil {
ts := time.Now().Unix() + sec
return &ts
}
}
return nil
}
// handle529 处理529过载错误
// 根据配置设置过载冷却时间
func (s *RateLimitService) handle529(ctx context.Context, account *Account) {

View File

@@ -0,0 +1,364 @@
package service
import (
"net/http"
"testing"
"time"
)
func TestCalculateOpenAI429ResetTime_7dExhausted(t *testing.T) {
svc := &RateLimitService{}
// Simulate headers when 7d limit is exhausted (100% used)
// Primary = 7d (10080 minutes), Secondary = 5h (300 minutes)
headers := http.Header{}
headers.Set("x-codex-primary-used-percent", "100")
headers.Set("x-codex-primary-reset-after-seconds", "384607") // ~4.5 days
headers.Set("x-codex-primary-window-minutes", "10080") // 7 days
headers.Set("x-codex-secondary-used-percent", "3")
headers.Set("x-codex-secondary-reset-after-seconds", "17369") // ~4.8 hours
headers.Set("x-codex-secondary-window-minutes", "300") // 5 hours
before := time.Now()
resetAt := svc.calculateOpenAI429ResetTime(headers)
after := time.Now()
if resetAt == nil {
t.Fatal("expected non-nil resetAt")
}
// Should be approximately 384607 seconds from now
expectedDuration := 384607 * time.Second
minExpected := before.Add(expectedDuration)
maxExpected := after.Add(expectedDuration)
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
}
}
func TestCalculateOpenAI429ResetTime_5hExhausted(t *testing.T) {
svc := &RateLimitService{}
// Simulate headers when 5h limit is exhausted (100% used)
headers := http.Header{}
headers.Set("x-codex-primary-used-percent", "50")
headers.Set("x-codex-primary-reset-after-seconds", "500000")
headers.Set("x-codex-primary-window-minutes", "10080") // 7 days
headers.Set("x-codex-secondary-used-percent", "100")
headers.Set("x-codex-secondary-reset-after-seconds", "3600") // 1 hour
headers.Set("x-codex-secondary-window-minutes", "300") // 5 hours
before := time.Now()
resetAt := svc.calculateOpenAI429ResetTime(headers)
after := time.Now()
if resetAt == nil {
t.Fatal("expected non-nil resetAt")
}
// Should be approximately 3600 seconds from now
expectedDuration := 3600 * time.Second
minExpected := before.Add(expectedDuration)
maxExpected := after.Add(expectedDuration)
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
}
}
func TestCalculateOpenAI429ResetTime_NeitherExhausted_UsesMax(t *testing.T) {
svc := &RateLimitService{}
// Neither limit at 100%, should use the longer reset time
headers := http.Header{}
headers.Set("x-codex-primary-used-percent", "80")
headers.Set("x-codex-primary-reset-after-seconds", "100000")
headers.Set("x-codex-primary-window-minutes", "10080")
headers.Set("x-codex-secondary-used-percent", "90")
headers.Set("x-codex-secondary-reset-after-seconds", "5000")
headers.Set("x-codex-secondary-window-minutes", "300")
before := time.Now()
resetAt := svc.calculateOpenAI429ResetTime(headers)
after := time.Now()
if resetAt == nil {
t.Fatal("expected non-nil resetAt")
}
// Should use the max (100000 seconds from 7d window)
expectedDuration := 100000 * time.Second
minExpected := before.Add(expectedDuration)
maxExpected := after.Add(expectedDuration)
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
}
}
func TestCalculateOpenAI429ResetTime_NoCodexHeaders(t *testing.T) {
svc := &RateLimitService{}
// No codex headers at all
headers := http.Header{}
headers.Set("content-type", "application/json")
resetAt := svc.calculateOpenAI429ResetTime(headers)
if resetAt != nil {
t.Errorf("expected nil resetAt when no codex headers, got %v", resetAt)
}
}
func TestCalculateOpenAI429ResetTime_ReversedWindowOrder(t *testing.T) {
svc := &RateLimitService{}
// Test when OpenAI sends primary as 5h and secondary as 7d (reversed)
headers := http.Header{}
headers.Set("x-codex-primary-used-percent", "100") // This is 5h
headers.Set("x-codex-primary-reset-after-seconds", "3600") // 1 hour
headers.Set("x-codex-primary-window-minutes", "300") // 5 hours - smaller!
headers.Set("x-codex-secondary-used-percent", "50")
headers.Set("x-codex-secondary-reset-after-seconds", "500000")
headers.Set("x-codex-secondary-window-minutes", "10080") // 7 days - larger!
before := time.Now()
resetAt := svc.calculateOpenAI429ResetTime(headers)
after := time.Now()
if resetAt == nil {
t.Fatal("expected non-nil resetAt")
}
// Should correctly identify that primary is 5h (smaller window) and use its reset time
expectedDuration := 3600 * time.Second
minExpected := before.Add(expectedDuration)
maxExpected := after.Add(expectedDuration)
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
}
}
func TestNormalizedCodexLimits(t *testing.T) {
// Test the Normalize() method directly
pUsed := 100.0
pReset := 384607
pWindow := 10080
sUsed := 3.0
sReset := 17369
sWindow := 300
snapshot := &OpenAICodexUsageSnapshot{
PrimaryUsedPercent: &pUsed,
PrimaryResetAfterSeconds: &pReset,
PrimaryWindowMinutes: &pWindow,
SecondaryUsedPercent: &sUsed,
SecondaryResetAfterSeconds: &sReset,
SecondaryWindowMinutes: &sWindow,
}
normalized := snapshot.Normalize()
if normalized == nil {
t.Fatal("expected non-nil normalized")
}
// Primary has larger window (10080 > 300), so primary should be 7d
if normalized.Used7dPercent == nil || *normalized.Used7dPercent != 100.0 {
t.Errorf("expected Used7dPercent=100, got %v", normalized.Used7dPercent)
}
if normalized.Reset7dSeconds == nil || *normalized.Reset7dSeconds != 384607 {
t.Errorf("expected Reset7dSeconds=384607, got %v", normalized.Reset7dSeconds)
}
if normalized.Used5hPercent == nil || *normalized.Used5hPercent != 3.0 {
t.Errorf("expected Used5hPercent=3, got %v", normalized.Used5hPercent)
}
if normalized.Reset5hSeconds == nil || *normalized.Reset5hSeconds != 17369 {
t.Errorf("expected Reset5hSeconds=17369, got %v", normalized.Reset5hSeconds)
}
}
func TestNormalizedCodexLimits_OnlyPrimaryData(t *testing.T) {
// Test when only primary has data, no window_minutes
pUsed := 80.0
pReset := 50000
snapshot := &OpenAICodexUsageSnapshot{
PrimaryUsedPercent: &pUsed,
PrimaryResetAfterSeconds: &pReset,
// No window_minutes, no secondary data
}
normalized := snapshot.Normalize()
if normalized == nil {
t.Fatal("expected non-nil normalized")
}
// Legacy assumption: primary=7d, secondary=5h
if normalized.Used7dPercent == nil || *normalized.Used7dPercent != 80.0 {
t.Errorf("expected Used7dPercent=80, got %v", normalized.Used7dPercent)
}
if normalized.Reset7dSeconds == nil || *normalized.Reset7dSeconds != 50000 {
t.Errorf("expected Reset7dSeconds=50000, got %v", normalized.Reset7dSeconds)
}
// Secondary (5h) should be nil
if normalized.Used5hPercent != nil {
t.Errorf("expected Used5hPercent=nil, got %v", *normalized.Used5hPercent)
}
if normalized.Reset5hSeconds != nil {
t.Errorf("expected Reset5hSeconds=nil, got %v", *normalized.Reset5hSeconds)
}
}
func TestNormalizedCodexLimits_OnlySecondaryData(t *testing.T) {
// Test when only secondary has data, no window_minutes
sUsed := 60.0
sReset := 3000
snapshot := &OpenAICodexUsageSnapshot{
SecondaryUsedPercent: &sUsed,
SecondaryResetAfterSeconds: &sReset,
// No window_minutes, no primary data
}
normalized := snapshot.Normalize()
if normalized == nil {
t.Fatal("expected non-nil normalized")
}
// Legacy assumption: primary=7d, secondary=5h
// So secondary goes to 5h
if normalized.Used5hPercent == nil || *normalized.Used5hPercent != 60.0 {
t.Errorf("expected Used5hPercent=60, got %v", normalized.Used5hPercent)
}
if normalized.Reset5hSeconds == nil || *normalized.Reset5hSeconds != 3000 {
t.Errorf("expected Reset5hSeconds=3000, got %v", normalized.Reset5hSeconds)
}
// Primary (7d) should be nil
if normalized.Used7dPercent != nil {
t.Errorf("expected Used7dPercent=nil, got %v", *normalized.Used7dPercent)
}
}
func TestNormalizedCodexLimits_BothDataNoWindowMinutes(t *testing.T) {
// Test when both have data but no window_minutes
pUsed := 100.0
pReset := 400000
sUsed := 50.0
sReset := 10000
snapshot := &OpenAICodexUsageSnapshot{
PrimaryUsedPercent: &pUsed,
PrimaryResetAfterSeconds: &pReset,
SecondaryUsedPercent: &sUsed,
SecondaryResetAfterSeconds: &sReset,
// No window_minutes
}
normalized := snapshot.Normalize()
if normalized == nil {
t.Fatal("expected non-nil normalized")
}
// Legacy assumption: primary=7d, secondary=5h
if normalized.Used7dPercent == nil || *normalized.Used7dPercent != 100.0 {
t.Errorf("expected Used7dPercent=100, got %v", normalized.Used7dPercent)
}
if normalized.Reset7dSeconds == nil || *normalized.Reset7dSeconds != 400000 {
t.Errorf("expected Reset7dSeconds=400000, got %v", normalized.Reset7dSeconds)
}
if normalized.Used5hPercent == nil || *normalized.Used5hPercent != 50.0 {
t.Errorf("expected Used5hPercent=50, got %v", normalized.Used5hPercent)
}
if normalized.Reset5hSeconds == nil || *normalized.Reset5hSeconds != 10000 {
t.Errorf("expected Reset5hSeconds=10000, got %v", normalized.Reset5hSeconds)
}
}
func TestHandle429_AnthropicPlatformUnaffected(t *testing.T) {
// Verify that Anthropic platform accounts still use the original logic
// This test ensures we don't break existing Claude account rate limiting
svc := &RateLimitService{}
// Simulate Anthropic 429 headers
headers := http.Header{}
headers.Set("anthropic-ratelimit-unified-reset", "1737820800") // A future Unix timestamp
// For Anthropic platform, calculateOpenAI429ResetTime should return nil
// because it only handles OpenAI platform
resetAt := svc.calculateOpenAI429ResetTime(headers)
// Should return nil since there are no x-codex-* headers
if resetAt != nil {
t.Errorf("expected nil for Anthropic headers, got %v", resetAt)
}
}
func TestCalculateOpenAI429ResetTime_UserProvidedScenario(t *testing.T) {
// This is the exact scenario from the user:
// codex_7d_used_percent: 100
// codex_7d_reset_after_seconds: 384607 (约4.5天后重置)
// codex_5h_used_percent: 3
// codex_5h_reset_after_seconds: 17369 (约4.8小时后重置)
svc := &RateLimitService{}
// Simulate headers matching user's data
// Note: We need to map the canonical 5h/7d back to primary/secondary
// Based on typical OpenAI behavior: primary=7d (larger window), secondary=5h (smaller window)
headers := http.Header{}
headers.Set("x-codex-primary-used-percent", "100")
headers.Set("x-codex-primary-reset-after-seconds", "384607")
headers.Set("x-codex-primary-window-minutes", "10080") // 7 days = 10080 minutes
headers.Set("x-codex-secondary-used-percent", "3")
headers.Set("x-codex-secondary-reset-after-seconds", "17369")
headers.Set("x-codex-secondary-window-minutes", "300") // 5 hours = 300 minutes
before := time.Now()
resetAt := svc.calculateOpenAI429ResetTime(headers)
after := time.Now()
if resetAt == nil {
t.Fatal("expected non-nil resetAt for user scenario")
}
// Should use the 7d reset time (384607 seconds) since 7d limit is exhausted (100%)
expectedDuration := 384607 * time.Second
minExpected := before.Add(expectedDuration)
maxExpected := after.Add(expectedDuration)
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
}
// Verify it's approximately 4.45 days (384607 seconds)
duration := resetAt.Sub(before)
actualDays := duration.Hours() / 24.0
// 384607 / 86400 = ~4.45 days
if actualDays < 4.4 || actualDays > 4.5 {
t.Errorf("expected ~4.45 days, got %.2f days", actualDays)
}
t.Logf("User scenario: reset_at=%v, duration=%.2f days", resetAt, actualDays)
}
func TestCalculateOpenAI429ResetTime_5MinFallbackWhenNoReset(t *testing.T) {
// Test that we return nil when there's used_percent but no reset_after_seconds
// This should cause the caller to use the default 5-minute fallback
svc := &RateLimitService{}
headers := http.Header{}
headers.Set("x-codex-primary-used-percent", "100")
// No reset_after_seconds!
resetAt := svc.calculateOpenAI429ResetTime(headers)
// Should return nil since there's no reset time available
if resetAt != nil {
t.Errorf("expected nil when no reset_after_seconds, got %v", resetAt)
}
}

View File

@@ -38,8 +38,9 @@ type SessionLimitCache interface {
GetActiveSessionCount(ctx context.Context, accountID int64) (int, error)
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
// idleTimeouts: 每个账号的空闲超时时间配置key 为 accountID若为 nil 或某账号不在其中,则使用默认超时
// 返回 map[accountID]count查询失败的账号不在 map 中
GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64, idleTimeouts map[int64]time.Duration) (map[int64]int, error)
// IsSessionActive 检查特定会话是否活跃(未过期)
IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error)

Some files were not shown because too many files have changed in this diff Show More