Compare commits

..

100 Commits

Author SHA1 Message Date
shaw
80cce858cb fix(frontend): 修复添加账号弹窗宽度和tooltip被截断问题
- 将弹窗尺寸从 lg 改为 xl,增加内容显示空间
- 修复 AI Studio tooltip 被弹窗边界截断的问题
  - 调整定位从 left-0 改为 right-0
  - 减小宽度从 w-[28rem] 改为 w-80
  - 提高 z-index 确保正确显示
2025-12-27 17:04:38 +08:00
shaw
0743652d92 Merge branch 'feature/ui-and-backend-improvements' 2025-12-27 16:33:57 +08:00
IanShaw027
96bec5c9b1 fix(test): 实现GetUserStatsAggregated方法以支持新的统计查询
- 在stubUsageLogRepo中实现GetUserStatsAggregated方法
- 根据userLogs计算统计数据而不是返回错误
- 修复类型转换问题(int转int64)
2025-12-27 16:20:59 +08:00
IanShaw027
cfeb6b8b14 fix(test): 添加缺失的GetUserStatsAggregated方法 2025-12-27 16:15:32 +08:00
IanShaw027
481310dea0 fix(test): 修复CI测试失败
- 修复gofmt格式问题
- 为stubUsageLogRepo添加缺失的GetApiKeyStatsAggregated方法
2025-12-27 16:12:06 +08:00
IanShaw027
ea2821d11d refactor(frontend): 优化用户视图和设置向导
- 改进API密钥管理界面
- 优化用户使用统计视图
- 完善初始化设置向导
2025-12-27 16:05:36 +08:00
IanShaw027
7a0de1765f refactor(frontend): 优化管理后台视图
- 改进账户管理视图
- 优化分组管理界面
- 完善代理管理功能
- 增强兑换码管理
- 改进订阅管理视图
- 优化使用统计展示
- 完善用户管理界面
2025-12-27 16:05:16 +08:00
IanShaw027
35b1bc3753 refactor(frontend): 优化格式化工具函数
- 改进数据格式化逻辑
- 增强工具函数可读性
2025-12-27 16:04:56 +08:00
IanShaw027
8d38788672 feat(frontend): 更新国际化翻译
- 新增英文翻译条目
- 完善中文翻译内容
- 改进多语言支持
2025-12-27 16:04:35 +08:00
IanShaw027
c615a4264d refactor(frontend): 优化通用组件
- 改进ConfirmDialog对话框组件
- 增强DataTable表格组件功能和响应式布局
- 优化EmptyState空状态组件
- 完善SubscriptionProgressMini订阅进度组件
2025-12-27 16:04:16 +08:00
IanShaw027
227d506c53 feat(backend): 增强使用统计和API密钥功能
- 优化使用统计处理逻辑
- 增强API密钥仓储层功能
- 改进账户使用服务
- 完善API契约测试覆盖
2025-12-27 16:03:57 +08:00
IanShaw027
36a86e9ab4 perf(backend): 优化数据库查询性能
- 合并多个独立查询为单个SQL查询
- 减少数据库往返次数
- 提升仪表板统计数据获取效率
2025-12-27 16:03:37 +08:00
shaw
f133b051dc fix: 修复TG通知workflow语法错误
- 移除if条件中对secrets的直接引用(GitHub Actions不支持)
- 改用shell脚本内部检查环境变量是否存在
2025-12-27 16:03:13 +08:00
shaw
7af1bdbf4c chore: workflow增加TG频道更新通知 2025-12-27 15:55:09 +08:00
shaw
016d7ef645 feat: 增强前端clipboard功能 2025-12-27 15:16:52 +08:00
shaw
f1e47291cd fix: 修复账号更新时分组绑定操作顺序导致的数据不一致问题
原逻辑先执行 Update 再验证 GroupIDs,如果验证失败会导致账号已更新但返回错误。
现改为先验证分组是否存在,再执行 Update 和 BindGroups。
2025-12-27 14:57:43 +08:00
shaw
d7e9ae38e4 Merge PR #49: feat: cc/codex/gemini 增加账号重试功能 2025-12-27 13:59:00 +08:00
程序猿MT
88be981afc feat: (#47)
golang 1.24-> 1.25
node 20 -> node 24
具体提升请查看官方文档

Co-authored-by: yangjianbo <yangjianbo@leagsoft.com>
2025-12-27 13:56:14 +08:00
IanShaw
3f92a43170 test: 完善 UsageLogRepo 测试 stub 的过滤逻辑 (#50) 2025-12-27 13:53:47 +08:00
shaw
2101f1d1c8 fix: 修复claude OAuth账户刷新token失败的bug 2025-12-27 13:50:35 +08:00
daodao97
f0f920e49f feat: cc/codex/gemini 增加账号重试 2025-12-27 12:27:47 +08:00
daodao97
95583fce83 feat: cc/codex support account retry 2025-12-27 12:05:38 +08:00
IanShaw
254f12543c feat(frontend): 前端界面优化与使用统计功能增强 (#46)
* feat(frontend): 前端界面优化与使用统计功能增强

主要改动:

1. 表格布局统一优化
   - 新增 TablePageLayout 通用布局组件
   - 统一所有管理页面的表格样式和交互
   - 优化 DataTable、Pagination、Select 等通用组件

2. 使用统计功能增强
   - 管理端: 添加完整的筛选和显示功能
   - 用户端: 完善 API Key 列显示
   - 后端: 优化使用统计数据结构和查询

3. 账户组件优化
   - 优化 AccountStatsModal、AccountUsageCell 等组件
   - 统一进度条和统计显示样式

4. 其他改进
   - 完善中英文国际化
   - 统一页面样式和交互体验
   - 优化各视图页面的响应式布局

* fix(test): 修复 stubUsageLogRepo.ListWithFilters 测试 stub

测试用例 GET /api/v1/usage 返回 500 是因为 stub 方法未实现,
现在正确返回基于 UserID 过滤的日志数据。

* feat(frontend): 统一日期时间显示格式

**主要改动**:
1. 增强 utils/format.ts:
   - 新增 formatDateOnly() - 格式: YYYY-MM-DD
   - 新增 formatDateTime() - 格式: YYYY-MM-DD HH:mm:ss

2. 全局替换视图中的格式化函数:
   - 移除各视图中的自定义 formatDate 函数
   - 统一导入使用 @/utils/format 中的函数
   - created_at/updated_at 使用 formatDateTime
   - expires_at 使用 formatDateOnly

3. 受影响的视图 (8个):
   - frontend/src/views/user/KeysView.vue
   - frontend/src/views/user/DashboardView.vue
   - frontend/src/views/user/UsageView.vue
   - frontend/src/views/user/RedeemView.vue
   - frontend/src/views/admin/UsersView.vue
   - frontend/src/views/admin/UsageView.vue
   - frontend/src/views/admin/RedeemView.vue
   - frontend/src/views/admin/SubscriptionsView.vue

**效果**:
- 日期统一显示为 YYYY-MM-DD
- 时间统一显示为 YYYY-MM-DD HH:mm:ss
- 提升可维护性,避免格式不一致

* fix(frontend): 补充遗漏的时间格式化统一

**补充修复**(基于 code review 发现的遗漏):

1. 增强 utils/format.ts:
   - 新增 formatTime() - 格式: HH:mm

2. 修复 4 个遗漏的文件:
   - src/views/admin/UsersView.vue
     * 删除 formatExpiresAt(),改用 formatDateTime()
     * 修复订阅过期时间 tooltip 显示格式不一致问题

   - src/views/user/ProfileView.vue
     * 删除 formatMemberSince(),改用 formatDate(date, 'YYYY-MM')
     * 统一会员起始时间显示格式

   - src/views/user/SubscriptionsView.vue
     * 修改 formatExpirationDate() 使用 formatDateOnly()
     * 保留天数计算逻辑

   - src/components/account/AccountStatusIndicator.vue
     * 删除本地 formatTime(),改用 utils/format 中的统一函数
     * 修复 rate limit 和 overload 重置时间显示

**验证**:
- TypeScript 类型检查通过 ✓
- 前端构建成功 ✓
- 所有剩余的 toLocaleString() 都是数字格式化,属于正确用法 ✓

**效果**:
- 订阅过期时间统一为 YYYY-MM-DD HH:mm:ss
- 会员起始时间统一为 YYYY-MM
- 重置时间统一为 HH:mm
- 消除所有不规范的原生 locale 方法调用
2025-12-27 10:50:25 +08:00
IanShaw
cf8a64528c fix: 修复 Gemini API 认证和 /responses 端点路由问题 (#45)
* fix(middleware): 修复 Gemini API Key 认证中间件用户上下文类型错误

修复了 ApiKeyAuthWithSubscriptionGoogle 中间件中设置用户上下文时的类型错误。

**问题:**
- 中间件直接设置 `apiKey.User` 对象到上下文
- 导致 handler 中获取 `AuthSubject` 时类型断言失败
- 所有 Gemini v1beta 端点返回 500 "User context not found"

**修复:**
- 改为设置 `AuthSubject` 结构体,与 `api_key_auth.go` 保持一致
- 添加 `ContextKeyUserRole` 设置以完整支持角色检查

**影响范围:**
- Gemini v1beta API 端点 (generateContent, streamGenerateContent)
- 使用 Google API Key 认证的所有请求

**测试:**
- 验证 Gemini CLI 调用成功返回 200
- 确认用户上下文正确传递到 handler

* fix(web): 修复 /responses 端点被前端中间件拦截的问题

- 将 /responses 路径添加到 API 白名单,防止其被当作前端路由处理
- 修复 /responses 端点返回 HTML 而非 API 响应的 BUG
- 解决 codex CLI stream 在远程服务器上断开连接的问题

根本原因:
在 6c469b4 提交中添加了 /responses 路由,但未同步更新前端嵌入中间件
的 API 白名单,导致该路由被拦截并返回 index.html 而非 API 响应。
2025-12-27 10:50:15 +08:00
shaw
2b79c4e8b7 chore: home页面更新gemini为已支持 2025-12-26 23:13:00 +08:00
shaw
429f38d0c9 Merge PR #37: Add Gemini OAuth and Messages Compat Support 2025-12-26 22:42:34 +08:00
IanShaw027
2714be99a9 feat(test): 添加 Gemini 双响应格式支持
添加对两种 Gemini 响应格式的支持:
- AI Studio: `{"candidates": [...]}`
- Gemini CLI: `{"response": {"candidates": [...]}}`

通过 unwrap 逻辑自动检测并适配两种格式,确保账号测试功能
对所有 Gemini 账号类型都能正常工作。

合并 PR #43 的剩余功能到 PR #37
2025-12-26 22:31:12 +08:00
IanShaw027
d851818035 fix(lint): 修复 gofmt 格式问题
修复 golangci-lint 检查失败的问题:
- gemini_token_provider.go: 删除 import 后多余空行
- gemini_token_refresher.go: 删除 import 后多余空行

Fixes CI golangci-lint check for PR #37
2025-12-26 22:24:22 +08:00
IanShaw027
576bf4639c refactor: 统一使用 mergeMap 函数提升代码一致性
根据 Gemini CLI 代码审查建议:

## 修改内容
- 将 Gemini OAuth 同步中的 `mergeJSONB` 调用替换为 `mergeMap`
- 删除不再使用的 `mergeJSONB` 函数定义

## 原因
- 其他平台(OpenAI、Anthropic)的账户同步都使用 `mergeMap`
- `mergeJSONB` 是为旧的 `model.JSONB` 类型设计,与重构后的架构不一致
- 统一函数命名提高代码可读性和可维护性

## 影响范围
- backend/internal/service/crs_sync_service.go (4处替换)
- backend/internal/service/account.go (删除 mergeJSONB 函数)

## 验证
✓ 编译通过
✓ 功能逻辑无变化(mergeMap 和 mergeJSONB 实现相同)
2025-12-26 22:15:15 +08:00
IanShaw027
9db52838b5 fix(backend): 适配重构后的架构修复 Gemini OAuth 集成
## 主要修改

1. **移除 model 包引用**
   - 删除所有 `internal/model` 包的 import
   - 使用 service 包中的类型定义(Account, Platform常量等)

2. **修复类型转换**
   - JSONB → map[string]any
   - 添加 mergeJSONB 辅助函数
   - 添加 Account.IsGemini() 方法

3. **更新中间件调用**
   - GetUserFromContext → GetAuthSubjectFromContext
   - 适配新的并发控制签名(传递 ID 和 Concurrency 而不是完整对象)

4. **修复 handler 层**
   - 更新 gemini_v1beta_handler.go
   - 修正 billing 检查和 usage 记录

## 影响范围
- backend/internal/service/gemini_*.go
- backend/internal/service/account_test_service.go
- backend/internal/service/crs_sync_service.go
- backend/internal/handler/gemini_v1beta_handler.go
- backend/internal/handler/gateway_handler.go
- backend/internal/handler/admin/account_handler.go
2025-12-26 22:07:55 +08:00
IanShaw027
bfcd9501c2 merge: 合并 upstream/main 解决 PR #37 冲突
- 删除 backend/internal/model/account.go 符合重构方向
- 合并最新的项目结构重构
- 包含 SSE 格式解析修复
- 更新依赖和配置文件
2025-12-26 21:56:08 +08:00
September999999999
12252c6005 fix: 卸载时删除安装锁文件以支持重新安装 (#39)
- 卸载时自动删除 .installed 安装锁文件
- 新增 --purge 参数支持完全清理(包括配置目录)
- 交互模式下增加是否删除配置目录的确认提示
- 支持中英文消息
2025-12-26 21:32:22 +08:00
shaw
2d89f36687 Merge PR #42: fix(sse): 修复非标准 SSE 格式解析问题 2025-12-26 21:31:34 +08:00
shaw
3d608c2625 Merge branch 'refactor/redis-key-helpers' 2025-12-26 21:26:18 +08:00
shaw
739d0ee61e fix: admin handlers 添加 DTO 转换修复 JSON 序列化
修复 PR #36 合并后部分 admin handler 直接返回 service 层对象导致
JSON 字段名为 PascalCase 而非期望的 snake_case 问题。

修复内容:
- account_handler: Refresh 接口添加 dto.AccountFromService
- openai_oauth_handler: RefreshAccountToken/CreateAccountFromOAuth 添加 dto 转换
- subscription_handler: BulkAssign 添加 dto.BulkAssignResultFromService
- usage_handler: List 接口添加 dto.UsageLogFromService 转换
- 新增 dto.BulkAssignResult 类型和对应的 mapper 函数
2025-12-26 21:22:48 +08:00
shaw
22f07a7bb6 Merge PR #36: refactor: 调整项目结构为单向依赖 2025-12-26 20:08:26 +08:00
ianshaw
16eec4eb41 fix(sse): 修复非标准 SSE 格式解析问题
部分上游 API 返回的 SSE 格式不符合标准规范:
- 标准格式: `data: {...}`(冒号后有空格)
- 非标准格式: `data:{...}`(冒号后无空格)

使用预编译正则 `^data:\s*` 统一处理两种格式。
2025-12-26 03:49:55 -08:00
shaw
ecb2c5353c fix: 修复docker-compose.yml redis密码传递问题 2025-12-26 17:25:12 +08:00
Forest
06d5876b02 refactor: 封装 Redis key 生成函数 2025-12-26 16:47:44 +08:00
Forest
e5a77853b0 refactor: 调整项目结构为单向依赖 2025-12-26 16:45:40 +08:00
ianshaw
9780f0fd9d fix(backend): 修复 rebase 后的代码集成问题
- 更新 middleware import 路径到 internal/server/middleware
- 修复 api_key_auth_google.go 使用正确的 service 类型
- 更新 router.go 和 http.go 支持 Gemini v1beta 路由
- 在 routes/gateway.go 中添加 Gemini v1beta API 端点
- 在 routes/admin.go 中添加 Gemini OAuth 路由
- 更新 wire.go 添加 GeminiOAuthService cleanup
- 重新生成 wire_gen.go
2025-12-26 00:17:55 -08:00
ianshaw
3559830882 fix(service): 应用德摩根定律修复 staticcheck QF1001 警告 2025-12-26 00:11:04 -08:00
ianshaw
5594680130 docs(deploy): 说明 AI Studio OAuth Client 需发布为正式版本
README.md:
- 添加第 7 步:发布 OAuth 应用到正式版本
- 说明 Testing 模式限制(100 用户、7 天 token 过期)
- 说明 sensitive scope 可能需要 Google 审核

.env.example:
- 添加 OAuth Client 需发布为正式版本的说明
2025-12-26 00:11:04 -08:00
ianshaw
50855ec15f feat(i18n): 添加 AI Studio OAuth 配置状态相关文案
- 添加 aiStudioNotConfiguredShort/Tip/aiStudioNotConfigured 翻译
- 更新 needsProjectIdDesc/noProjectIdNeededDesc 描述更准确
2025-12-26 00:11:04 -08:00
ianshaw
f9f33e7b5c feat(frontend): UI 显示 AI Studio OAuth 配置状态
CreateAccountModal:
- 检测 AI Studio OAuth 是否可用(服务器配置了自定义 OAuth 客户端)
- 未配置时禁用 AI Studio 选项并显示提示
- 添加悬停提示说明配置要求

ReAuthAccountModal:
- 同步 AI Studio OAuth 可用性检测逻辑
- 未配置时自动回退到 Code Assist
2025-12-26 00:11:03 -08:00
ianshaw
1bec35999b feat(frontend): 添加 Gemini OAuth 能力查询 API
- 添加 GeminiOAuthCapabilities 类型定义
- 添加 getCapabilities API 函数
- useGeminiOAuth composable 导出 getCapabilities 方法
2025-12-26 00:11:03 -08:00
ianshaw
632318ad33 feat(backend): 添加 OAuth 能力查询接口,改进 OAuth 客户端选择逻辑
Handler 改进:
- 添加 GET /api/v1/admin/gemini/oauth/capabilities 接口
- 简化 GenerateAuthURL,redirect_uri 由服务层决定

Repository 改进:
- ExchangeCode/RefreshToken 根据 oauthType 选择正确的 OAuth 客户端
- Code Assist 始终使用内置客户端,AI Studio 使用用户配置的客户端
2025-12-26 00:11:03 -08:00
ianshaw
456e8984b0 feat(service): 改进 Gemini OAuth 服务层,区分 Code Assist 和 AI Studio 客户端
OAuth 服务改进:
- 添加 GetOAuthConfig 返回 AI Studio OAuth 可用性
- Code Assist 强制使用内置 Gemini CLI 客户端
- AI Studio OAuth 要求用户配置自定义 OAuth 客户端
- ExchangeCode/RefreshToken 接口添加 oauthType 参数
- 添加 unauthorized_client 错误的向后兼容重试逻辑

兼容层改进:
- 403 重试逻辑仅对 Code Assist OAuth 生效
- 添加 insufficient-scope 错误检测,避免无效重试
- 上游错误消息脱敏处理(隐藏 API key 等敏感信息)
- 改进错误提示,显示更多上游错误详情
2025-12-26 00:11:03 -08:00
ianshaw
eea949853a feat(geminicli): 添加内置 Gemini CLI OAuth 客户端常量和改进配置逻辑
- 添加 GeminiCLIOAuthClientID/Secret 常量(Gemini CLI 公开 OAuth 客户端)
- 更新 DefaultAIStudioScopes 使用 generative-language.retriever(符合 Google 文档)
- EffectiveOAuthConfig 支持自动回退到内置客户端
- 内置客户端自动过滤受限 scope(如 generative-language)
- 添加 scope 向后兼容性处理
2025-12-26 00:11:03 -08:00
ianshaw
85fd1e4a2c fix(backend): 移除对已删除 ports 包的依赖
适配 main 分支的 ports 目录删除重构:
- 将 ports 包中的接口移至 service 包
- 更新 repository 层的导入路径
2025-12-26 00:11:03 -08:00
ianshaw
6682d06c99 fix(backend): 修复 golangci-lint 报告的格式和代码规范问题
- gofmt: 修复 account_handler.go, models.go, gemini_messages_compat_service.go 的格式
- staticcheck ST1005: 将 error strings 改为小写开头
2025-12-26 00:11:03 -08:00
ianshaw
efa470efc7 fix(backend): 修复 golangci-lint 报告的问题
- gofmt: 修复代码格式问题
- errcheck: 处理 WriteString 和 Close 返回值
- staticcheck: 错误信息改为小写开头
- staticcheck: 移除无效的 nil 检查
- staticcheck: 使用 append 替换循环
- staticcheck: 使用无条件的 TrimPrefix
- ineffassign: 移除无效赋值
- unused: 移除未使用的 geminiOAuthService 字段
- 重新生成 wire_gen.go
2025-12-26 00:11:03 -08:00
ianshaw
79d1585250 docs(deploy): 更新部署配置和文档
- .env.example: 新增 Gemini OAuth 环境变量配置示例
- config.example.yaml: 新增 Gemini OAuth 配置示例
- README.md: 更新部署文档
- docker-compose.yml: 添加 Gemini OAuth 环境变量传递
2025-12-26 00:11:03 -08:00
ianshaw
2d1a15b196 feat(frontend): 添加 Gemini OAuth 类型国际化
- zh.ts: 添加中文翻译(Code Assist/AI Studio 选择等)
- en.ts: 添加英文翻译
2025-12-26 00:11:03 -08:00
ianshaw
09431cfc0b feat(frontend): 支持 Gemini OAuth 类型选择 (Code Assist/AI Studio)
- CreateAccountModal.vue: 新增 OAuth 类型选择 UI
- ReAuthAccountModal.vue: 重授权支持选择类型
- OAuthAuthorizationFlow.vue: 新增 Project ID 输入框
- AccountTestModal.vue: Gemini 模型默认选择优化
- useGeminiOAuth.ts: OAuth 逻辑参数变更
- gemini.ts: API 调用更新
2025-12-26 00:11:03 -08:00
ianshaw
46cb82bac0 feat(backend): 添加 Gemini V1beta Handler 和路由
- 新增 gemini_v1beta_handler.go: 代理原生 Google API 格式
- 更新 gemini_oauth_handler.go: 移除 redirectUri,新增 oauthType
- 更新 account_handler.go: 账户 Handler 增强
- 更新 router.go: 注册 v1beta 路由
- 更新 config.go: Gemini OAuth 通过环境变量配置
- 更新 wire_gen.go: 依赖注入
2025-12-26 00:11:03 -08:00
ianshaw
b2d71da2a2 feat(backend): 实现 Gemini AI Studio OAuth 和消息兼容服务
- gemini_oauth_service.go: 新增 AI Studio OAuth 类型支持
- gemini_token_provider.go: Token 提供器增强
- gemini_messages_compat_service.go: 支持 AI Studio 端点
- account_test_service.go: Gemini 账户可用性检测
- gateway_service.go: 网关服务适配
- openai_gateway_service.go: OpenAI 兼容层调整
2025-12-26 00:11:03 -08:00
ianshaw
2d6e1d26c0 feat(backend): 扩展 Gemini OAuth Repository 层
- 更新 gemini_oauth_client.go: 支持 AI Studio OAuth 客户端
- 更新 geminicli_codeassist_client.go: 适配新的认证流程
2025-12-26 00:11:03 -08:00
ianshaw
50734c5edc feat(backend): 添加 Google API Key 认证中间件
- 新增 api_key_auth_google.go: 支持 x-goog-api-key 格式认证
- 更新 api_key_auth.go: 适配 Gemini 原生 API 格式
2025-12-26 00:11:03 -08:00
ianshaw
040dc27ea5 feat(backend): 添加 Gemini/Google API 基础包
- 新增 pkg/gemini: 模型定义与回退列表
- 新增 pkg/googleapi: Google API 错误状态处理
- 新增 pkg/geminicli/models.go: CLI 模型结构
- 更新 constants.go: AI Studio 相关常量
- 更新 oauth.go: 支持 AI Studio OAuth 流程,凭据通过环境变量配置
2025-12-26 00:10:44 -08:00
ianshaw
d7090de0e0 chore: 更新 .gitignore 忽略 Go 编译缓存
添加 backend/.gocache/ 到忽略列表
2025-12-26 00:10:44 -08:00
ianshaw
cab681c7d1 chore(frontend): 更新项目配置和依赖
- 更新 package-lock.json 依赖版本
- 优化 Vite、PostCSS、Tailwind 配置
- 更新入口 HTML 文件
- 更新 TypeScript 构建缓存
2025-12-26 00:10:44 -08:00
ianshaw
01f990a5c9 style(frontend): 统一核心模块代码风格
- Composables: 优化 OAuth 相关 hooks 代码格式
- Stores: 规范状态管理模块格式
- Types: 统一类型定义格式
- Utils: 优化工具函数格式
- App.vue & style.css: 调整全局样式和主组件格式
2025-12-26 00:10:44 -08:00
ianshaw
5763f5ced3 style(frontend): 统一 Views 模块代码风格
- 移除语句末尾分号,规范代码格式
- 优化组件结构和类型定义
- 改进视图文档和示例
- 提升代码一致性
2025-12-26 00:10:44 -08:00
ianshaw
f79b0f0fad style(frontend): 统一 API 模块代码风格
- 移除所有语句末尾分号
- 统一对象属性尾随逗号格式
- 优化类型定义导入顺序
- 提升代码一致性和可读性
2025-12-26 00:10:44 -08:00
ianshaw
34183b527b feat(frontend): 添加 OAuth 回调路由
- 新增 /auth/callback 路由用于接收 OAuth 授权回调
- 优化路由配置和文档
- 统一代码格式
2025-12-26 00:10:44 -08:00
ianshaw
bceed08fc3 feat(frontend): 添加 Gemini 平台国际化支持
- 新增中文 Gemini OAuth 相关翻译(步骤说明、错误提示等)
- 新增英文 Gemini OAuth 相关翻译
- 添加 Gemini 账号类型、平台名称等基础翻译
- 优化代码格式
2025-12-26 00:10:44 -08:00
ianshaw
5deef27e1d style(frontend): 优化 Components 代码风格和结构
- 统一移除语句末尾分号,规范代码格式
- 优化组件类型定义和 props 声明
- 改进组件文档和示例代码
- 提升代码可读性和一致性
2025-12-26 00:10:01 -08:00
ianshaw
1ac8b1f03e feat(frontend): Components 集成 Gemini 账号支持
- CreateAccountModal: 添加 Gemini 平台选项和 OAuth 授权流程
- EditAccountModal: 支持 Gemini 账号编辑
- OAuthAuthorizationFlow: 新增 Gemini 平台 OAuth 流程处理(支持 state 参数)
- ReAuthAccountModal: 支持 Gemini 账号重新授权
- 优化代码格式和组件逻辑
2025-12-26 00:09:46 -08:00
ianshaw
0b30cc2b7e feat(frontend): 新增 Gemini OAuth 授权流程
- 新增 /admin/gemini API 接口封装(generateAuthUrl, exchangeCode)
- 新增 useGeminiOAuth composable 处理 Gemini OAuth 流程
- 新增 OAuthCallbackView 视图用于接收 OAuth 回调
- 支持 code/state 参数提取和 credentials 构建
2025-12-26 00:09:46 -08:00
ianshaw
03a8ae62e5 feat(backend): 完善 Gemini OAuth Token 处理
- 修复 account_handler 中 token 字段类型转换(int64 转 string)
- 增强 Account.GetCredential 支持多种数值类型(float64, int, json.Number 等)
- 添加 Account.IsGemini() 方法用于平台判断
- 优化 refresh_token 和 scope 的空值处理
2025-12-26 00:09:46 -08:00
ianshaw
e36fb98fb9 feat(handler): 添加 Gemini OAuth Handler 和完善依赖注入
- 新增 Gemini OAuth 授权处理器
- 扩展账号和网关处理器支持 Gemini
- 注册 Gemini 相关路由
- 更新 Wire 依赖注入配置(所有层)
- 更新 Docker Compose 配置
2025-12-26 00:09:46 -08:00
ianshaw
55258bf099 feat(service): 扩展 CRS 同步和定价服务支持 Gemini
- CRS 同步服务新增 Gemini 账号同步逻辑(+273行)
- 定价服务扩展 Gemini 模型定价计算(+99行)
- 更新 Token 刷新服务集成 Gemini
- 更新相关单元测试
2025-12-26 00:09:04 -08:00
ianshaw
dc109827b7 feat(service): 实现 Gemini OAuth 和 Token 管理服务
- 实现 OAuth 授权流程服务
- 添加 Token 提供者和自动刷新机制
- 实现 Gemini Messages API 兼容层
- 更新服务容器注册
2025-12-26 00:09:04 -08:00
ianshaw
71c28e436a feat(service): 定义 Gemini 服务端口接口
- 定义 OAuth 服务接口
- 定义 Token 缓存服务接口
- 定义 Code Assist 服务接口
2025-12-26 00:08:27 -08:00
ianshaw
2bafc28a9b feat(repository): 实现 Gemini OAuth 和 Token 缓存客户端
- 添加 Gemini OAuth 客户端实现
- 实现 Redis 基础的 Token 缓存
- 添加 gemini-cli Code Assist 客户端封装
2025-12-26 00:08:27 -08:00
ianshaw
aea48ae1ab feat(config): 新增 Gemini 配置项和 geminicli 核心包
- 添加 Gemini OAuth 配置结构
- 实现 geminicli 包(OAuth、Token、CodeAssist 类型)
- 更新配置示例文件
2025-12-26 00:08:27 -08:00
shaw
b3463769dc chore: 调整403重试次数跟间隔 2025-12-26 14:19:57 +08:00
shaw
d9e6cfc44d feat: apikey使用弹出适配codex分组 2025-12-26 13:46:40 +08:00
Forest
57fd172287 refactor: 调整 server 目录结构 2025-12-26 10:42:35 +08:00
NepetaLemon
8d7a497553 refactor: 自定义业务错误 (#33)
* refactor: 自定义业务错误

* refactor: 隐藏服务器错误与统一 panic 响应
2025-12-26 08:47:00 +08:00
shaw
b31698b9f2 fix: 修复账户代理ip编辑保存不生效的bug 2025-12-25 21:58:09 +08:00
Forest
eeaff85e47 refactor: 自定义业务错误 2025-12-25 21:06:40 +08:00
Forest
f51ad2e126 refactor: 删除 ports 目录 2025-12-25 17:15:01 +08:00
hi_yueban
f57f12c6cc fix: 修复 OpenAI 账号 5h/7d 使用限制显示错误的问题 (#30)
* fix: 修复 OpenAI 账号 5h/7d 使用限制显示错误的问题

问题描述:
- 账号管理页面中,OpenAI OAuth 账号的 5h 列显示 7 天的剩余时间
- 7d 列却显示几小时的剩余时间
- 根本原因: OpenAI 响应头中 primary/secondary 的实际含义与代码假设相反

修复方案:
1. 后端归一化 (openai_gateway_service.go):
   - 根据 window_minutes 动态判断哪个是 5h/7d 限制
   - 新增规范字段 codex_5h_* 和 codex_7d_*
   - 保留旧字段以兼容性

2. 前端适配 (AccountUsageCell.vue):
   - 优先使用新的规范字段
   - Fallback 到旧字段时基于 window_minutes 动态判断
   - 更新 computed 属性命名

3. 类型定义更新 (types/index.ts):
   - 添加新的规范字段定义
   - 更新注释说明实际语义由 window_minutes 决定

🤖 Generated with Claude Code and Codex collaboration

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-Authored-By: OpenAI Codex <noreply@openai.com>

* fix: 改进窗口判断逻辑,修复两个窗口都小于阈值时的bug

问题:
当两个窗口都小于360分钟时(如 primary=180分钟,secondary=300分钟),
之前的逻辑会导致:
- primary5h = true, secondary5h = true
- 5h 字段会使用 primary(错误)
- 7d 字段没有数据(bug)

修复方案:
改用比较策略:
1. 当两个窗口都存在时:较小的分配给5h,较大的分配给7d
2. 当只有一个窗口时:根据大小(<=360分钟)判断是5h还是7d
3. 确保数据不会丢失,逻辑更健壮

示例:
- Primary: 180分钟, Secondary: 300分钟
  → 5h 使用 Primary(180分钟), 7d 使用 Secondary(300分钟) ✓

🤖 Generated with Claude Code

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* fix: 修正窗口大小判断逻辑 - 不能用剩余时间判断窗口类型

**严重bug修复:**
之前的 fallback 逻辑错误地使用 reset_after_seconds 来判断窗口大小。

问题示例:
- 周限制(7d)剩余 2h → reset_after_seconds = 7200秒
- 5h限制 剩余 4h → reset_after_seconds = 14400秒
- 错误逻辑:7200/60 < 14400/60,把周限制当成5h 

根本问题:
- window_minutes = 窗口的总大小(300 or 10080)
- reset_after_seconds = 距离重置的剩余时间(变化的)
- 不能用剩余时间来判断窗口类型!

修复方案:
1. **只使用 window_minutes** 来判断窗口大小
2. 移除错误的 reset_after_seconds fallback
3. 如果 window_minutes 都不存在,使用传统假设
4. 添加详细注释说明这个陷阱

🤖 Generated with Claude Code

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* fix: 修复 lint 问题 - 改进 fallback 逻辑的变量赋值

问题:
第882-883行的简单布尔赋值可能触发 ineffassign 或 staticcheck 警告:
  use5hFromSecondary = snapshot.SecondaryUsedPercent != nil
  use7dFromPrimary = snapshot.PrimaryUsedPercent != nil

修复:
改用明确的 if 语句检查任意字段是否存在,更符合代码意图:
- 如果 secondary 的任意字段存在,将其视为 5h
- 如果 primary 的任意字段存在,将其视为 7d

这样逻辑更清晰,也避免了 lint 警告。

---------

Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-authored-by: OpenAI Codex <noreply@openai.com>
2025-12-25 17:00:02 +08:00
shaw
5fca2d10b9 fix: 修复image地址 2025-12-25 16:57:29 +08:00
shaw
8fbe1ad70d chore: 调整403重试次数跟间隔 2025-12-25 16:23:31 +08:00
Forest
25a304c231 test: 增加 repository 测试 2025-12-25 16:01:17 +08:00
刀刀
9d30ceae8d CC 400 返回具体错误信息 && 非 CC 请求时增加 system prompt (#26)
* feat: http 400 返回具体错误

* 更新 workflows

* 优化打包/docker 构建流程

* 400 是返回 原始错误 - json 格式

* feat: 非 cc请求时补充 system

* go mod tidy
2025-12-25 14:47:19 +08:00
IanShaw
60f6ed6bf6 feat: CRS 同步增强 - 自动刷新 OAuth token 和修复测试配置 (#27)
* fix(service): 修复 OpenAI Responses API 测试负载配置

- 所有账号类型统一添加 instructions 字段(不再仅限 OAuth)
- Responses API 要求所有请求必须包含 instructions 参数

* feat(crs-sync): CRS 同步时自动刷新 OAuth token 并保留完整 extra 字段

**核心功能**:
- CRSSyncService 注入 OAuth 服务依赖(Anthropic + OpenAI)
- 账号创建/更新后自动刷新 OAuth token,确保可用性
- 完整保留 CRS extra 字段,避免数据丢失

**Extra 字段增强**:
- 保留 CRS 所有原始 extra 字段
- 新增同步元数据: crs_account_id, crs_kind, crs_synced_at
- Claude 账号: 从 credentials 提取 org_uuid/account_uuid 到 extra
- OpenAI 账号: 映射 crs_email -> email

**Token 刷新逻辑**:
- 新增 refreshOAuthToken() 方法处理 Anthropic/OpenAI 平台
- 保留原有 credentials 字段,仅更新 token 相关字段
- 刷新失败静默处理,不中断同步流程

**依赖注入**:
- wire_gen.go: CRSSyncService 新增 oAuthService/openaiOAuthService

* style(crs-sync): 使用 switch 替代 if-else 修复 golangci-lint 警告

- 将 refreshOAuthToken 中的 if-else 改为 switch 语句
- 符合 staticcheck 规范
- 添加 default 分支处理未知平台
2025-12-25 14:45:17 +08:00
shaw
4a2f7d4a99 chore: CRS迁移功能增加版本提示 2025-12-25 10:57:04 +08:00
shaw
c19a393be9 Merge PR #24: feat: 添加账户同步与批量编辑功能
- 添加从 CRS 同步账户功能 (Claude OAuth/API Key, OpenAI OAuth/Responses)
- 添加批量编辑账户功能,支持 JSONB 字段智能合并
- 新增 CRSSyncService、BulkUpdate 仓储方法
- 前端新增 SyncFromCrsModal 和 BulkEditAccountModal 组件
2025-12-25 10:44:40 +08:00
ianshaw
938ffb002e style(frontend): format code with prettier
格式化前端业务代码,符合代码规范
- 统一代码风格
- 修复 ESLint 警告
2025-12-24 18:07:58 -08:00
ianshaw
372a01290b fix(backend): handle defer Close() errors in crs_sync_service
修复 golangci-lint 错误检查问题
- 使用匿名函数包装 defer Close() 并忽略错误
- 符合 Go 最佳实践
2025-12-24 17:58:47 -08:00
ianshaw
8b163ca49b chore: trigger CI after enabling Actions 2025-12-24 17:56:55 -08:00
ianshaw
d23810dc53 chore: trigger CI workflow 2025-12-24 17:54:43 -08:00
ianshaw
62ed5422dd feat(account): 优化批量更新实现,使用统一 SQL 合并 JSONB 字段
- 新增 BulkUpdate 仓储方法,使用单条 SQL 更新所有账户
- credentials/extra 使用 COALESCE(...) || ? 合并,只更新传入的 key
- name/proxy_id/concurrency/priority/status 只在提供时更新
- 分组绑定仍逐账号处理(需要独立操作)
- 前端优化:Base URL 留空则不修改,按勾选字段更新
- 完善 i18n 文案:说明留空不修改、批量更新行为
2025-12-24 17:16:19 -08:00
ianshaw
2e76302af7 feat(account): 添加批量编辑账户凭据功能并优化 CRS 同步
- 新增批量更新账户凭据接口(account_uuid/org_uuid/intercept_warmup_requests)
- 新增前端批量编辑模态框组件
- 优化 CRS 同步逻辑,改进 extra 字段处理
- 优化 CRS 同步 UI,添加更详细的结果展示
- 完善国际化文案(中英文)
2025-12-24 16:56:48 -08:00
ianshaw
6553828008 feat(account): 添加从 CRS 同步账户功能
- 添加账户同步 API 接口 (account_handler.go)
- 实现 CRS 同步服务 (crs_sync_service.go)
- 添加前端同步对话框组件 (SyncFromCrsModal.vue)
- 更新账户管理界面支持同步操作
- 添加账户仓库批量创建方法
- 添加中英文国际化翻译
- 更新依赖注入配置
2025-12-24 08:48:58 -08:00
ianshaw
adcb7bf00e chore: 更新 .gitignore 忽略配置文件并还原 Makefile
- 添加 backend/config.yaml 到 .gitignore(包含敏感信息)
- 添加 deploy/config.yaml 到 .gitignore(包含敏感信息)
- 添加 backend/.installed 到 .gitignore
- 还原 Makefile 到原始版本
2025-12-24 08:48:49 -08:00
349 changed files with 39405 additions and 11482 deletions

View File

@@ -17,9 +17,12 @@ jobs:
go-version-file: backend/go.mod
check-latest: true
cache: true
- name: Run tests
- name: Unit tests
working-directory: backend
run: go test ./...
run: make test-unit
- name: Integration tests
working-directory: backend
run: make test-integration
golangci-lint:
runs-on: ubuntu-latest

View File

@@ -85,6 +85,19 @@ jobs:
go-version: '1.24'
cache-dependency-path: backend/go.sum
# Docker setup for GoReleaser
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Fetch tags with annotations
run: |
# 确保获取完整的 annotated tag 信息
@@ -117,87 +130,74 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
TAG_MESSAGE: ${{ steps.tag_message.outputs.message }}
GITHUB_REPO_OWNER: ${{ github.repository_owner }}
GITHUB_REPO_NAME: ${{ github.event.repository.name }}
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
# ===========================================================================
# Docker Build and Push
# ===========================================================================
docker:
needs: [update-version, build-frontend]
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Download VERSION artifact
uses: actions/download-artifact@v4
with:
name: version-file
path: backend/cmd/server/
- name: Download frontend artifact
uses: actions/download-artifact@v4
with:
name: frontend-dist
path: backend/internal/web/dist/
# Extract version from tag
- name: Extract version
id: version
run: |
VERSION=${GITHUB_REF#refs/tags/v}
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "Version: $VERSION"
# Set up Docker Buildx for multi-platform builds
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
# Login to DockerHub
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# Extract metadata for Docker
- name: Extract Docker metadata
id: meta
uses: docker/metadata-action@v5
with:
images: |
weishaw/sub2api
tags: |
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=semver,pattern={{major}}
type=raw,value=latest,enable={{is_default_branch}}
# Build and push Docker image
- name: Build and push Docker image
uses: docker/build-push-action@v5
with:
context: .
file: ./Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
build-args: |
VERSION=${{ steps.version.outputs.version }}
COMMIT=${{ github.sha }}
DATE=${{ github.event.head_commit.timestamp }}
cache-from: type=gha
cache-to: type=gha,mode=max
# Update DockerHub description (optional)
# Update DockerHub description
- name: Update DockerHub description
uses: peter-evans/dockerhub-description@v4
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
repository: weishaw/sub2api
repository: ${{ secrets.DOCKERHUB_USERNAME }}/sub2api
short-description: "Sub2API - AI API Gateway Platform"
readme-filepath: ./deploy/DOCKER.md
# Send Telegram notification
- name: Send Telegram Notification
env:
TELEGRAM_BOT_TOKEN: ${{ secrets.TELEGRAM_BOT_TOKEN }}
TELEGRAM_CHAT_ID: ${{ secrets.TELEGRAM_CHAT_ID }}
continue-on-error: true
run: |
# 检查必要的环境变量
if [ -z "$TELEGRAM_BOT_TOKEN" ] || [ -z "$TELEGRAM_CHAT_ID" ]; then
echo "Telegram credentials not configured, skipping notification"
exit 0
fi
TAG_NAME=${GITHUB_REF#refs/tags/}
VERSION=${TAG_NAME#v}
REPO="${{ github.repository }}"
DOCKER_IMAGE="${{ secrets.DOCKERHUB_USERNAME }}/sub2api"
# 获取 tag message 内容
TAG_MESSAGE='${{ steps.tag_message.outputs.message }}'
# 限制消息长度Telegram 消息限制 4096 字符,预留空间给头尾固定内容)
if [ ${#TAG_MESSAGE} -gt 3500 ]; then
TAG_MESSAGE="${TAG_MESSAGE:0:3500}..."
fi
# 构建消息内容
MESSAGE="🚀 *Sub2API 新版本发布!*"$'\n'$'\n'
MESSAGE+="📦 版本号: \`${VERSION}\`"$'\n'$'\n'
# 添加更新内容
if [ -n "$TAG_MESSAGE" ]; then
MESSAGE+="${TAG_MESSAGE}"$'\n'$'\n'
fi
MESSAGE+="🐳 *Docker 部署:*"$'\n'
MESSAGE+="\`\`\`bash"$'\n'
MESSAGE+="docker pull ${DOCKER_IMAGE}:${TAG_NAME}"$'\n'
MESSAGE+="docker pull ${DOCKER_IMAGE}:latest"$'\n'
MESSAGE+="\`\`\`"$'\n'$'\n'
MESSAGE+="🔗 *相关链接:*"$'\n'
MESSAGE+="• [GitHub Release](https://github.com/${REPO}/releases/tag/${TAG_NAME})"$'\n'
MESSAGE+="• [Docker Hub](https://hub.docker.com/r/${DOCKER_IMAGE})"$'\n'$'\n'
MESSAGE+="#Sub2API #Release #${TAG_NAME//./_}"
# 发送消息
curl -s -X POST "https://api.telegram.org/bot${TELEGRAM_BOT_TOKEN}/sendMessage" \
-H "Content-Type: application/json" \
-d "$(jq -n \
--arg chat_id "${TELEGRAM_CHAT_ID}" \
--arg text "${MESSAGE}" \
'{
chat_id: $chat_id,
text: $text,
parse_mode: "Markdown",
disable_web_page_preview: true
}')"

10
.gitignore vendored
View File

@@ -21,6 +21,9 @@ coverage.html
# 依赖(使用 go mod
vendor/
# Go 编译缓存
backend/.gocache/
# ===================
# Node.js / Vue 前端
# ===================
@@ -92,6 +95,13 @@ backend/internal/web/dist/*
# 后端运行时缓存数据
backend/data/
# ===================
# 本地配置文件(包含敏感信息)
# ===================
backend/config.yaml
deploy/config.yaml
backend/.installed
# ===================
# 其他
# ===================

View File

@@ -52,10 +52,58 @@ changelog:
# 禁用自动 changelog完全使用 tag 消息
disable: true
# Docker images
dockers:
- id: amd64
goos: linux
goarch: amd64
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
dockerfile: Dockerfile.goreleaser
use: buildx
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.version={{ .Version }}"
- "--label=org.opencontainers.image.revision={{ .Commit }}"
- id: arm64
goos: linux
goarch: arm64
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
dockerfile: Dockerfile.goreleaser
use: buildx
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.version={{ .Version }}"
- "--label=org.opencontainers.image.revision={{ .Commit }}"
# Docker manifests for multi-arch support
docker_manifests:
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}"
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:latest"
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Major }}.{{ .Minor }}"
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Major }}"
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
release:
github:
owner: Wei-Shaw
name: sub2api
owner: "{{ .Env.GITHUB_REPO_OWNER }}"
name: "{{ .Env.GITHUB_REPO_NAME }}"
draft: false
prerelease: auto
name_template: "Sub2API {{.Version}}"
@@ -73,7 +121,7 @@ release:
**One-line install (Linux):**
```bash
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash
curl -sSL https://raw.githubusercontent.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }}/main/deploy/install.sh | sudo bash
```
**Manual download:**
@@ -81,5 +129,5 @@ release:
## 📚 Documentation
- [GitHub Repository](https://github.com/Wei-Shaw/sub2api)
- [Installation Guide](https://github.com/Wei-Shaw/sub2api/blob/main/deploy/README.md)
- [GitHub Repository](https://github.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }})
- [Installation Guide](https://github.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }}/blob/main/deploy/README.md)

View File

@@ -9,7 +9,7 @@
# -----------------------------------------------------------------------------
# Stage 1: Frontend Builder
# -----------------------------------------------------------------------------
FROM node:20-alpine AS frontend-builder
FROM node:24-alpine AS frontend-builder
WORKDIR /app/frontend
@@ -24,7 +24,7 @@ RUN npm run build
# -----------------------------------------------------------------------------
# Stage 2: Backend Builder
# -----------------------------------------------------------------------------
FROM golang:1.24-alpine AS backend-builder
FROM golang:1.25-alpine AS backend-builder
# Build arguments for version info (set by CI)
ARG VERSION=docker

40
Dockerfile.goreleaser Normal file
View File

@@ -0,0 +1,40 @@
# =============================================================================
# Sub2API Dockerfile for GoReleaser
# =============================================================================
# This Dockerfile is used by GoReleaser to build Docker images.
# It only packages the pre-built binary, no compilation needed.
# =============================================================================
FROM alpine:3.19
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
LABEL description="Sub2API - AI API Gateway Platform"
LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
# Install runtime dependencies
RUN apk add --no-cache \
ca-certificates \
tzdata \
curl \
&& rm -rf /var/cache/apk/*
# Create non-root user
RUN addgroup -g 1000 sub2api && \
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
WORKDIR /app
# Copy pre-built binary from GoReleaser
COPY sub2api /app/sub2api
# Create data directory
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
USER sub2api
EXPOSE 8080
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
ENTRYPOINT ["/app/sub2api"]

View File

@@ -19,15 +19,23 @@ linters:
files:
- "**/internal/service/**"
deny:
- pkg: sub2api/internal/repository
- pkg: github.com/Wei-Shaw/sub2api/internal/repository
desc: "service must not import repository"
- pkg: gorm.io/gorm
desc: "service must not import gorm"
- pkg: github.com/redis/go-redis/v9
desc: "service must not import redis"
handler-no-repository:
list-mode: original
files:
- "**/internal/handler/**"
deny:
- pkg: sub2api/internal/repository
- pkg: github.com/Wei-Shaw/sub2api/internal/repository
desc: "handler must not import repository"
- pkg: gorm.io/gorm
desc: "handler must not import gorm"
- pkg: github.com/redis/go-redis/v9
desc: "handler must not import redis"
errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default.

View File

@@ -1,4 +1,4 @@
.PHONY: wire build build-embed
.PHONY: wire build build-embed test-unit test-integration test-cover-integration clean-coverage
wire:
@echo "生成 Wire 代码..."
@@ -13,4 +13,21 @@ build:
build-embed:
@echo "构建后端(嵌入前端)..."
@go build -tags embed -o bin/server ./cmd/server
@echo "构建完成: bin/server (with embedded frontend)"
@echo "构建完成: bin/server (with embedded frontend)"
test-unit:
@go test -tags unit ./... -count=1
test-integration:
@go test -tags integration ./... -count=1 -race -parallel=8
test-cover-integration:
@echo "运行集成测试并生成覆盖率报告..."
@go test -tags=integration -cover -coverprofile=coverage.out -count=1 -race -parallel=8 ./...
@go tool cover -func=coverage.out | tail -1
@go tool cover -html=coverage.out -o coverage.html
@echo "覆盖率报告已生成: coverage.html"
clean-coverage:
@rm -f coverage.out coverage.html
@echo "覆盖率文件已清理"

View File

@@ -17,7 +17,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/middleware"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/setup"
"github.com/Wei-Shaw/sub2api/internal/web"
@@ -84,7 +84,7 @@ func main() {
func runSetupServer() {
r := gin.New()
r.Use(gin.Recovery())
r.Use(middleware.Recovery())
r.Use(middleware.CORS())
// Register setup routes

View File

@@ -4,18 +4,19 @@
package main
import (
"context"
"log"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/server"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"context"
"log"
"net/http"
"time"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
@@ -35,6 +36,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
// 业务层 ProviderSets
repository.ProviderSet,
service.ProviderSet,
middleware.ProviderSet,
handler.ProviderSet,
// 服务器层 ProviderSet
@@ -62,7 +64,12 @@ func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
func provideCleanup(
db *gorm.DB,
rdb *redis.Client,
services *service.Services,
tokenRefresh *service.TokenRefreshService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
oauth *service.OAuthService,
openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -74,23 +81,27 @@ func provideCleanup(
fn func() error
}{
{"TokenRefreshService", func() error {
services.TokenRefresh.Stop()
tokenRefresh.Stop()
return nil
}},
{"PricingService", func() error {
services.Pricing.Stop()
pricing.Stop()
return nil
}},
{"EmailQueueService", func() error {
services.EmailQueue.Stop()
emailQueue.Stop()
return nil
}},
{"OAuthService", func() error {
services.OAuth.Stop()
oauth.Stop()
return nil
}},
{"OpenAIOAuthService", func() error {
services.OpenAIOAuth.Stop()
openaiOAuth.Stop()
return nil
}},
{"GeminiOAuthService", func() error {
geminiOAuth.Stop()
return nil
}},
{"Redis", func() error {

View File

@@ -14,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/server"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
@@ -47,8 +48,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
turnstileService := service.NewTurnstileService(settingService, turnstileVerifier)
emailQueueService := service.ProvideEmailQueueService(emailService)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
authHandler := handler.NewAuthHandler(authService)
userService := service.NewUserService(userRepository)
authHandler := handler.NewAuthHandler(authService, userService)
userHandler := handler.NewUserHandler(userService)
apiKeyRepository := repository.NewApiKeyRepository(db)
groupRepository := repository.NewGroupRepository(db)
@@ -79,16 +80,23 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
openAIOAuthClient := repository.NewOpenAIOAuthClient()
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig)
rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher)
geminiTokenCache := repository.NewGeminiTokenCache(client)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
httpUpstream := repository.NewHTTPUpstream(configConfig)
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, httpUpstream)
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, httpUpstream)
concurrencyCache := repository.NewConcurrencyCache(client)
concurrencyService := service.NewConcurrencyService(concurrencyCache)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
proxyHandler := admin.NewProxyHandler(adminService)
adminRedeemHandler := admin.NewRedeemHandler(adminService)
settingHandler := admin.NewSettingHandler(settingService, emailService)
@@ -99,7 +107,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
systemHandler := handler.ProvideSystemHandler(updateService)
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
gatewayCache := repository.NewGatewayCache(client)
pricingRemoteClient := repository.NewPricingRemoteClient()
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
@@ -110,59 +118,19 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityCache := repository.NewIdentityCache(client)
identityService := service.NewIdentityService(identityCache)
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream)
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, userService, concurrencyService, billingCacheService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
groupService := service.NewGroupService(groupRepository)
accountService := service.NewAccountService(accountRepository, groupRepository)
proxyService := service.NewProxyService(proxyRepository)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, configConfig)
services := &service.Services{
Auth: authService,
User: userService,
ApiKey: apiKeyService,
Group: groupService,
Account: accountService,
Proxy: proxyService,
Redeem: redeemService,
Usage: usageService,
Pricing: pricingService,
Billing: billingService,
BillingCache: billingCacheService,
Admin: adminService,
Gateway: gatewayService,
OpenAIGateway: openAIGatewayService,
OAuth: oAuthService,
OpenAIOAuth: openAIOAuthService,
RateLimit: rateLimitService,
AccountUsage: accountUsageService,
AccountTest: accountTestService,
Setting: settingService,
Email: emailService,
EmailQueue: emailQueueService,
Turnstile: turnstileService,
Subscription: subscriptionService,
Concurrency: concurrencyService,
Identity: identityService,
Update: updateService,
TokenRefresh: tokenRefreshService,
}
repositories := &repository.Repositories{
User: userRepository,
ApiKey: apiKeyRepository,
Group: groupRepository,
Account: accountRepository,
Proxy: proxyRepository,
RedeemCode: redeemCodeRepository,
UsageLog: usageLogRepository,
Setting: settingRepository,
UserSubscription: userSubscriptionRepository,
}
engine := server.ProvideRouter(configConfig, handlers, services, repositories)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService)
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService)
httpServer := server.ProvideHTTPServer(configConfig, engine)
v := provideCleanup(db, client, services)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
v := provideCleanup(db, client, tokenRefreshService, pricingService, emailQueueService, oAuthService, openAIOAuthService, geminiOAuthService)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -187,7 +155,12 @@ func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
func provideCleanup(
db *gorm.DB,
rdb *redis.Client,
services *service.Services,
tokenRefresh *service.TokenRefreshService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
oauth *service.OAuthService,
openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -198,23 +171,27 @@ func provideCleanup(
fn func() error
}{
{"TokenRefreshService", func() error {
services.TokenRefresh.Stop()
tokenRefresh.Stop()
return nil
}},
{"PricingService", func() error {
services.Pricing.Stop()
pricing.Stop()
return nil
}},
{"EmailQueueService", func() error {
services.EmailQueue.Stop()
emailQueue.Stop()
return nil
}},
{"OAuthService", func() error {
services.OAuth.Stop()
oauth.Stop()
return nil
}},
{"OpenAIOAuthService", func() error {
services.OpenAIOAuth.Stop()
openaiOAuth.Stop()
return nil
}},
{"GeminiOAuthService", func() error {
geminiOAuth.Stop()
return nil
}},
{"Redis", func() error {

View File

@@ -1,38 +0,0 @@
server:
host: "0.0.0.0"
port: 8080
mode: "debug" # debug/release
database:
host: "127.0.0.1"
port: 5432
user: "postgres"
password: "XZeRr7nkjHWhm8fw"
dbname: "sub2api"
sslmode: "disable"
redis:
host: "127.0.0.1"
port: 6379
password: ""
db: 0
jwt:
secret: "your-secret-key-change-in-production"
expire_hour: 24
default:
admin_email: "admin@sub2api.com"
admin_password: "admin123"
user_concurrency: 5
user_balance: 0
api_key_prefix: "sk-"
rate_multiplier: 1.0
# Timezone configuration (similar to PHP's date_default_timezone_set)
# This affects ALL time operations:
# - Database timestamps
# - Usage statistics "today" boundary
# - Subscription expiry times
# Common values: Asia/Shanghai, America/New_York, Europe/London, UTC
timezone: "Asia/Shanghai"

View File

@@ -8,64 +8,121 @@ require (
github.com/gin-gonic/gin v1.9.1
github.com/golang-jwt/jwt/v5 v5.2.0
github.com/google/uuid v1.6.0
github.com/google/wire v0.7.0
github.com/imroc/req/v3 v3.56.0
github.com/lib/pq v1.10.9
github.com/redis/go-redis/v9 v9.3.0
github.com/redis/go-redis/v9 v9.7.3
github.com/spf13/viper v1.18.2
github.com/stretchr/testify v1.11.1
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0
github.com/testcontainers/testcontainers-go/modules/redis v0.40.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
golang.org/x/crypto v0.44.0
golang.org/x/net v0.47.0
golang.org/x/term v0.37.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/datatypes v1.2.0
gorm.io/driver/postgres v1.5.4
gorm.io/gorm v1.25.5
)
require (
dario.cat/mergo v1.0.2 // indirect
filippo.io/edwards25519 v1.1.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/containerd/errdefs v1.0.0 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v28.5.1+incompatible // indirect
github.com/docker/go-connections v0.6.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/ebitengine/purego v0.8.4 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/subcommands v1.2.0 // indirect
github.com/google/wire v0.7.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/icholy/digest v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.4.3 // indirect
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect
github.com/jackc/pgx/v5 v5.5.5 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.1 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/magiconair/properties v1.8.7 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/magiconair/properties v1.8.10 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mdelapenya/tlscert v0.2.0 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/go-archive v0.1.0 // indirect
github.com/moby/patternmatcher v0.6.0 // indirect
github.com/moby/sys/sequential v0.6.0 // indirect
github.com/moby/sys/user v0.4.0 // indirect
github.com/moby/sys/userns v0.1.0 // indirect
github.com/moby/term v0.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/quic-go/qpack v0.5.1 // indirect
github.com/quic-go/quic-go v0.56.0 // indirect
github.com/refraction-networking/utls v1.8.1 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/testcontainers/testcontainers-go v0.40.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
go.opentelemetry.io/otel v1.37.0 // indirect
go.opentelemetry.io/otel/metric v1.37.0 // indirect
go.opentelemetry.io/otel/sdk v1.37.0 // indirect
go.opentelemetry.io/otel/trace v1.37.0 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/arch v0.3.0 // indirect
@@ -75,6 +132,8 @@ require (
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
golang.org/x/tools v0.38.0 // indirect
google.golang.org/protobuf v1.31.0 // indirect
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gorm.io/driver/mysql v1.5.2 // indirect
)

View File

@@ -1,3 +1,13 @@
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
@@ -7,17 +17,43 @@ github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
@@ -28,6 +64,13 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
@@ -36,13 +79,19 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
@@ -54,6 +103,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3J18=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
@@ -62,10 +113,12 @@ github.com/imroc/req/v3 v3.56.0 h1:t6YdqqerYBXhZ9+VjqsQs5wlKxdUNEvsgBhxWc1AEEo=
github.com/imroc/req/v3 v3.56.0/go.mod h1:cUZSooE8hhzFNOrAbdxuemXDQxFXLQTnu3066jr7ZGk=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY=
github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
@@ -85,36 +138,74 @@ github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI=
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o=
github.com/microsoft/go-mssqldb v0.17.0 h1:Fto83dMZPnYv1Zwx5vHHxpNraeEaUlQ/hhHLgZiaenE=
github.com/microsoft/go-mssqldb v0.17.0/go.mod h1:OkoNGhGEs8EZqchVTtochlXruEhEOaO4S0d2sB5aeGQ=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/moby/go-archive v0.1.0 h1:Kk/5rdW/g+H8NHdJW2gsXyZ7UnzvJNOy6VKJqueWdcQ=
github.com/moby/go-archive v0.1.0/go.mod h1:G9B+YoujNohJmrIYFBpSd54GTUB4lt9S+xVQvsJyFuo=
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc=
github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw=
github.com/moby/sys/atomicwriter v0.1.0/go.mod h1:Ul8oqv2ZMNHOceF643P6FKPXeCmYtlQMvpizfsSoaWs=
github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU=
github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko=
github.com/moby/sys/user v0.4.0 h1:jhcMKit7SA80hivmFJcbB1vqmw//wU61Zdui2eQXuMs=
github.com/moby/sys/user v0.4.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs=
github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g=
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4=
github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
github.com/quic-go/quic-go v0.56.0 h1:q/TW+OLismmXAehgFLczhCDTYB3bFmua4D9lsNBWxvY=
github.com/quic-go/quic-go v0.56.0/go.mod h1:9gx5KsFQtw2oZ6GZTyh+7YEvOxWCL9WZAepnHxgAo6c=
github.com/redis/go-redis/v9 v9.3.0 h1:RiVDjmig62jIWp7Kk4XVLs0hzV6pI3PyTnnL0cnn0u0=
github.com/redis/go-redis/v9 v9.3.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
@@ -128,6 +219,8 @@ github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMV
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
@@ -135,16 +228,55 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU=
github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY=
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 h1:s2bIayFXlbDFexo96y+htn7FzuhpXLYJNnIuglNKqOk=
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0/go.mod h1:h+u/2KoREGTnTl9UwrQ/g+XhasAT8E6dClclAADeXoQ=
github.com/testcontainers/testcontainers-go/modules/redis v0.40.0 h1:OG4qwcxp2O0re7V7M9lY9w0v6wWgWf7j7rtkpAnGMd0=
github.com/testcontainers/testcontainers-go/modules/redis v0.40.0/go.mod h1:Bc+EDhKMo5zI5V5zdBkHiMVzeAXbtI4n5isS/nzf6zw=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw=
go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ=
go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU=
go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE=
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
@@ -164,8 +296,14 @@ golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
@@ -177,9 +315,15 @@ golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 h1:i8QOKZfYg6AbGVZzUAY3LrNWCKF8O6zFisU9Wl9RER4=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4/go.mod h1:HSkG/KdJWusxU1F6CNrwNDjBMgisKxGnc5dAZfT0mjQ=
google.golang.org/grpc v1.75.1 h1:/ODCNEuf9VghjgO3rqLcfg8fiOP0nSluljWFlDxELLI=
google.golang.org/grpc v1.75.1/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ=
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
@@ -188,10 +332,19 @@ gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/datatypes v1.2.0 h1:5YT+eokWdIxhJgWHdrb2zYUimyk0+TaFth+7a0ybzco=
gorm.io/datatypes v1.2.0/go.mod h1:o1dh0ZvjIjhH/bngTpypG6lVRJ5chTBxE09FH/71k04=
gorm.io/driver/mysql v1.5.2 h1:QC2HRskSE75wBuOxe0+iCkyJZ+RqpudsQtqkp+IMuXs=
gorm.io/driver/mysql v1.5.2/go.mod h1:pQLhh1Ut/WUAySdTHwBpBv6+JKcj+ua4ZFx1QQTBzb8=
gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo=
gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0=
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
gorm.io/driver/sqlserver v1.4.1 h1:t4r4r6Jam5E6ejqP7N82qAJIJAht27EGT41HyPfXRw0=
gorm.io/driver/sqlserver v1.4.1/go.mod h1:DJ4P+MeZbc5rvY58PnmN1Lnyvb5gw5NPzGshHDnJLig=
gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

View File

@@ -18,6 +18,17 @@ type Config struct {
Gateway GatewayConfig `mapstructure:"gateway"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Gemini GeminiConfig `mapstructure:"gemini"`
}
type GeminiConfig struct {
OAuth GeminiOAuthConfig `mapstructure:"oauth"`
}
type GeminiOAuthConfig struct {
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
Scopes string `mapstructure:"scopes"`
}
// TokenRefreshConfig OAuth token自动刷新配置
@@ -211,9 +222,16 @@ func setDefaults() {
// TokenRefresh
viper.SetDefault("token_refresh.enabled", true)
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新适配Google 1小时token
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
// Gemini OAuth - configure via environment variables or config file
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
// Default: uses Gemini CLI public credentials (set via environment)
viper.SetDefault("gemini.oauth.client_id", "")
viper.SetDefault("gemini.oauth.client_secret", "")
viper.SetDefault("gemini.oauth.scopes", "")
}
func (c *Config) Validate() error {

View File

@@ -2,9 +2,11 @@ package admin
import (
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
@@ -30,22 +32,36 @@ type AccountHandler struct {
adminService service.AdminService
oauthService *service.OAuthService
openaiOAuthService *service.OpenAIOAuthService
geminiOAuthService *service.GeminiOAuthService
rateLimitService *service.RateLimitService
accountUsageService *service.AccountUsageService
accountTestService *service.AccountTestService
concurrencyService *service.ConcurrencyService
crsSyncService *service.CRSSyncService
}
// NewAccountHandler creates a new admin account handler
func NewAccountHandler(adminService service.AdminService, oauthService *service.OAuthService, openaiOAuthService *service.OpenAIOAuthService, rateLimitService *service.RateLimitService, accountUsageService *service.AccountUsageService, accountTestService *service.AccountTestService, concurrencyService *service.ConcurrencyService) *AccountHandler {
func NewAccountHandler(
adminService service.AdminService,
oauthService *service.OAuthService,
openaiOAuthService *service.OpenAIOAuthService,
geminiOAuthService *service.GeminiOAuthService,
rateLimitService *service.RateLimitService,
accountUsageService *service.AccountUsageService,
accountTestService *service.AccountTestService,
concurrencyService *service.ConcurrencyService,
crsSyncService *service.CRSSyncService,
) *AccountHandler {
return &AccountHandler{
adminService: adminService,
oauthService: oauthService,
openaiOAuthService: openaiOAuthService,
geminiOAuthService: geminiOAuthService,
rateLimitService: rateLimitService,
accountUsageService: accountUsageService,
accountTestService: accountTestService,
concurrencyService: concurrencyService,
crsSyncService: crsSyncService,
}
}
@@ -76,9 +92,22 @@ type UpdateAccountRequest struct {
GroupIDs *[]int64 `json:"group_ids"`
}
// BulkUpdateAccountsRequest represents the payload for bulk editing accounts
type BulkUpdateAccountsRequest struct {
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
Name string `json:"name"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
GroupIDs *[]int64 `json:"group_ids"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
}
// AccountWithConcurrency extends Account with real-time concurrency info
type AccountWithConcurrency struct {
*model.Account
*dto.Account
CurrentConcurrency int `json:"current_concurrency"`
}
@@ -93,7 +122,7 @@ func (h *AccountHandler) List(c *gin.Context) {
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
if err != nil {
response.InternalError(c, "Failed to list accounts: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -113,7 +142,7 @@ func (h *AccountHandler) List(c *gin.Context) {
result := make([]AccountWithConcurrency, len(accounts))
for i := range accounts {
result[i] = AccountWithConcurrency{
Account: &accounts[i],
Account: dto.AccountFromService(&accounts[i]),
CurrentConcurrency: concurrencyCounts[accounts[i].ID],
}
}
@@ -132,11 +161,11 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.NotFound(c, "Account not found")
response.ErrorFrom(c, err)
return
}
response.Success(c, account)
response.Success(c, dto.AccountFromService(account))
}
// Create handles creating a new account
@@ -160,11 +189,11 @@ func (h *AccountHandler) Create(c *gin.Context) {
GroupIDs: req.GroupIDs,
})
if err != nil {
response.BadRequest(c, "Failed to create account: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, account)
response.Success(c, dto.AccountFromService(account))
}
// Update handles updating an account
@@ -194,11 +223,11 @@ func (h *AccountHandler) Update(c *gin.Context) {
GroupIDs: req.GroupIDs,
})
if err != nil {
response.InternalError(c, "Failed to update account: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, account)
response.Success(c, dto.AccountFromService(account))
}
// Delete handles deleting an account
@@ -212,7 +241,7 @@ func (h *AccountHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteAccount(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to delete account: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -224,6 +253,13 @@ type TestAccountRequest struct {
ModelID string `json:"model_id"`
}
type SyncFromCRSRequest struct {
BaseURL string `json:"base_url" binding:"required"`
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
SyncProxies *bool `json:"sync_proxies"`
}
// Test handles testing account connectivity with SSE streaming
// POST /api/v1/admin/accounts/:id/test
func (h *AccountHandler) Test(c *gin.Context) {
@@ -244,6 +280,35 @@ func (h *AccountHandler) Test(c *gin.Context) {
}
}
// SyncFromCRS handles syncing accounts from claude-relay-service (CRS)
// POST /api/v1/admin/accounts/sync/crs
func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
var req SyncFromCRSRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Default to syncing proxies (can be disabled by explicitly setting false)
syncProxies := true
if req.SyncProxies != nil {
syncProxies = *req.SyncProxies
}
result, err := h.crsSyncService.SyncFromCRS(c.Request.Context(), service.SyncFromCRSInput{
BaseURL: req.BaseURL,
Username: req.Username,
Password: req.Password,
SyncProxies: syncProxies,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// Refresh handles refreshing account credentials
// POST /api/v1/admin/accounts/:id/refresh
func (h *AccountHandler) Refresh(c *gin.Context) {
@@ -272,7 +337,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
// Use OpenAI OAuth service to refresh token
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -285,11 +350,24 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
newCredentials[k] = v
}
}
} else if account.Platform == service.PlatformGemini {
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
return
}
newCredentials = h.geminiOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
} else {
// Use Anthropic/Claude OAuth service to refresh token
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -302,21 +380,25 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
// Update token-related fields
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = tokenInfo.ExpiresIn
newCredentials["expires_at"] = tokenInfo.ExpiresAt
newCredentials["refresh_token"] = tokenInfo.RefreshToken
newCredentials["scope"] = tokenInfo.Scope
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
if strings.TrimSpace(tokenInfo.RefreshToken) != "" {
newCredentials["refresh_token"] = tokenInfo.RefreshToken
}
if strings.TrimSpace(tokenInfo.Scope) != "" {
newCredentials["scope"] = tokenInfo.Scope
}
}
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: newCredentials,
})
if err != nil {
response.InternalError(c, "Failed to update account credentials: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, updatedAccount)
response.Success(c, dto.AccountFromService(updatedAccount))
}
// GetStats handles getting account statistics
@@ -343,7 +425,7 @@ func (h *AccountHandler) GetStats(c *gin.Context) {
stats, err := h.accountUsageService.GetAccountUsageStats(c.Request.Context(), accountID, startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get account stats: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -361,11 +443,11 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
account, err := h.adminService.ClearAccountError(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to clear error: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, account)
response.Success(c, dto.AccountFromService(account))
}
// BatchCreate handles batch creating accounts
@@ -387,6 +469,136 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
})
}
// BatchUpdateCredentialsRequest represents batch credentials update request
type BatchUpdateCredentialsRequest struct {
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
Field string `json:"field" binding:"required,oneof=account_uuid org_uuid intercept_warmup_requests"`
Value any `json:"value"`
}
// BatchUpdateCredentials handles batch updating credentials fields
// POST /api/v1/admin/accounts/batch-update-credentials
func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
var req BatchUpdateCredentialsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Validate value type based on field
if req.Field == "intercept_warmup_requests" {
// Must be boolean
if _, ok := req.Value.(bool); !ok {
response.BadRequest(c, "intercept_warmup_requests must be boolean")
return
}
} else {
// account_uuid and org_uuid can be string or null
if req.Value != nil {
if _, ok := req.Value.(string); !ok {
response.BadRequest(c, req.Field+" must be string or null")
return
}
}
}
ctx := c.Request.Context()
success := 0
failed := 0
results := []gin.H{}
for _, accountID := range req.AccountIDs {
// Get account
account, err := h.adminService.GetAccount(ctx, accountID)
if err != nil {
failed++
results = append(results, gin.H{
"account_id": accountID,
"success": false,
"error": "Account not found",
})
continue
}
// Update credentials field
if account.Credentials == nil {
account.Credentials = make(map[string]any)
}
account.Credentials[req.Field] = req.Value
// Update account
updateInput := &service.UpdateAccountInput{
Credentials: account.Credentials,
}
_, err = h.adminService.UpdateAccount(ctx, accountID, updateInput)
if err != nil {
failed++
results = append(results, gin.H{
"account_id": accountID,
"success": false,
"error": err.Error(),
})
continue
}
success++
results = append(results, gin.H{
"account_id": accountID,
"success": true,
})
}
response.Success(c, gin.H{
"success": success,
"failed": failed,
"results": results,
})
}
// BulkUpdate handles bulk updating accounts with selected fields/credentials.
// POST /api/v1/admin/accounts/bulk-update
func (h *AccountHandler) BulkUpdate(c *gin.Context) {
var req BulkUpdateAccountsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
hasUpdates := req.Name != "" ||
req.ProxyID != nil ||
req.Concurrency != nil ||
req.Priority != nil ||
req.Status != "" ||
req.GroupIDs != nil ||
len(req.Credentials) > 0 ||
len(req.Extra) > 0
if !hasUpdates {
response.BadRequest(c, "No updates provided")
return
}
result, err := h.adminService.BulkUpdateAccounts(c.Request.Context(), &service.BulkUpdateAccountsInput{
AccountIDs: req.AccountIDs,
Name: req.Name,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
Status: req.Status,
GroupIDs: req.GroupIDs,
Credentials: req.Credentials,
Extra: req.Extra,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// ========== OAuth Handlers ==========
// GenerateAuthURLRequest represents the request for generating auth URL
@@ -405,7 +617,7 @@ func (h *OAuthHandler) GenerateAuthURL(c *gin.Context) {
result, err := h.oauthService.GenerateAuthURL(c.Request.Context(), req.ProxyID)
if err != nil {
response.InternalError(c, "Failed to generate auth URL: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -423,7 +635,7 @@ func (h *OAuthHandler) GenerateSetupTokenURL(c *gin.Context) {
result, err := h.oauthService.GenerateSetupTokenURL(c.Request.Context(), req.ProxyID)
if err != nil {
response.InternalError(c, "Failed to generate setup token URL: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -452,7 +664,7 @@ func (h *OAuthHandler) ExchangeCode(c *gin.Context) {
ProxyID: req.ProxyID,
})
if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -474,7 +686,7 @@ func (h *OAuthHandler) ExchangeSetupTokenCode(c *gin.Context) {
ProxyID: req.ProxyID,
})
if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -502,7 +714,7 @@ func (h *OAuthHandler) CookieAuth(c *gin.Context) {
Scope: "full",
})
if err != nil {
response.BadRequest(c, "Cookie auth failed: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -524,7 +736,7 @@ func (h *OAuthHandler) SetupTokenCookieAuth(c *gin.Context) {
Scope: "inference",
})
if err != nil {
response.BadRequest(c, "Cookie auth failed: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -542,7 +754,7 @@ func (h *AccountHandler) GetUsage(c *gin.Context) {
usage, err := h.accountUsageService.GetUsage(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to get usage: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -560,7 +772,7 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
err = h.rateLimitService.ClearRateLimit(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to clear rate limit: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -578,7 +790,7 @@ func (h *AccountHandler) GetTodayStats(c *gin.Context) {
stats, err := h.accountUsageService.GetTodayStats(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to get today stats: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -607,11 +819,11 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
account, err := h.adminService.SetAccountSchedulable(c.Request.Context(), accountID, req.Schedulable)
if err != nil {
response.InternalError(c, "Failed to update schedulable status: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, account)
response.Success(c, dto.AccountFromService(account))
}
// GetAvailableModels handles getting available models for an account
@@ -668,6 +880,44 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
return
}
// Handle Gemini accounts
if account.IsGemini() {
// For OAuth accounts: return default Gemini models
if account.IsOAuth() {
response.Success(c, geminicli.DefaultModels)
return
}
// For API Key accounts: return models based on model_mapping
mapping := account.GetModelMapping()
if len(mapping) == 0 {
response.Success(c, geminicli.DefaultModels)
return
}
var models []geminicli.Model
for requestedModel := range mapping {
var found bool
for _, dm := range geminicli.DefaultModels {
if dm.ID == requestedModel {
models = append(models, dm)
found = true
break
}
}
if !found {
models = append(models, geminicli.Model{
ID: requestedModel,
Type: "model",
DisplayName: requestedModel,
CreatedAt: "",
})
}
}
response.Success(c, models)
return
}
// Handle Claude/Anthropic accounts
// For OAuth and Setup-Token accounts: return default models
if account.IsOAuth() {

View File

@@ -1,11 +1,12 @@
package admin
import (
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
"strconv"
"time"
"github.com/gin-gonic/gin"
)

View File

@@ -0,0 +1,135 @@
package admin
import (
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type GeminiOAuthHandler struct {
geminiOAuthService *service.GeminiOAuthService
}
func NewGeminiOAuthHandler(geminiOAuthService *service.GeminiOAuthService) *GeminiOAuthHandler {
return &GeminiOAuthHandler{geminiOAuthService: geminiOAuthService}
}
// GET /api/v1/admin/gemini/oauth/capabilities
func (h *GeminiOAuthHandler) GetCapabilities(c *gin.Context) {
cfg := h.geminiOAuthService.GetOAuthConfig()
response.Success(c, cfg)
}
type GeminiGenerateAuthURLRequest struct {
ProxyID *int64 `json:"proxy_id"`
ProjectID string `json:"project_id"`
// OAuth 类型: "code_assist" (需要 project_id) 或 "ai_studio" (不需要 project_id)
// 默认为 "code_assist" 以保持向后兼容
OAuthType string `json:"oauth_type"`
}
// GenerateAuthURL generates Google OAuth authorization URL for Gemini.
// POST /api/v1/admin/gemini/oauth/auth-url
func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
var req GeminiGenerateAuthURLRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// 默认使用 code_assist 以保持向后兼容
oauthType := strings.TrimSpace(req.OAuthType)
if oauthType == "" {
oauthType = "code_assist"
}
if oauthType != "code_assist" && oauthType != "ai_studio" {
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist' or 'ai_studio'")
return
}
// Always pass the "hosted" callback URI; the OAuth service may override it depending on
// oauth_type and whether the built-in Gemini CLI OAuth client is used.
redirectURI := deriveGeminiRedirectURI(c)
result, err := h.geminiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, redirectURI, req.ProjectID, oauthType)
if err != nil {
msg := err.Error()
// Treat missing/invalid OAuth client configuration as a user/config error.
if strings.Contains(msg, "OAuth client not configured") || strings.Contains(msg, "requires your own OAuth Client") {
response.BadRequest(c, "Failed to generate auth URL: "+msg)
return
}
response.InternalError(c, "Failed to generate auth URL: "+msg)
return
}
response.Success(c, result)
}
type GeminiExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"`
State string `json:"state" binding:"required"`
Code string `json:"code" binding:"required"`
ProxyID *int64 `json:"proxy_id"`
// OAuth 类型: "code_assist" 或 "ai_studio",需要与 GenerateAuthURL 时的类型一致
OAuthType string `json:"oauth_type"`
}
// ExchangeCode exchanges authorization code for tokens.
// POST /api/v1/admin/gemini/oauth/exchange-code
func (h *GeminiOAuthHandler) ExchangeCode(c *gin.Context) {
var req GeminiExchangeCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// 默认使用 code_assist 以保持向后兼容
oauthType := strings.TrimSpace(req.OAuthType)
if oauthType == "" {
oauthType = "code_assist"
}
if oauthType != "code_assist" && oauthType != "ai_studio" {
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist' or 'ai_studio'")
return
}
tokenInfo, err := h.geminiOAuthService.ExchangeCode(c.Request.Context(), &service.GeminiExchangeCodeInput{
SessionID: req.SessionID,
State: req.State,
Code: req.Code,
ProxyID: req.ProxyID,
OAuthType: oauthType,
})
if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error())
return
}
response.Success(c, tokenInfo)
}
func deriveGeminiRedirectURI(c *gin.Context) string {
origin := strings.TrimSpace(c.GetHeader("Origin"))
if origin != "" {
return strings.TrimRight(origin, "/") + "/auth/callback"
}
scheme := "http"
if c.Request.TLS != nil {
scheme = "https"
}
if xfProto := strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")); xfProto != "" {
scheme = strings.TrimSpace(strings.Split(xfProto, ",")[0])
}
host := strings.TrimSpace(c.Request.Host)
if xfHost := strings.TrimSpace(c.GetHeader("X-Forwarded-Host")); xfHost != "" {
host = strings.TrimSpace(strings.Split(xfHost, ",")[0])
}
return fmt.Sprintf("%s://%s/auth/callback", scheme, host)
}

View File

@@ -3,7 +3,7 @@ package admin
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -65,11 +65,15 @@ func (h *GroupHandler) List(c *gin.Context) {
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, isExclusive)
if err != nil {
response.InternalError(c, "Failed to list groups: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Paginated(c, groups, total, page, pageSize)
outGroups := make([]dto.Group, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
}
response.Paginated(c, outGroups, total, page, pageSize)
}
// GetAll handles getting all active groups without pagination
@@ -77,7 +81,7 @@ func (h *GroupHandler) List(c *gin.Context) {
func (h *GroupHandler) GetAll(c *gin.Context) {
platform := c.Query("platform")
var groups []model.Group
var groups []service.Group
var err error
if platform != "" {
@@ -87,11 +91,15 @@ func (h *GroupHandler) GetAll(c *gin.Context) {
}
if err != nil {
response.InternalError(c, "Failed to get groups: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, groups)
outGroups := make([]dto.Group, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
}
response.Success(c, outGroups)
}
// GetByID handles getting a group by ID
@@ -105,11 +113,11 @@ func (h *GroupHandler) GetByID(c *gin.Context) {
group, err := h.adminService.GetGroup(c.Request.Context(), groupID)
if err != nil {
response.NotFound(c, "Group not found")
response.ErrorFrom(c, err)
return
}
response.Success(c, group)
response.Success(c, dto.GroupFromService(group))
}
// Create handles creating a new group
@@ -133,11 +141,11 @@ func (h *GroupHandler) Create(c *gin.Context) {
MonthlyLimitUSD: req.MonthlyLimitUSD,
})
if err != nil {
response.BadRequest(c, "Failed to create group: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, group)
response.Success(c, dto.GroupFromService(group))
}
// Update handles updating a group
@@ -168,11 +176,11 @@ func (h *GroupHandler) Update(c *gin.Context) {
MonthlyLimitUSD: req.MonthlyLimitUSD,
})
if err != nil {
response.InternalError(c, "Failed to update group: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, group)
response.Success(c, dto.GroupFromService(group))
}
// Delete handles deleting a group
@@ -186,7 +194,7 @@ func (h *GroupHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteGroup(c.Request.Context(), groupID)
if err != nil {
response.InternalError(c, "Failed to delete group: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -225,9 +233,13 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
keys, total, err := h.adminService.GetGroupAPIKeys(c.Request.Context(), groupID, page, pageSize)
if err != nil {
response.InternalError(c, "Failed to get group API keys: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Paginated(c, keys, total, page, pageSize)
outKeys := make([]dto.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *dto.ApiKeyFromService(&keys[i]))
}
response.Paginated(c, outKeys, total, page, pageSize)
}

View File

@@ -3,6 +3,7 @@ package admin
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -40,7 +41,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI)
if err != nil {
response.InternalError(c, "Failed to generate auth URL: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -71,7 +72,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
ProxyID: req.ProxyID,
})
if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -103,7 +104,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
if err != nil {
response.BadRequest(c, "Failed to refresh token: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -122,7 +123,7 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
// Get account
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.NotFound(c, "Account not found")
response.ErrorFrom(c, err)
return
}
@@ -141,7 +142,7 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
// Use OpenAI OAuth service to refresh token
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -159,11 +160,11 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
Credentials: newCredentials,
})
if err != nil {
response.InternalError(c, "Failed to update account credentials: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, updatedAccount)
response.Success(c, dto.AccountFromService(updatedAccount))
}
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
@@ -192,7 +193,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
ProxyID: req.ProxyID,
})
if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -220,9 +221,9 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
GroupIDs: req.GroupIDs,
})
if err != nil {
response.InternalError(c, "Failed to create account: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, account)
response.Success(c, dto.AccountFromService(account))
}

View File

@@ -4,6 +4,7 @@ import (
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -53,11 +54,15 @@ func (h *ProxyHandler) List(c *gin.Context) {
proxies, total, err := h.adminService.ListProxies(c.Request.Context(), page, pageSize, protocol, status, search)
if err != nil {
response.InternalError(c, "Failed to list proxies: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Paginated(c, proxies, total, page, pageSize)
out := make([]dto.Proxy, 0, len(proxies))
for i := range proxies {
out = append(out, *dto.ProxyFromService(&proxies[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
// GetAll handles getting all active proxies without pagination
@@ -69,20 +74,28 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
if withCount {
proxies, err := h.adminService.GetAllProxiesWithAccountCount(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get proxies: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, proxies)
out := make([]dto.ProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i]))
}
response.Success(c, out)
return
}
proxies, err := h.adminService.GetAllProxies(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get proxies: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, proxies)
out := make([]dto.Proxy, 0, len(proxies))
for i := range proxies {
out = append(out, *dto.ProxyFromService(&proxies[i]))
}
response.Success(c, out)
}
// GetByID handles getting a proxy by ID
@@ -96,11 +109,11 @@ func (h *ProxyHandler) GetByID(c *gin.Context) {
proxy, err := h.adminService.GetProxy(c.Request.Context(), proxyID)
if err != nil {
response.NotFound(c, "Proxy not found")
response.ErrorFrom(c, err)
return
}
response.Success(c, proxy)
response.Success(c, dto.ProxyFromService(proxy))
}
// Create handles creating a new proxy
@@ -121,11 +134,11 @@ func (h *ProxyHandler) Create(c *gin.Context) {
Password: strings.TrimSpace(req.Password),
})
if err != nil {
response.BadRequest(c, "Failed to create proxy: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, proxy)
response.Success(c, dto.ProxyFromService(proxy))
}
// Update handles updating a proxy
@@ -153,11 +166,11 @@ func (h *ProxyHandler) Update(c *gin.Context) {
Status: strings.TrimSpace(req.Status),
})
if err != nil {
response.InternalError(c, "Failed to update proxy: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, proxy)
response.Success(c, dto.ProxyFromService(proxy))
}
// Delete handles deleting a proxy
@@ -171,7 +184,7 @@ func (h *ProxyHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteProxy(c.Request.Context(), proxyID)
if err != nil {
response.InternalError(c, "Failed to delete proxy: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -189,7 +202,7 @@ func (h *ProxyHandler) Test(c *gin.Context) {
result, err := h.adminService.TestProxy(c.Request.Context(), proxyID)
if err != nil {
response.InternalError(c, "Failed to test proxy: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -229,11 +242,15 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
accounts, total, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID, page, pageSize)
if err != nil {
response.InternalError(c, "Failed to get proxy accounts: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Paginated(c, accounts, total, page, pageSize)
out := make([]dto.Account, 0, len(accounts))
for i := range accounts {
out = append(out, *dto.AccountFromService(&accounts[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
// BatchCreateProxyItem represents a single proxy in batch create request
@@ -272,7 +289,7 @@ func (h *ProxyHandler) BatchCreate(c *gin.Context) {
// Check for duplicates (same host, port, username, password)
exists, err := h.adminService.CheckProxyExists(c.Request.Context(), host, item.Port, username, password)
if err != nil {
response.InternalError(c, "Failed to check proxy existence: "+err.Error())
response.ErrorFrom(c, err)
return
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -43,11 +44,15 @@ func (h *RedeemHandler) List(c *gin.Context) {
codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search)
if err != nil {
response.InternalError(c, "Failed to list redeem codes: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Paginated(c, codes, total, page, pageSize)
out := make([]dto.RedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
// GetByID handles getting a redeem code by ID
@@ -61,11 +66,11 @@ func (h *RedeemHandler) GetByID(c *gin.Context) {
code, err := h.adminService.GetRedeemCode(c.Request.Context(), codeID)
if err != nil {
response.NotFound(c, "Redeem code not found")
response.ErrorFrom(c, err)
return
}
response.Success(c, code)
response.Success(c, dto.RedeemCodeFromService(code))
}
// Generate handles generating new redeem codes
@@ -85,11 +90,15 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
ValidityDays: req.ValidityDays,
})
if err != nil {
response.InternalError(c, "Failed to generate redeem codes: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, codes)
out := make([]dto.RedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
}
response.Success(c, out)
}
// Delete handles deleting a redeem code
@@ -103,7 +112,7 @@ func (h *RedeemHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteRedeemCode(c.Request.Context(), codeID)
if err != nil {
response.InternalError(c, "Failed to delete redeem code: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -123,7 +132,7 @@ func (h *RedeemHandler) BatchDelete(c *gin.Context) {
deleted, err := h.adminService.BatchDeleteRedeemCodes(c.Request.Context(), req.IDs)
if err != nil {
response.InternalError(c, "Failed to batch delete redeem codes: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -144,11 +153,11 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
code, err := h.adminService.ExpireRedeemCode(c.Request.Context(), codeID)
if err != nil {
response.InternalError(c, "Failed to expire redeem code: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, code)
response.Success(c, dto.RedeemCodeFromService(code))
}
// GetStats handles getting redeem code statistics
@@ -178,7 +187,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
// Get all codes without pagination (use large page size)
codes, _, err := h.adminService.ListRedeemCodes(c.Request.Context(), 1, 10000, codeType, status, "")
if err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
response.ErrorFrom(c, err)
return
}

View File

@@ -1,7 +1,7 @@
package admin
import (
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -27,11 +27,32 @@ func NewSettingHandler(settingService *service.SettingService, emailService *ser
func (h *SettingHandler) GetSettings(c *gin.Context) {
settings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get settings: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, settings)
response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
SmtpHost: settings.SmtpHost,
SmtpPort: settings.SmtpPort,
SmtpUsername: settings.SmtpUsername,
SmtpPassword: settings.SmtpPassword,
SmtpFrom: settings.SmtpFrom,
SmtpFromName: settings.SmtpFromName,
SmtpUseTLS: settings.SmtpUseTLS,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
TurnstileSecretKey: settings.TurnstileSecretKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
ApiBaseUrl: settings.ApiBaseUrl,
ContactInfo: settings.ContactInfo,
DocUrl: settings.DocUrl,
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
})
}
// UpdateSettingsRequest 更新设置请求
@@ -87,7 +108,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.SmtpPort = 587
}
settings := &model.SystemSettings{
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
SmtpHost: req.SmtpHost,
@@ -111,18 +132,39 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
response.InternalError(c, "Failed to update settings: "+err.Error())
response.ErrorFrom(c, err)
return
}
// 重新获取设置返回
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get updated settings: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, updatedSettings)
response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
SmtpHost: updatedSettings.SmtpHost,
SmtpPort: updatedSettings.SmtpPort,
SmtpUsername: updatedSettings.SmtpUsername,
SmtpPassword: updatedSettings.SmtpPassword,
SmtpFrom: updatedSettings.SmtpFrom,
SmtpFromName: updatedSettings.SmtpFromName,
SmtpUseTLS: updatedSettings.SmtpUseTLS,
TurnstileEnabled: updatedSettings.TurnstileEnabled,
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
TurnstileSecretKey: updatedSettings.TurnstileSecretKey,
SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle,
ApiBaseUrl: updatedSettings.ApiBaseUrl,
ContactInfo: updatedSettings.ContactInfo,
DocUrl: updatedSettings.DocUrl,
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
})
}
// TestSmtpRequest 测试SMTP连接请求
@@ -166,7 +208,7 @@ func (h *SettingHandler) TestSmtpConnection(c *gin.Context) {
err := h.emailService.TestSmtpConnectionWithConfig(config)
if err != nil {
response.BadRequest(c, "SMTP connection test failed: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -252,7 +294,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
`
if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil {
response.BadRequest(c, "Failed to send test email: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -264,7 +306,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
maskedKey, exists, err := h.settingService.GetAdminApiKeyStatus(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get admin API key status: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -279,7 +321,7 @@ func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
key, err := h.settingService.GenerateAdminApiKey(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to generate admin API key: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -292,7 +334,7 @@ func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
// DELETE /api/v1/admin/settings/admin-api-key
func (h *SettingHandler) DeleteAdminApiKey(c *gin.Context) {
if err := h.settingService.DeleteAdminApiKey(c.Request.Context()); err != nil {
response.InternalError(c, "Failed to delete admin API key: "+err.Error())
response.ErrorFrom(c, err)
return
}

View File

@@ -3,9 +3,10 @@ package admin
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -78,11 +79,15 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
if err != nil {
response.InternalError(c, "Failed to list subscriptions: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.PaginatedWithResult(c, subscriptions, toResponsePagination(pagination))
out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
}
// GetByID handles getting a subscription by ID
@@ -96,11 +101,11 @@ func (h *SubscriptionHandler) GetByID(c *gin.Context) {
subscription, err := h.subscriptionService.GetByID(c.Request.Context(), subscriptionID)
if err != nil {
response.NotFound(c, "Subscription not found")
response.ErrorFrom(c, err)
return
}
response.Success(c, subscription)
response.Success(c, dto.UserSubscriptionFromService(subscription))
}
// GetProgress handles getting subscription usage progress
@@ -141,11 +146,11 @@ func (h *SubscriptionHandler) Assign(c *gin.Context) {
Notes: req.Notes,
})
if err != nil {
response.BadRequest(c, "Failed to assign subscription: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, subscription)
response.Success(c, dto.UserSubscriptionFromService(subscription))
}
// BulkAssign handles bulk assigning subscriptions to multiple users
@@ -168,11 +173,11 @@ func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
Notes: req.Notes,
})
if err != nil {
response.InternalError(c, "Failed to bulk assign subscriptions: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
response.Success(c, dto.BulkAssignResultFromService(result))
}
// Extend handles extending a subscription
@@ -192,11 +197,11 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days)
if err != nil {
response.InternalError(c, "Failed to extend subscription: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, subscription)
response.Success(c, dto.UserSubscriptionFromService(subscription))
}
// Revoke handles revoking a subscription
@@ -210,7 +215,7 @@ func (h *SubscriptionHandler) Revoke(c *gin.Context) {
err = h.subscriptionService.RevokeSubscription(c.Request.Context(), subscriptionID)
if err != nil {
response.InternalError(c, "Failed to revoke subscription: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -230,11 +235,15 @@ func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
subscriptions, pagination, err := h.subscriptionService.ListGroupSubscriptions(c.Request.Context(), groupID, page, pageSize)
if err != nil {
response.InternalError(c, "Failed to list group subscriptions: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.PaginatedWithResult(c, subscriptions, toResponsePagination(pagination))
out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
}
// ListByUser handles listing subscriptions for a specific user
@@ -248,19 +257,22 @@ func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), userID)
if err != nil {
response.InternalError(c, "Failed to list user subscriptions: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, subscriptions)
out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.Success(c, out)
}
// Helper function to get admin ID from context
func getAdminIDFromContext(c *gin.Context) int64 {
if user, exists := c.Get("user"); exists {
if u, ok := user.(*model.User); ok && u != nil {
return u.ID
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
return 0
}
return 0
return subject.UserID
}

View File

@@ -4,6 +4,7 @@ import (
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
@@ -39,7 +40,7 @@ func (h *UsageHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
// Parse filters
var userID, apiKeyID int64
var userID, apiKeyID, accountID, groupID int64
if userIDStr := c.Query("user_id"); userIDStr != "" {
id, err := strconv.ParseInt(userIDStr, 10, 64)
if err != nil {
@@ -58,6 +59,47 @@ func (h *UsageHandler) List(c *gin.Context) {
apiKeyID = id
}
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
id, err := strconv.ParseInt(accountIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account_id")
return
}
accountID = id
}
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
id, err := strconv.ParseInt(groupIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group_id")
return
}
groupID = id
}
model := c.Query("model")
var stream *bool
if streamStr := c.Query("stream"); streamStr != "" {
val, err := strconv.ParseBool(streamStr)
if err != nil {
response.BadRequest(c, "Invalid stream value, use true or false")
return
}
stream = &val
}
var billingType *int8
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
val, err := strconv.ParseInt(billingTypeStr, 10, 8)
if err != nil {
response.BadRequest(c, "Invalid billing_type")
return
}
bt := int8(val)
billingType = &bt
}
// Parse date range
var startTime, endTime *time.Time
if startDateStr := c.Query("start_date"); startDateStr != "" {
@@ -82,19 +124,28 @@ func (h *UsageHandler) List(c *gin.Context) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
filters := usagestats.UsageLogFilters{
UserID: userID,
ApiKeyID: apiKeyID,
StartTime: startTime,
EndTime: endTime,
UserID: userID,
ApiKeyID: apiKeyID,
AccountID: accountID,
GroupID: groupID,
Model: model,
Stream: stream,
BillingType: billingType,
StartTime: startTime,
EndTime: endTime,
}
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
if err != nil {
response.InternalError(c, "Failed to list usage records: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Paginated(c, records, result.Total, page, pageSize)
out := make([]dto.UsageLog, 0, len(records))
for i := range records {
out = append(out, *dto.UsageLogFromService(&records[i]))
}
response.Paginated(c, out, result.Total, page, pageSize)
}
// Stats handles getting usage statistics with filters
@@ -158,7 +209,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
if apiKeyID > 0 {
stats, err := h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, stats)
@@ -168,7 +219,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
if userID > 0 {
stats, err := h.usageService.GetStatsByUser(c.Request.Context(), userID, startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, stats)
@@ -178,7 +229,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
// Get global stats
stats, err := h.usageService.GetGlobalStats(c.Request.Context(), startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -197,7 +248,7 @@ func (h *UsageHandler) SearchUsers(c *gin.Context) {
// Limit to 30 results
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, "", "", keyword)
if err != nil {
response.InternalError(c, "Failed to search users: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -236,7 +287,7 @@ func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
keys, err := h.apiKeyService.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
if err != nil {
response.InternalError(c, "Failed to search API keys: "+err.Error())
response.ErrorFrom(c, err)
return
}

View File

@@ -3,6 +3,7 @@ package admin
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -64,11 +65,15 @@ func (h *UserHandler) List(c *gin.Context) {
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, status, role, search)
if err != nil {
response.InternalError(c, "Failed to list users: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Paginated(c, users, total, page, pageSize)
out := make([]dto.User, 0, len(users))
for i := range users {
out = append(out, *dto.UserFromService(&users[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
// GetByID handles getting a user by ID
@@ -82,11 +87,11 @@ func (h *UserHandler) GetByID(c *gin.Context) {
user, err := h.adminService.GetUser(c.Request.Context(), userID)
if err != nil {
response.NotFound(c, "User not found")
response.ErrorFrom(c, err)
return
}
response.Success(c, user)
response.Success(c, dto.UserFromService(user))
}
// Create handles creating a new user
@@ -109,11 +114,11 @@ func (h *UserHandler) Create(c *gin.Context) {
AllowedGroups: req.AllowedGroups,
})
if err != nil {
response.BadRequest(c, "Failed to create user: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, user)
response.Success(c, dto.UserFromService(user))
}
// Update handles updating a user
@@ -144,11 +149,11 @@ func (h *UserHandler) Update(c *gin.Context) {
AllowedGroups: req.AllowedGroups,
})
if err != nil {
response.InternalError(c, "Failed to update user: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, user)
response.Success(c, dto.UserFromService(user))
}
// Delete handles deleting a user
@@ -162,7 +167,7 @@ func (h *UserHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteUser(c.Request.Context(), userID)
if err != nil {
response.InternalError(c, "Failed to delete user: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -186,11 +191,11 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes)
if err != nil {
response.InternalError(c, "Failed to update balance: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, user)
response.Success(c, dto.UserFromService(user))
}
// GetUserAPIKeys handles getting user's API keys
@@ -206,11 +211,15 @@ func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize)
if err != nil {
response.InternalError(c, "Failed to get user API keys: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Paginated(c, keys, total, page, pageSize)
out := make([]dto.ApiKey, 0, len(keys))
for i := range keys {
out = append(out, *dto.ApiKeyFromService(&keys[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
// GetUserUsage handles getting user's usage statistics
@@ -226,7 +235,7 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) {
stats, err := h.adminService.GetUserUsageStats(c.Request.Context(), userID, period)
if err != nil {
response.InternalError(c, "Failed to get user usage: "+err.Error())
response.ErrorFrom(c, err)
return
}

View File

@@ -3,9 +3,10 @@ package handler
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -40,42 +41,34 @@ type UpdateAPIKeyRequest struct {
// List handles listing user's API keys with pagination
// GET /api/v1/api-keys
func (h *APIKeyHandler) List(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params)
keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params)
if err != nil {
response.InternalError(c, "Failed to list API keys: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Paginated(c, keys, result.Total, page, pageSize)
out := make([]dto.ApiKey, 0, len(keys))
for i := range keys {
out = append(out, *dto.ApiKeyFromService(&keys[i]))
}
response.Paginated(c, out, result.Total, page, pageSize)
}
// GetByID handles getting a single API key
// GET /api/v1/api-keys/:id
func (h *APIKeyHandler) GetByID(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -87,31 +80,25 @@ func (h *APIKeyHandler) GetByID(c *gin.Context) {
key, err := h.apiKeyService.GetByID(c.Request.Context(), keyID)
if err != nil {
response.NotFound(c, "API key not found")
response.ErrorFrom(c, err)
return
}
// 验证所有权
if key.UserID != user.ID {
if key.UserID != subject.UserID {
response.Forbidden(c, "Not authorized to access this key")
return
}
response.Success(c, key)
response.Success(c, dto.ApiKeyFromService(key))
}
// Create handles creating a new API key
// POST /api/v1/api-keys
func (h *APIKeyHandler) Create(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -126,27 +113,21 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
GroupID: req.GroupID,
CustomKey: req.CustomKey,
}
key, err := h.apiKeyService.Create(c.Request.Context(), user.ID, svcReq)
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
if err != nil {
response.InternalError(c, "Failed to create API key: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, key)
response.Success(c, dto.ApiKeyFromService(key))
}
// Update handles updating an API key
// PUT /api/v1/api-keys/:id
func (h *APIKeyHandler) Update(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -171,27 +152,21 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
svcReq.Status = &req.Status
}
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, user.ID, svcReq)
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, subject.UserID, svcReq)
if err != nil {
response.InternalError(c, "Failed to update API key: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, key)
response.Success(c, dto.ApiKeyFromService(key))
}
// Delete handles deleting an API key
// DELETE /api/v1/api-keys/:id
func (h *APIKeyHandler) Delete(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -201,9 +176,9 @@ func (h *APIKeyHandler) Delete(c *gin.Context) {
return
}
err = h.apiKeyService.Delete(c.Request.Context(), keyID, user.ID)
err = h.apiKeyService.Delete(c.Request.Context(), keyID, subject.UserID)
if err != nil {
response.InternalError(c, "Failed to delete API key: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -213,23 +188,21 @@ func (h *APIKeyHandler) Delete(c *gin.Context) {
// GetAvailableGroups 获取用户可以绑定的分组列表
// GET /api/v1/groups/available
func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), user.ID)
groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), subject.UserID)
if err != nil {
response.InternalError(c, "Failed to get available groups: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, groups)
out := make([]dto.Group, 0, len(groups))
for i := range groups {
out = append(out, *dto.GroupFromService(&groups[i]))
}
response.Success(c, out)
}

View File

@@ -1,8 +1,9 @@
package handler
import (
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -11,12 +12,14 @@ import (
// AuthHandler handles authentication-related requests
type AuthHandler struct {
authService *service.AuthService
userService *service.UserService
}
// NewAuthHandler creates a new AuthHandler
func NewAuthHandler(authService *service.AuthService) *AuthHandler {
func NewAuthHandler(authService *service.AuthService, userService *service.UserService) *AuthHandler {
return &AuthHandler{
authService: authService,
userService: userService,
}
}
@@ -49,9 +52,9 @@ type LoginRequest struct {
// AuthResponse 认证响应格式(匹配前端期望)
type AuthResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
User *model.User `json:"user"`
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
User *dto.User `json:"user"`
}
// Register handles user registration
@@ -66,21 +69,21 @@ func (h *AuthHandler) Register(c *gin.Context) {
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
if req.VerifyCode == "" {
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
response.ErrorFrom(c, err)
return
}
}
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode)
if err != nil {
response.BadRequest(c, "Registration failed: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
User: user,
User: dto.UserFromService(user),
})
}
@@ -95,13 +98,13 @@ func (h *AuthHandler) SendVerifyCode(c *gin.Context) {
// Turnstile 验证
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
response.ErrorFrom(c, err)
return
}
result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email)
if err != nil {
response.BadRequest(c, "Failed to send verification code: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -122,37 +125,37 @@ func (h *AuthHandler) Login(c *gin.Context) {
// Turnstile 验证
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
response.ErrorFrom(c, err)
return
}
token, user, err := h.authService.Login(c.Request.Context(), req.Email, req.Password)
if err != nil {
response.Unauthorized(c, "Login failed: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
User: user,
User: dto.UserFromService(user),
})
}
// GetCurrentUser handles getting current authenticated user
// GET /api/v1/auth/me
func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
user, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, user)
response.Success(c, dto.UserFromService(user))
}

View File

@@ -0,0 +1,310 @@
package dto
import "github.com/Wei-Shaw/sub2api/internal/service"
func UserFromServiceShallow(u *service.User) *User {
if u == nil {
return nil
}
return &User{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Wechat: u.Wechat,
Notes: u.Notes,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
AllowedGroups: u.AllowedGroups,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}
func UserFromService(u *service.User) *User {
if u == nil {
return nil
}
out := UserFromServiceShallow(u)
if len(u.ApiKeys) > 0 {
out.ApiKeys = make([]ApiKey, 0, len(u.ApiKeys))
for i := range u.ApiKeys {
k := u.ApiKeys[i]
out.ApiKeys = append(out.ApiKeys, *ApiKeyFromService(&k))
}
}
if len(u.Subscriptions) > 0 {
out.Subscriptions = make([]UserSubscription, 0, len(u.Subscriptions))
for i := range u.Subscriptions {
s := u.Subscriptions[i]
out.Subscriptions = append(out.Subscriptions, *UserSubscriptionFromService(&s))
}
}
return out
}
func ApiKeyFromService(k *service.ApiKey) *ApiKey {
if k == nil {
return nil
}
return &ApiKey{
ID: k.ID,
UserID: k.UserID,
Key: k.Key,
Name: k.Name,
GroupID: k.GroupID,
Status: k.Status,
CreatedAt: k.CreatedAt,
UpdatedAt: k.UpdatedAt,
User: UserFromServiceShallow(k.User),
Group: GroupFromServiceShallow(k.Group),
}
}
func GroupFromServiceShallow(g *service.Group) *Group {
if g == nil {
return nil
}
return &Group{
ID: g.ID,
Name: g.Name,
Description: g.Description,
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUSD,
WeeklyLimitUSD: g.WeeklyLimitUSD,
MonthlyLimitUSD: g.MonthlyLimitUSD,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
AccountCount: g.AccountCount,
}
}
func GroupFromService(g *service.Group) *Group {
if g == nil {
return nil
}
out := GroupFromServiceShallow(g)
if len(g.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
for i := range g.AccountGroups {
ag := g.AccountGroups[i]
out.AccountGroups = append(out.AccountGroups, *AccountGroupFromService(&ag))
}
}
return out
}
func AccountFromServiceShallow(a *service.Account) *Account {
if a == nil {
return nil
}
return &Account{
ID: a.ID,
Name: a.Name,
Platform: a.Platform,
Type: a.Type,
Credentials: a.Credentials,
Extra: a.Extra,
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
Priority: a.Priority,
Status: a.Status,
ErrorMessage: a.ErrorMessage,
LastUsedAt: a.LastUsedAt,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
Schedulable: a.Schedulable,
RateLimitedAt: a.RateLimitedAt,
RateLimitResetAt: a.RateLimitResetAt,
OverloadUntil: a.OverloadUntil,
SessionWindowStart: a.SessionWindowStart,
SessionWindowEnd: a.SessionWindowEnd,
SessionWindowStatus: a.SessionWindowStatus,
GroupIDs: a.GroupIDs,
}
}
func AccountFromService(a *service.Account) *Account {
if a == nil {
return nil
}
out := AccountFromServiceShallow(a)
out.Proxy = ProxyFromService(a.Proxy)
if len(a.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(a.AccountGroups))
for i := range a.AccountGroups {
ag := a.AccountGroups[i]
out.AccountGroups = append(out.AccountGroups, *AccountGroupFromService(&ag))
}
}
if len(a.Groups) > 0 {
out.Groups = make([]*Group, 0, len(a.Groups))
for _, g := range a.Groups {
out.Groups = append(out.Groups, GroupFromServiceShallow(g))
}
}
return out
}
func AccountGroupFromService(ag *service.AccountGroup) *AccountGroup {
if ag == nil {
return nil
}
return &AccountGroup{
AccountID: ag.AccountID,
GroupID: ag.GroupID,
Priority: ag.Priority,
CreatedAt: ag.CreatedAt,
Account: AccountFromServiceShallow(ag.Account),
Group: GroupFromServiceShallow(ag.Group),
}
}
func ProxyFromService(p *service.Proxy) *Proxy {
if p == nil {
return nil
}
return &Proxy{
ID: p.ID,
Name: p.Name,
Protocol: p.Protocol,
Host: p.Host,
Port: p.Port,
Username: p.Username,
Password: p.Password,
Status: p.Status,
CreatedAt: p.CreatedAt,
UpdatedAt: p.UpdatedAt,
}
}
func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWithAccountCount {
if p == nil {
return nil
}
return &ProxyWithAccountCount{
Proxy: *ProxyFromService(&p.Proxy),
AccountCount: p.AccountCount,
}
}
func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
if rc == nil {
return nil
}
return &RedeemCode{
ID: rc.ID,
Code: rc.Code,
Type: rc.Type,
Value: rc.Value,
Status: rc.Status,
UsedBy: rc.UsedBy,
UsedAt: rc.UsedAt,
Notes: rc.Notes,
CreatedAt: rc.CreatedAt,
GroupID: rc.GroupID,
ValidityDays: rc.ValidityDays,
User: UserFromServiceShallow(rc.User),
Group: GroupFromServiceShallow(rc.Group),
}
}
func UsageLogFromService(l *service.UsageLog) *UsageLog {
if l == nil {
return nil
}
return &UsageLog{
ID: l.ID,
UserID: l.UserID,
ApiKeyID: l.ApiKeyID,
AccountID: l.AccountID,
RequestID: l.RequestID,
Model: l.Model,
GroupID: l.GroupID,
SubscriptionID: l.SubscriptionID,
InputTokens: l.InputTokens,
OutputTokens: l.OutputTokens,
CacheCreationTokens: l.CacheCreationTokens,
CacheReadTokens: l.CacheReadTokens,
CacheCreation5mTokens: l.CacheCreation5mTokens,
CacheCreation1hTokens: l.CacheCreation1hTokens,
InputCost: l.InputCost,
OutputCost: l.OutputCost,
CacheCreationCost: l.CacheCreationCost,
CacheReadCost: l.CacheReadCost,
TotalCost: l.TotalCost,
ActualCost: l.ActualCost,
RateMultiplier: l.RateMultiplier,
BillingType: l.BillingType,
Stream: l.Stream,
DurationMs: l.DurationMs,
FirstTokenMs: l.FirstTokenMs,
CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User),
ApiKey: ApiKeyFromService(l.ApiKey),
Account: AccountFromService(l.Account),
Group: GroupFromServiceShallow(l.Group),
Subscription: UserSubscriptionFromService(l.Subscription),
}
}
func SettingFromService(s *service.Setting) *Setting {
if s == nil {
return nil
}
return &Setting{
ID: s.ID,
Key: s.Key,
Value: s.Value,
UpdatedAt: s.UpdatedAt,
}
}
func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscription {
if sub == nil {
return nil
}
return &UserSubscription{
ID: sub.ID,
UserID: sub.UserID,
GroupID: sub.GroupID,
StartsAt: sub.StartsAt,
ExpiresAt: sub.ExpiresAt,
Status: sub.Status,
DailyWindowStart: sub.DailyWindowStart,
WeeklyWindowStart: sub.WeeklyWindowStart,
MonthlyWindowStart: sub.MonthlyWindowStart,
DailyUsageUSD: sub.DailyUsageUSD,
WeeklyUsageUSD: sub.WeeklyUsageUSD,
MonthlyUsageUSD: sub.MonthlyUsageUSD,
AssignedBy: sub.AssignedBy,
AssignedAt: sub.AssignedAt,
Notes: sub.Notes,
CreatedAt: sub.CreatedAt,
UpdatedAt: sub.UpdatedAt,
User: UserFromServiceShallow(sub.User),
Group: GroupFromServiceShallow(sub.Group),
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
}
}
func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult {
if r == nil {
return nil
}
subs := make([]UserSubscription, 0, len(r.Subscriptions))
for i := range r.Subscriptions {
subs = append(subs, *UserSubscriptionFromService(&r.Subscriptions[i]))
}
return &BulkAssignResult{
SuccessCount: r.SuccessCount,
FailedCount: r.FailedCount,
Subscriptions: subs,
Errors: r.Errors,
}
}

View File

@@ -0,0 +1,43 @@
package dto
// SystemSettings represents the admin settings API response payload.
type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
SmtpHost string `json:"smtp_host"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password,omitempty"`
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
}
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
Version string `json:"version"`
}

View File

@@ -0,0 +1,219 @@
package dto
import "time"
type User struct {
ID int64 `json:"id"`
Email string `json:"email"`
Username string `json:"username"`
Wechat string `json:"wechat"`
Notes string `json:"notes"`
Role string `json:"role"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
Status string `json:"status"`
AllowedGroups []int64 `json:"allowed_groups"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ApiKeys []ApiKey `json:"api_keys,omitempty"`
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
}
type ApiKey struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
Key string `json:"key"`
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
}
type Group struct {
ID int64 `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
Status string `json:"status"`
SubscriptionType string `json:"subscription_type"`
DailyLimitUSD *float64 `json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
AccountCount int64 `json:"account_count,omitempty"`
}
type Account struct {
ID int64 `json:"id"`
Name string `json:"name"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
Status string `json:"status"`
ErrorMessage string `json:"error_message"`
LastUsedAt *time.Time `json:"last_used_at"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Schedulable bool `json:"schedulable"`
RateLimitedAt *time.Time `json:"rate_limited_at"`
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
OverloadUntil *time.Time `json:"overload_until"`
SessionWindowStart *time.Time `json:"session_window_start"`
SessionWindowEnd *time.Time `json:"session_window_end"`
SessionWindowStatus string `json:"session_window_status"`
Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
GroupIDs []int64 `json:"group_ids,omitempty"`
Groups []*Group `json:"groups,omitempty"`
}
type AccountGroup struct {
AccountID int64 `json:"account_id"`
GroupID int64 `json:"group_id"`
Priority int `json:"priority"`
CreatedAt time.Time `json:"created_at"`
Account *Account `json:"account,omitempty"`
Group *Group `json:"group,omitempty"`
}
type Proxy struct {
ID int64 `json:"id"`
Name string `json:"name"`
Protocol string `json:"protocol"`
Host string `json:"host"`
Port int `json:"port"`
Username string `json:"username"`
Password string `json:"-"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type ProxyWithAccountCount struct {
Proxy
AccountCount int64 `json:"account_count"`
}
type RedeemCode struct {
ID int64 `json:"id"`
Code string `json:"code"`
Type string `json:"type"`
Value float64 `json:"value"`
Status string `json:"status"`
UsedBy *int64 `json:"used_by"`
UsedAt *time.Time `json:"used_at"`
Notes string `json:"notes"`
CreatedAt time.Time `json:"created_at"`
GroupID *int64 `json:"group_id"`
ValidityDays int `json:"validity_days"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
}
type UsageLog struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
ApiKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"`
Model string `json:"model"`
GroupID *int64 `json:"group_id"`
SubscriptionID *int64 `json:"subscription_id"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationTokens int `json:"cache_creation_tokens"`
CacheReadTokens int `json:"cache_read_tokens"`
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
InputCost float64 `json:"input_cost"`
OutputCost float64 `json:"output_cost"`
CacheCreationCost float64 `json:"cache_creation_cost"`
CacheReadCost float64 `json:"cache_read_cost"`
TotalCost float64 `json:"total_cost"`
ActualCost float64 `json:"actual_cost"`
RateMultiplier float64 `json:"rate_multiplier"`
BillingType int8 `json:"billing_type"`
Stream bool `json:"stream"`
DurationMs *int `json:"duration_ms"`
FirstTokenMs *int `json:"first_token_ms"`
CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"`
ApiKey *ApiKey `json:"api_key,omitempty"`
Account *Account `json:"account,omitempty"`
Group *Group `json:"group,omitempty"`
Subscription *UserSubscription `json:"subscription,omitempty"`
}
type Setting struct {
ID int64 `json:"id"`
Key string `json:"key"`
Value string `json:"value"`
UpdatedAt time.Time `json:"updated_at"`
}
type UserSubscription struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
GroupID int64 `json:"group_id"`
StartsAt time.Time `json:"starts_at"`
ExpiresAt time.Time `json:"expires_at"`
Status string `json:"status"`
DailyWindowStart *time.Time `json:"daily_window_start"`
WeeklyWindowStart *time.Time `json:"weekly_window_start"`
MonthlyWindowStart *time.Time `json:"monthly_window_start"`
DailyUsageUSD float64 `json:"daily_usage_usd"`
WeeklyUsageUSD float64 `json:"weekly_usage_usd"`
MonthlyUsageUSD float64 `json:"monthly_usage_usd"`
AssignedBy *int64 `json:"assigned_by"`
AssignedAt time.Time `json:"assigned_at"`
Notes string `json:"notes"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
AssignedByUser *User `json:"assigned_by_user,omitempty"`
}
type BulkAssignResult struct {
SuccessCount int `json:"success_count"`
FailedCount int `json:"failed_count"`
Subscriptions []UserSubscription `json:"subscriptions"`
Errors []string `json:"errors"`
}

View File

@@ -3,6 +3,7 @@ package handler
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
@@ -10,10 +11,9 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/middleware"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -22,15 +22,23 @@ import (
// GatewayHandler handles API gateway requests
type GatewayHandler struct {
gatewayService *service.GatewayService
geminiCompatService *service.GeminiMessagesCompatService
userService *service.UserService
billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
}
// NewGatewayHandler creates a new GatewayHandler
func NewGatewayHandler(gatewayService *service.GatewayService, userService *service.UserService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService) *GatewayHandler {
func NewGatewayHandler(
gatewayService *service.GatewayService,
geminiCompatService *service.GeminiMessagesCompatService,
userService *service.UserService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
) *GatewayHandler {
return &GatewayHandler{
gatewayService: gatewayService,
geminiCompatService: geminiCompatService,
userService: userService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude),
@@ -41,13 +49,13 @@ func NewGatewayHandler(gatewayService *service.GatewayService, userService *serv
// POST /v1/messages
func (h *GatewayHandler) Messages(c *gin.Context) {
// 从context获取apiKey和userApiKeyAuth中间件已设置
apiKey, ok := middleware.GetApiKeyFromContext(c)
apiKey, ok := middleware2.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
user, ok := middleware.GetUserFromContext(c)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
@@ -79,11 +87,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
streamStarted := false
// 获取订阅信息可能为nil- 提前获取用于后续检查
subscription, _ := middleware.GetSubscriptionFromContext(c)
subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 0. 检查wait队列是否已满
maxWait := service.CalculateMaxWait(user.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
if err != nil {
log.Printf("Increment wait count failed: %v", err)
// On error, allow request to proceed
@@ -92,10 +100,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 确保在函数退出时减少wait计数
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID)
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
// 1. 首先获取用户并发槽位
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, req.Stream, &streamStarted)
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, req.Stream, &streamStarted)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted)
@@ -106,7 +114,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
return
@@ -115,63 +123,170 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 计算粘性会话hash
sessionHash := h.gatewayService.GenerateSessionHash(body)
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
if err != nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
platform := ""
if apiKey.Group != nil {
platform = apiKey.Group.Platform
}
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if req.Stream {
sendMockWarmupStream(c, req.Model)
} else {
sendMockWarmupResponse(c, req.Model)
if platform == service.PlatformGemini {
const maxAccountSwitches = 3
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
for {
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if req.Stream {
sendMockWarmupStream(c, req.Model)
} else {
sendMockWarmupResponse(c, req.Model)
}
return
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
// 转发请求
result, err := h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
lastFailoverStatus = failoverErr.StatusCode
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
}
// 错误响应已在Forward中处理这里只记录日志
log.Printf("Forward request failed: %v", err)
return
}
// 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account)
return
}
return
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, req.Stream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountReleaseFunc != nil {
defer accountReleaseFunc()
}
const maxAccountSwitches = 3
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
// 转发请求
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
if err != nil {
// 错误响应已在Forward中处理这里只记录日志
log.Printf("Forward request failed: %v", err)
return
}
// 异步记录使用量subscription已在函数开头获取
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
User: user,
Account: account,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
for {
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
}()
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if req.Stream {
sendMockWarmupStream(c, req.Model)
} else {
sendMockWarmupResponse(c, req.Model)
}
return
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
// 转发请求
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
lastFailoverStatus = failoverErr.StatusCode
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
}
// 错误响应已在Forward中处理这里只记录日志
log.Printf("Forward request failed: %v", err)
return
}
// 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account)
return
}
}
// Models handles listing available models
// GET /v1/models
// Returns different model lists based on the API key's group platform
func (h *GatewayHandler) Models(c *gin.Context) {
apiKey, _ := middleware.GetApiKeyFromContext(c)
apiKey, _ := middleware2.GetApiKeyFromContext(c)
// Return OpenAI models for OpenAI platform groups
if apiKey != nil && apiKey.Group != nil && apiKey.Group.Platform == "openai" {
@@ -192,13 +307,13 @@ func (h *GatewayHandler) Models(c *gin.Context) {
// Usage handles getting account balance for CC Switch integration
// GET /v1/usage
func (h *GatewayHandler) Usage(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
apiKey, ok := middleware2.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
user, ok := middleware.GetUserFromContext(c)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
@@ -206,7 +321,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
// 订阅模式:返回订阅限额信息
if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() {
subscription, ok := middleware.GetSubscriptionFromContext(c)
subscription, ok := middleware2.GetSubscriptionFromContext(c)
if !ok {
h.errorResponse(c, http.StatusForbidden, "subscription_error", "No active subscription")
return
@@ -223,7 +338,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
}
// 余额模式:返回钱包余额
latestUser, err := h.userService.GetByID(c.Request.Context(), user.ID)
latestUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info")
return
@@ -241,7 +356,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
// 逻辑:
// 1. 如果日/周/月任一限额达到100%返回0
// 2. 否则返回所有已配置周期中剩余额度的最小值
func (h *GatewayHandler) calculateSubscriptionRemaining(group *model.Group, sub *model.UserSubscription) float64 {
func (h *GatewayHandler) calculateSubscriptionRemaining(group *service.Group, sub *service.UserSubscription) float64 {
var remainingValues []float64
// 检查日限额
@@ -292,6 +407,28 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
func (h *GatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
switch statusCode {
case 401:
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
case 403:
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
case 429:
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
case 529:
return http.StatusServiceUnavailable, "overloaded_error", "Upstream service overloaded, please retry later"
case 500, 502, 503, 504:
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
default:
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
}
}
// handleStreamingAwareError handles errors that may occur after streaming has started
func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {
@@ -328,13 +465,13 @@ func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, mess
// 特点:校验订阅/余额,但不计算并发、不记录使用量
func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 从context获取apiKey和userApiKeyAuth中间件已设置
apiKey, ok := middleware.GetApiKeyFromContext(c)
apiKey, ok := middleware2.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
user, ok := middleware.GetUserFromContext(c)
_, ok = middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
@@ -362,11 +499,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
}
// 获取订阅信息可能为nil
subscription, _ := middleware.GetSubscriptionFromContext(c)
subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 校验 billing eligibility订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
h.errorResponse(c, http.StatusForbidden, "billing_error", err.Error())
return
}

View File

@@ -6,7 +6,6 @@ import (
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -69,11 +68,11 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun.
func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) {
func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency)
result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency)
if err != nil {
return nil, err
}
@@ -83,17 +82,17 @@ func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, user *model.
}
// Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted)
return h.waitForSlotWithPing(c, "user", userID, maxConcurrency, isStream, streamStarted)
}
// AcquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun.
func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) {
func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency)
result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
if err != nil {
return nil, err
}
@@ -103,7 +102,7 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, account *
}
// Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted)
return h.waitForSlotWithPing(c, "account", accountID, maxConcurrency, isStream, streamStarted)
}
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.

View File

@@ -0,0 +1,320 @@
package handler
import (
"context"
"errors"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// GeminiV1BetaListModels proxies:
// GET /v1beta/models
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok || apiKey == nil {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
}
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return
}
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
if err != nil {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
res, err := h.geminiCompatService.ForwardAIStudioGET(c.Request.Context(), account, "/v1beta/models")
if err != nil {
googleError(c, http.StatusBadGateway, err.Error())
return
}
if shouldFallbackGeminiModels(res) {
c.JSON(http.StatusOK, gemini.FallbackModelsList())
return
}
writeUpstreamResponse(c, res)
}
// GeminiV1BetaGetModel proxies:
// GET /v1beta/models/{model}
func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok || apiKey == nil {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
}
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return
}
modelName := strings.TrimSpace(c.Param("model"))
if modelName == "" {
googleError(c, http.StatusBadRequest, "Missing model in URL")
return
}
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
if err != nil {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
res, err := h.geminiCompatService.ForwardAIStudioGET(c.Request.Context(), account, "/v1beta/models/"+modelName)
if err != nil {
googleError(c, http.StatusBadGateway, err.Error())
return
}
if shouldFallbackGeminiModels(res) {
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
return
}
writeUpstreamResponse(c, res)
}
// GeminiV1BetaModels proxies Gemini native REST endpoints like:
// POST /v1beta/models/{model}:generateContent
// POST /v1beta/models/{model}:streamGenerateContent?alt=sse
func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok || apiKey == nil {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
}
authSubject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok {
googleError(c, http.StatusInternalServerError, "User context not found")
return
}
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return
}
modelName, action, err := parseGeminiModelAction(strings.TrimPrefix(c.Param("modelAction"), "/"))
if err != nil {
googleError(c, http.StatusNotFound, err.Error())
return
}
stream := action == "streamGenerateContent"
body, err := io.ReadAll(c.Request.Body)
if err != nil {
googleError(c, http.StatusBadRequest, "Failed to read request body")
return
}
if len(body) == 0 {
googleError(c, http.StatusBadRequest, "Request body is empty")
return
}
// Get subscription (may be nil)
subscription, _ := middleware.GetSubscriptionFromContext(c)
// For Gemini native API, do not send Claude-style ping frames.
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone)
// 0) wait queue check
maxWait := service.CalculateMaxWait(authSubject.Concurrency)
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait)
if err != nil {
log.Printf("Increment wait count failed: %v", err)
} else if !canWait {
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
}
defer geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID)
// 1) user concurrency slot
streamStarted := false
userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted)
if err != nil {
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
googleError(c, http.StatusForbidden, err.Error())
return
}
// 3) select account (sticky session based on request body)
sessionHash := h.gatewayService.GenerateSessionHash(body)
const maxAccountSwitches = 3
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
for {
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, modelName, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
handleGeminiFailoverExhausted(c, lastFailoverStatus)
return
}
// 4) account concurrency slot
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted)
if err != nil {
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
// 5) forward (writes response to client)
result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
handleGeminiFailoverExhausted(c, lastFailoverStatus)
return
}
lastFailoverStatus = failoverErr.StatusCode
switchCount++
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
}
// ForwardNative already wrote the response
log.Printf("Gemini native forward failed: %v", err)
return
}
// 6) record usage async
go func(result *service.ForwardResult, usedAccount *service.Account) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account)
return
}
}
func parseGeminiModelAction(rest string) (model string, action string, err error) {
rest = strings.TrimSpace(rest)
if rest == "" {
return "", "", &pathParseError{"missing path"}
}
// Standard: {model}:{action}
if i := strings.Index(rest, ":"); i > 0 && i < len(rest)-1 {
return rest[:i], rest[i+1:], nil
}
// Fallback: {model}/{action}
if i := strings.Index(rest, "/"); i > 0 && i < len(rest)-1 {
return rest[:i], rest[i+1:], nil
}
return "", "", &pathParseError{"invalid model action path"}
}
func handleGeminiFailoverExhausted(c *gin.Context, statusCode int) {
status, message := mapGeminiUpstreamError(statusCode)
googleError(c, status, message)
}
func mapGeminiUpstreamError(statusCode int) (int, string) {
switch statusCode {
case 401:
return http.StatusBadGateway, "Upstream authentication failed, please contact administrator"
case 403:
return http.StatusBadGateway, "Upstream access forbidden, please contact administrator"
case 429:
return http.StatusTooManyRequests, "Upstream rate limit exceeded, please retry later"
case 529:
return http.StatusServiceUnavailable, "Upstream service overloaded, please retry later"
case 500, 502, 503, 504:
return http.StatusBadGateway, "Upstream service temporarily unavailable"
default:
return http.StatusBadGateway, "Upstream request failed"
}
}
type pathParseError struct{ msg string }
func (e *pathParseError) Error() string { return e.msg }
func googleError(c *gin.Context, status int, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"code": status,
"message": message,
"status": googleapi.HTTPStatusToGoogleStatus(status),
},
})
}
func writeUpstreamResponse(c *gin.Context, res *service.UpstreamHTTPResult) {
if res == nil {
googleError(c, http.StatusBadGateway, "Empty upstream response")
return
}
for k, vv := range res.Headers {
// Avoid overriding content-length and hop-by-hop headers.
if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") {
continue
}
for _, v := range vv {
c.Writer.Header().Add(k, v)
}
}
contentType := res.Headers.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(res.StatusCode, contentType, res.Body)
}
func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
if res == nil {
return true
}
if res.StatusCode != http.StatusUnauthorized && res.StatusCode != http.StatusForbidden {
return false
}
if strings.Contains(strings.ToLower(res.Headers.Get("Www-Authenticate")), "insufficient_scope") {
return true
}
if strings.Contains(strings.ToLower(string(res.Body)), "insufficient authentication scopes") {
return true
}
if strings.Contains(strings.ToLower(string(res.Body)), "access_token_scope_insufficient") {
return true
}
return false
}

View File

@@ -12,6 +12,7 @@ type AdminHandlers struct {
Account *admin.AccountHandler
OAuth *admin.OAuthHandler
OpenAIOAuth *admin.OpenAIOAuthHandler
GeminiOAuth *admin.GeminiOAuthHandler
Proxy *admin.ProxyHandler
Redeem *admin.RedeemHandler
Setting *admin.SettingHandler

View File

@@ -3,14 +3,15 @@ package handler
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/middleware"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -40,13 +41,13 @@ func NewOpenAIGatewayHandler(
// POST /openai/v1/responses
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Get apiKey and user from context (set by ApiKeyAuth middleware)
apiKey, ok := middleware.GetApiKeyFromContext(c)
apiKey, ok := middleware2.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
user, ok := middleware.GetUserFromContext(c)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
@@ -91,11 +92,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
streamStarted := false
// Get subscription info (may be nil)
subscription, _ := middleware.GetSubscriptionFromContext(c)
subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 0. Check if wait queue is full
maxWait := service.CalculateMaxWait(user.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
if err != nil {
log.Printf("Increment wait count failed: %v", err)
// On error, allow request to proceed
@@ -104,10 +105,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
// Ensure wait count is decremented when function exits
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID)
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
// 1. First acquire user concurrency slot
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, reqStream, &streamStarted)
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted)
@@ -118,7 +119,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
// 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
return
@@ -127,49 +128,74 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Generate session hash (from header for OpenAI)
sessionHash := h.gatewayService.GenerateSessionHash(c)
// Select account supporting the requested model
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel)
if err != nil {
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
const maxAccountSwitches = 3
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
// 3. Acquire account concurrency slot
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, reqStream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountReleaseFunc != nil {
defer accountReleaseFunc()
}
// Forward request
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
if err != nil {
// Error response already handled in Forward, just log
log.Printf("Forward request failed: %v", err)
return
}
// Async record usage
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
ApiKey: apiKey,
User: user,
Account: account,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
for {
// Select account supporting the requested model
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
if err != nil {
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
}()
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
// 3. Acquire account concurrency slot
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
// Forward request
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
lastFailoverStatus = failoverErr.StatusCode
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
}
// Error response already handled in Forward, just log
log.Printf("Forward request failed: %v", err)
return
}
// Async record usage
go func(result *service.OpenAIForwardResult, usedAccount *service.Account) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
ApiKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account)
return
}
}
// handleConcurrencyError handles concurrency-related errors with proper 429 response
@@ -178,6 +204,28 @@ func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error,
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
func (h *OpenAIGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
switch statusCode {
case 401:
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
case 403:
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
case 429:
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
case 529:
return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later"
case 500, 502, 503, 504:
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
default:
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
}
}
// handleStreamingAwareError handles errors that may occur after streaming has started
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {

View File

@@ -1,8 +1,9 @@
package handler
import (
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -37,15 +38,9 @@ type RedeemResponse struct {
// Redeem handles redeeming a code
// POST /api/v1/redeem
func (h *RedeemHandler) Redeem(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -55,38 +50,36 @@ func (h *RedeemHandler) Redeem(c *gin.Context) {
return
}
result, err := h.redeemService.Redeem(c.Request.Context(), user.ID, req.Code)
result, err := h.redeemService.Redeem(c.Request.Context(), subject.UserID, req.Code)
if err != nil {
response.BadRequest(c, "Failed to redeem code: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
response.Success(c, dto.RedeemCodeFromService(result))
}
// GetHistory returns the user's redemption history
// GET /api/v1/redeem/history
func (h *RedeemHandler) GetHistory(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
// Default limit is 25
limit := 25
codes, err := h.redeemService.GetUserHistory(c.Request.Context(), user.ID, limit)
codes, err := h.redeemService.GetUserHistory(c.Request.Context(), subject.UserID, limit)
if err != nil {
response.InternalError(c, "Failed to get history: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, codes)
out := make([]dto.RedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
}
response.Success(c, out)
}

View File

@@ -1,6 +1,7 @@
package handler
import (
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -26,10 +27,21 @@ func NewSettingHandler(settingService *service.SettingService, version string) *
func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
settings, err := h.settingService.GetPublicSettings(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get settings: "+err.Error())
response.ErrorFrom(c, err)
return
}
settings.Version = h.version
response.Success(c, settings)
response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
ApiBaseUrl: settings.ApiBaseUrl,
ContactInfo: settings.ContactInfo,
DocUrl: settings.DocUrl,
Version: h.version,
})
}

View File

@@ -1,8 +1,9 @@
package handler
import (
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -25,7 +26,7 @@ type SubscriptionSummaryItem struct {
// SubscriptionProgressInfo represents subscription with progress info
type SubscriptionProgressInfo struct {
Subscription *model.UserSubscription `json:"subscription"`
Subscription *dto.UserSubscription `json:"subscription"`
Progress *service.SubscriptionProgress `json:"progress"`
}
@@ -44,70 +45,60 @@ func NewSubscriptionHandler(subscriptionService *service.SubscriptionService) *S
// List handles listing current user's subscriptions
// GET /api/v1/subscriptions
func (h *SubscriptionHandler) List(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), u.ID)
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), subject.UserID)
if err != nil {
response.InternalError(c, "Failed to list subscriptions: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, subscriptions)
out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.Success(c, out)
}
// GetActive handles getting current user's active subscriptions
// GET /api/v1/subscriptions/active
func (h *SubscriptionHandler) GetActive(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
if err != nil {
response.InternalError(c, "Failed to get active subscriptions: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, subscriptions)
out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.Success(c, out)
}
// GetProgress handles getting subscription progress for current user
// GET /api/v1/subscriptions/progress
func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
// Get all active subscriptions with progress
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
if err != nil {
response.InternalError(c, "Failed to get subscriptions: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -120,7 +111,7 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
continue
}
result = append(result, SubscriptionProgressInfo{
Subscription: sub,
Subscription: dto.UserSubscriptionFromService(sub),
Progress: progress,
})
}
@@ -131,22 +122,16 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
// GetSummary handles getting a summary of current user's subscription status
// GET /api/v1/subscriptions/summary
func (h *SubscriptionHandler) GetSummary(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
// Get all active subscriptions
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
if err != nil {
response.InternalError(c, "Failed to get subscriptions: "+err.Error())
response.ErrorFrom(c, err)
return
}

View File

@@ -4,10 +4,12 @@ import (
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -30,15 +32,9 @@ func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.
// List handles listing usage records with pagination
// GET /api/v1/usage
func (h *UsageHandler) List(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -55,10 +51,10 @@ func (h *UsageHandler) List(c *gin.Context) {
// [Security Fix] Verify API Key ownership to prevent horizontal privilege escalation
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id)
if err != nil {
response.NotFound(c, "API key not found")
response.ErrorFrom(c, err)
return
}
if apiKey.UserID != user.ID {
if apiKey.UserID != subject.UserID {
response.Forbidden(c, "Not authorized to access this API key's usage records")
return
}
@@ -66,36 +62,82 @@ func (h *UsageHandler) List(c *gin.Context) {
apiKeyID = id
}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
var records []model.UsageLog
var result *pagination.PaginationResult
var err error
// Parse additional filters
model := c.Query("model")
if apiKeyID > 0 {
records, result, err = h.usageService.ListByApiKey(c.Request.Context(), apiKeyID, params)
} else {
records, result, err = h.usageService.ListByUser(c.Request.Context(), user.ID, params)
var stream *bool
if streamStr := c.Query("stream"); streamStr != "" {
val, err := strconv.ParseBool(streamStr)
if err != nil {
response.BadRequest(c, "Invalid stream value, use true or false")
return
}
stream = &val
}
var billingType *int8
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
val, err := strconv.ParseInt(billingTypeStr, 10, 8)
if err != nil {
response.BadRequest(c, "Invalid billing_type")
return
}
bt := int8(val)
billingType = &bt
}
// Parse date range
var startTime, endTime *time.Time
if startDateStr := c.Query("start_date"); startDateStr != "" {
t, err := timezone.ParseInLocation("2006-01-02", startDateStr)
if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return
}
startTime = &t
}
if endDateStr := c.Query("end_date"); endDateStr != "" {
t, err := timezone.ParseInLocation("2006-01-02", endDateStr)
if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
// Set end time to end of day
t = t.Add(24*time.Hour - time.Nanosecond)
endTime = &t
}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
filters := usagestats.UsageLogFilters{
UserID: subject.UserID, // Always filter by current user for security
ApiKeyID: apiKeyID,
Model: model,
Stream: stream,
BillingType: billingType,
StartTime: startTime,
EndTime: endTime,
}
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
if err != nil {
response.InternalError(c, "Failed to list usage records: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Paginated(c, records, result.Total, page, pageSize)
out := make([]dto.UsageLog, 0, len(records))
for i := range records {
out = append(out, *dto.UsageLogFromService(&records[i]))
}
response.Paginated(c, out, result.Total, page, pageSize)
}
// GetByID handles getting a single usage record
// GET /api/v1/usage/:id
func (h *UsageHandler) GetByID(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -107,31 +149,25 @@ func (h *UsageHandler) GetByID(c *gin.Context) {
record, err := h.usageService.GetByID(c.Request.Context(), usageID)
if err != nil {
response.NotFound(c, "Usage record not found")
response.ErrorFrom(c, err)
return
}
// 验证所有权
if record.UserID != user.ID {
if record.UserID != subject.UserID {
response.Forbidden(c, "Not authorized to access this record")
return
}
response.Success(c, record)
response.Success(c, dto.UsageLogFromService(record))
}
// Stats handles getting usage statistics
// GET /api/v1/usage/stats
func (h *UsageHandler) Stats(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -149,7 +185,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
response.NotFound(c, "API key not found")
return
}
if apiKey.UserID != user.ID {
if apiKey.UserID != subject.UserID {
response.Forbidden(c, "Not authorized to access this API key's statistics")
return
}
@@ -201,10 +237,10 @@ func (h *UsageHandler) Stats(c *gin.Context) {
if apiKeyID > 0 {
stats, err = h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
} else {
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), user.ID, startTime, endTime)
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), subject.UserID, startTime, endTime)
}
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -245,21 +281,15 @@ func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) {
// DashboardStats handles getting user dashboard statistics
// GET /api/v1/usage/dashboard/stats
func (h *UsageHandler) DashboardStats(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), user.ID)
stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID)
if err != nil {
response.InternalError(c, "Failed to get dashboard statistics")
response.ErrorFrom(c, err)
return
}
@@ -269,24 +299,18 @@ func (h *UsageHandler) DashboardStats(c *gin.Context) {
// DashboardTrend handles getting user usage trend data
// GET /api/v1/usage/dashboard/trend
func (h *UsageHandler) DashboardTrend(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
startTime, endTime := parseUserTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity)
trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), subject.UserID, startTime, endTime, granularity)
if err != nil {
response.InternalError(c, "Failed to get usage trend")
response.ErrorFrom(c, err)
return
}
@@ -301,23 +325,17 @@ func (h *UsageHandler) DashboardTrend(c *gin.Context) {
// DashboardModels handles getting user model usage statistics
// GET /api/v1/usage/dashboard/models
func (h *UsageHandler) DashboardModels(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
startTime, endTime := parseUserTimeRange(c)
stats, err := h.usageService.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime)
stats, err := h.usageService.GetUserModelStats(c.Request.Context(), subject.UserID, startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get model statistics")
response.ErrorFrom(c, err)
return
}
@@ -336,15 +354,9 @@ type BatchApiKeysUsageRequest struct {
// DashboardApiKeysUsage handles getting usage stats for user's own API keys
// POST /api/v1/usage/dashboard/api-keys-usage
func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -359,24 +371,16 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
return
}
// Verify ownership of all requested API keys
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, pagination.PaginationParams{Page: 1, PageSize: 1000})
if err != nil {
response.InternalError(c, "Failed to verify API key ownership")
// Limit the number of API key IDs to prevent SQL parameter overflow
if len(req.ApiKeyIDs) > 100 {
response.BadRequest(c, "Too many API key IDs (maximum 100 allowed)")
return
}
userApiKeyIDs := make(map[int64]bool)
for _, key := range userApiKeys {
userApiKeyIDs[key.ID] = true
}
// Filter to only include user's own API keys
validApiKeyIDs := make([]int64, 0)
for _, id := range req.ApiKeyIDs {
if userApiKeyIDs[id] {
validApiKeyIDs = append(validApiKeyIDs, id)
}
validApiKeyIDs, err := h.apiKeyService.VerifyOwnership(c.Request.Context(), subject.UserID, req.ApiKeyIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
if len(validApiKeyIDs) == 0 {
@@ -386,7 +390,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
stats, err := h.usageService.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
if err != nil {
response.InternalError(c, "Failed to get API key usage stats")
response.ErrorFrom(c, err)
return
}

View File

@@ -1,8 +1,9 @@
package handler
import (
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -35,42 +36,30 @@ type UpdateProfileRequest struct {
// GetProfile handles getting user profile
// GET /api/v1/users/me
func (h *UserHandler) GetProfile(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
userData, err := h.userService.GetByID(c.Request.Context(), user.ID)
userData, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
response.InternalError(c, "Failed to get user profile: "+err.Error())
response.ErrorFrom(c, err)
return
}
// 清空notes字段普通用户不应看到备注
userData.Notes = ""
response.Success(c, userData)
response.Success(c, dto.UserFromService(userData))
}
// ChangePassword handles changing user password
// POST /api/v1/users/me/password
func (h *UserHandler) ChangePassword(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -84,9 +73,9 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
CurrentPassword: req.OldPassword,
NewPassword: req.NewPassword,
}
err := h.userService.ChangePassword(c.Request.Context(), user.ID, svcReq)
err := h.userService.ChangePassword(c.Request.Context(), subject.UserID, svcReq)
if err != nil {
response.BadRequest(c, "Failed to change password: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -96,15 +85,9 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
// UpdateProfile handles updating user profile
// PUT /api/v1/users/me
func (h *UserHandler) UpdateProfile(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -118,14 +101,14 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
Username: req.Username,
Wechat: req.Wechat,
}
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), user.ID, svcReq)
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
if err != nil {
response.BadRequest(c, "Failed to update profile: "+err.Error())
response.ErrorFrom(c, err)
return
}
// 清空notes字段普通用户不应看到备注
updatedUser.Notes = ""
response.Success(c, updatedUser)
response.Success(c, dto.UserFromService(updatedUser))
}

View File

@@ -15,6 +15,7 @@ func ProvideAdminHandlers(
accountHandler *admin.AccountHandler,
oauthHandler *admin.OAuthHandler,
openaiOAuthHandler *admin.OpenAIOAuthHandler,
geminiOAuthHandler *admin.GeminiOAuthHandler,
proxyHandler *admin.ProxyHandler,
redeemHandler *admin.RedeemHandler,
settingHandler *admin.SettingHandler,
@@ -29,6 +30,7 @@ func ProvideAdminHandlers(
Account: accountHandler,
OAuth: oauthHandler,
OpenAIOAuth: openaiOAuthHandler,
GeminiOAuth: geminiOAuthHandler,
Proxy: proxyHandler,
Redeem: redeemHandler,
Setting: settingHandler,
@@ -95,6 +97,7 @@ var ProviderSet = wire.NewSet(
admin.NewAccountHandler,
admin.NewOAuthHandler,
admin.NewOpenAIOAuthHandler,
admin.NewGeminiOAuthHandler,
admin.NewProxyHandler,
admin.NewRedeemHandler,
admin.NewSettingHandler,

View File

@@ -2,8 +2,8 @@ package infrastructure
import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/repository"
"gorm.io/driver/postgres"
"gorm.io/gorm"
@@ -30,7 +30,7 @@ func InitDB(cfg *config.Config) (*gorm.DB, error) {
// 自动迁移(始终执行,确保数据库结构与代码同步)
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
if err := model.AutoMigrate(db); err != nil {
if err := repository.AutoMigrate(db); err != nil {
return nil, err
}

View File

@@ -0,0 +1,158 @@
package errors
import (
"errors"
"fmt"
"net/http"
)
const (
UnknownCode = http.StatusInternalServerError
UnknownReason = ""
UnknownMessage = "internal error"
)
type Status struct {
Code int32 `json:"code"`
Reason string `json:"reason,omitempty"`
Message string `json:"message"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// ApplicationError is the standard error type used to control HTTP responses.
//
// Code is expected to be an HTTP status code (e.g. 400/401/403/404/409/500).
type ApplicationError struct {
Status
cause error
}
// Error is kept for backwards compatibility within this package.
type Error = ApplicationError
func (e *ApplicationError) Error() string {
if e == nil {
return "<nil>"
}
if e.cause == nil {
return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v", e.Code, e.Reason, e.Message, e.Metadata)
}
return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v cause=%v", e.Code, e.Reason, e.Message, e.Metadata, e.cause)
}
// Unwrap provides compatibility for Go 1.13 error chains.
func (e *ApplicationError) Unwrap() error { return e.cause }
// Is matches each error in the chain with the target value.
func (e *ApplicationError) Is(err error) bool {
if se := new(ApplicationError); errors.As(err, &se) {
return se.Code == e.Code && se.Reason == e.Reason
}
return false
}
// WithCause attaches the underlying cause of the error.
func (e *ApplicationError) WithCause(cause error) *ApplicationError {
err := Clone(e)
err.cause = cause
return err
}
// WithMetadata deep-copies the given metadata map.
func (e *ApplicationError) WithMetadata(md map[string]string) *ApplicationError {
err := Clone(e)
if md == nil {
err.Metadata = nil
return err
}
err.Metadata = make(map[string]string, len(md))
for k, v := range md {
err.Metadata[k] = v
}
return err
}
// New returns an error object for the code, message.
func New(code int, reason, message string) *ApplicationError {
return &ApplicationError{
Status: Status{
Code: int32(code),
Message: message,
Reason: reason,
},
}
}
// Newf New(code fmt.Sprintf(format, a...))
func Newf(code int, reason, format string, a ...any) *ApplicationError {
return New(code, reason, fmt.Sprintf(format, a...))
}
// Errorf returns an error object for the code, message and error info.
func Errorf(code int, reason, format string, a ...any) error {
return New(code, reason, fmt.Sprintf(format, a...))
}
// Code returns the http code for an error.
// It supports wrapped errors.
func Code(err error) int {
if err == nil {
return http.StatusOK
}
return int(FromError(err).Code)
}
// Reason returns the reason for a particular error.
// It supports wrapped errors.
func Reason(err error) string {
if err == nil {
return UnknownReason
}
return FromError(err).Reason
}
// Message returns the message for a particular error.
// It supports wrapped errors.
func Message(err error) string {
if err == nil {
return ""
}
return FromError(err).Message
}
// Clone deep clone error to a new error.
func Clone(err *ApplicationError) *ApplicationError {
if err == nil {
return nil
}
var metadata map[string]string
if err.Metadata != nil {
metadata = make(map[string]string, len(err.Metadata))
for k, v := range err.Metadata {
metadata[k] = v
}
}
return &ApplicationError{
cause: err.cause,
Status: Status{
Code: err.Code,
Reason: err.Reason,
Message: err.Message,
Metadata: metadata,
},
}
}
// FromError tries to convert an error to *ApplicationError.
// It supports wrapped errors.
func FromError(err error) *ApplicationError {
if err == nil {
return nil
}
if se := new(ApplicationError); errors.As(err, &se) {
return se
}
// Fall back to a generic internal error.
return New(UnknownCode, UnknownReason, UnknownMessage).WithCause(err)
}

View File

@@ -0,0 +1,168 @@
//go:build unit
package errors
import (
stderrors "errors"
"fmt"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
func TestApplicationError_Basics(t *testing.T) {
tests := []struct {
name string
err *ApplicationError
want Status
wantIs bool
target error
wrapped error
}{
{
name: "new",
err: New(400, "BAD_REQUEST", "invalid input"),
want: Status{
Code: 400,
Reason: "BAD_REQUEST",
Message: "invalid input",
},
},
{
name: "is_matches_code_and_reason",
err: New(401, "UNAUTHORIZED", "nope"),
want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"},
target: New(401, "UNAUTHORIZED", "ignored message"),
wantIs: true,
},
{
name: "is_does_not_match_reason",
err: New(401, "UNAUTHORIZED", "nope"),
want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"},
target: New(401, "DIFFERENT", "ignored message"),
wantIs: false,
},
{
name: "from_error_unwraps_wrapped_application_error",
err: New(404, "NOT_FOUND", "missing"),
wrapped: fmt.Errorf("wrap: %w", New(404, "NOT_FOUND", "missing")),
want: Status{
Code: 404,
Reason: "NOT_FOUND",
Message: "missing",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.err != nil {
require.Equal(t, tt.want, tt.err.Status)
}
if tt.target != nil {
require.Equal(t, tt.wantIs, stderrors.Is(tt.err, tt.target))
}
if tt.wrapped != nil {
got := FromError(tt.wrapped)
require.Equal(t, tt.want, got.Status)
}
})
}
}
func TestApplicationError_WithMetadataDeepCopy(t *testing.T) {
tests := []struct {
name string
md map[string]string
}{
{name: "non_nil", md: map[string]string{"a": "1"}},
{name: "nil", md: nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
appErr := BadRequest("BAD_REQUEST", "invalid input").WithMetadata(tt.md)
if tt.md == nil {
require.Nil(t, appErr.Metadata)
return
}
tt.md["a"] = "changed"
require.Equal(t, "1", appErr.Metadata["a"])
})
}
}
func TestFromError_Generic(t *testing.T) {
tests := []struct {
name string
err error
wantCode int32
wantReason string
wantMsg string
}{
{
name: "plain_error",
err: stderrors.New("boom"),
wantCode: UnknownCode,
wantReason: UnknownReason,
wantMsg: UnknownMessage,
},
{
name: "wrapped_plain_error",
err: fmt.Errorf("wrap: %w", io.EOF),
wantCode: UnknownCode,
wantReason: UnknownReason,
wantMsg: UnknownMessage,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := FromError(tt.err)
require.Equal(t, tt.wantCode, got.Code)
require.Equal(t, tt.wantReason, got.Reason)
require.Equal(t, tt.wantMsg, got.Message)
require.Equal(t, tt.err, got.Unwrap())
})
}
}
func TestToHTTP(t *testing.T) {
tests := []struct {
name string
err error
wantStatusCode int
wantBody Status
}{
{
name: "nil_error",
err: nil,
wantStatusCode: http.StatusOK,
wantBody: Status{Code: int32(http.StatusOK)},
},
{
name: "application_error",
err: Forbidden("FORBIDDEN", "no access"),
wantStatusCode: http.StatusForbidden,
wantBody: Status{
Code: int32(http.StatusForbidden),
Reason: "FORBIDDEN",
Message: "no access",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
code, body := ToHTTP(tt.err)
require.Equal(t, tt.wantStatusCode, code)
require.Equal(t, tt.wantBody, body)
})
}
}

View File

@@ -0,0 +1,21 @@
package errors
import "net/http"
// ToHTTP converts an error into an HTTP status code and a JSON-serializable body.
//
// The returned body matches the project's Status shape:
// { code, reason, message, metadata }.
func ToHTTP(err error) (statusCode int, body Status) {
if err == nil {
return http.StatusOK, Status{Code: int32(http.StatusOK)}
}
appErr := FromError(err)
if appErr == nil {
return http.StatusOK, Status{Code: int32(http.StatusOK)}
}
cloned := Clone(appErr)
return int(cloned.Code), cloned.Status
}

View File

@@ -0,0 +1,114 @@
// nolint:mnd
package errors
import "net/http"
// BadRequest new BadRequest error that is mapped to a 400 response.
func BadRequest(reason, message string) *ApplicationError {
return New(http.StatusBadRequest, reason, message)
}
// IsBadRequest determines if err is an error which indicates a BadRequest error.
// It supports wrapped errors.
func IsBadRequest(err error) bool {
return Code(err) == http.StatusBadRequest
}
// TooManyRequests new TooManyRequests error that is mapped to a 429 response.
func TooManyRequests(reason, message string) *ApplicationError {
return New(http.StatusTooManyRequests, reason, message)
}
// IsTooManyRequests determines if err is an error which indicates a TooManyRequests error.
// It supports wrapped errors.
func IsTooManyRequests(err error) bool {
return Code(err) == http.StatusTooManyRequests
}
// Unauthorized new Unauthorized error that is mapped to a 401 response.
func Unauthorized(reason, message string) *ApplicationError {
return New(http.StatusUnauthorized, reason, message)
}
// IsUnauthorized determines if err is an error which indicates an Unauthorized error.
// It supports wrapped errors.
func IsUnauthorized(err error) bool {
return Code(err) == http.StatusUnauthorized
}
// Forbidden new Forbidden error that is mapped to a 403 response.
func Forbidden(reason, message string) *ApplicationError {
return New(http.StatusForbidden, reason, message)
}
// IsForbidden determines if err is an error which indicates a Forbidden error.
// It supports wrapped errors.
func IsForbidden(err error) bool {
return Code(err) == http.StatusForbidden
}
// NotFound new NotFound error that is mapped to a 404 response.
func NotFound(reason, message string) *ApplicationError {
return New(http.StatusNotFound, reason, message)
}
// IsNotFound determines if err is an error which indicates an NotFound error.
// It supports wrapped errors.
func IsNotFound(err error) bool {
return Code(err) == http.StatusNotFound
}
// Conflict new Conflict error that is mapped to a 409 response.
func Conflict(reason, message string) *ApplicationError {
return New(http.StatusConflict, reason, message)
}
// IsConflict determines if err is an error which indicates a Conflict error.
// It supports wrapped errors.
func IsConflict(err error) bool {
return Code(err) == http.StatusConflict
}
// InternalServer new InternalServer error that is mapped to a 500 response.
func InternalServer(reason, message string) *ApplicationError {
return New(http.StatusInternalServerError, reason, message)
}
// IsInternalServer determines if err is an error which indicates an Internal error.
// It supports wrapped errors.
func IsInternalServer(err error) bool {
return Code(err) == http.StatusInternalServerError
}
// ServiceUnavailable new ServiceUnavailable error that is mapped to an HTTP 503 response.
func ServiceUnavailable(reason, message string) *ApplicationError {
return New(http.StatusServiceUnavailable, reason, message)
}
// IsServiceUnavailable determines if err is an error which indicates an Unavailable error.
// It supports wrapped errors.
func IsServiceUnavailable(err error) bool {
return Code(err) == http.StatusServiceUnavailable
}
// GatewayTimeout new GatewayTimeout error that is mapped to an HTTP 504 response.
func GatewayTimeout(reason, message string) *ApplicationError {
return New(http.StatusGatewayTimeout, reason, message)
}
// IsGatewayTimeout determines if err is an error which indicates a GatewayTimeout error.
// It supports wrapped errors.
func IsGatewayTimeout(err error) bool {
return Code(err) == http.StatusGatewayTimeout
}
// ClientClosed new ClientClosed error that is mapped to an HTTP 499 response.
func ClientClosed(reason, message string) *ApplicationError {
return New(499, reason, message)
}
// IsClientClosed determines if err is an error which indicates a IsClientClosed error.
// It supports wrapped errors.
func IsClientClosed(err error) bool {
return Code(err) == 499
}

View File

@@ -1,415 +0,0 @@
package model
import (
"database/sql/driver"
"encoding/json"
"errors"
"time"
"gorm.io/gorm"
)
// JSONB 用于存储JSONB数据
type JSONB map[string]any
func (j JSONB) Value() (driver.Value, error) {
if j == nil {
return nil, nil
}
return json.Marshal(j)
}
func (j *JSONB) Scan(value any) error {
if value == nil {
*j = nil
return nil
}
bytes, ok := value.([]byte)
if !ok {
return errors.New("type assertion to []byte failed")
}
return json.Unmarshal(bytes, j)
}
type Account struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"size:100;not null" json:"name"`
Platform string `gorm:"size:50;not null" json:"platform"` // anthropic/openai/gemini
Type string `gorm:"size:20;not null" json:"type"` // oauth/apikey
Credentials JSONB `gorm:"type:jsonb;default:'{}'" json:"credentials"` // 凭证(加密存储)
Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息
ProxyID *int64 `gorm:"index" json:"proxy_id"`
Concurrency int `gorm:"default:3;not null" json:"concurrency"`
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100越小越高
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
ErrorMessage string `gorm:"type:text" json:"error_message"`
LastUsedAt *time.Time `gorm:"index" json:"last_used_at"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 调度控制
Schedulable bool `gorm:"default:true;not null" json:"schedulable"`
// 限流状态 (429)
RateLimitedAt *time.Time `gorm:"index" json:"rate_limited_at"`
RateLimitResetAt *time.Time `gorm:"index" json:"rate_limit_reset_at"`
// 过载状态 (529)
OverloadUntil *time.Time `gorm:"index" json:"overload_until"`
// 5小时时间窗口
SessionWindowStart *time.Time `json:"session_window_start"`
SessionWindowEnd *time.Time `json:"session_window_end"`
SessionWindowStatus string `gorm:"size:20" json:"session_window_status"` // allowed/allowed_warning/rejected
// 关联
Proxy *Proxy `gorm:"foreignKey:ProxyID" json:"proxy,omitempty"`
AccountGroups []AccountGroup `gorm:"foreignKey:AccountID" json:"account_groups,omitempty"`
// 虚拟字段 (不存储到数据库)
GroupIDs []int64 `gorm:"-" json:"group_ids,omitempty"`
Groups []*Group `gorm:"-" json:"groups,omitempty"`
}
func (Account) TableName() string {
return "accounts"
}
// IsActive 检查是否激活
func (a *Account) IsActive() bool {
return a.Status == "active"
}
// IsSchedulable 检查账号是否可调度
func (a *Account) IsSchedulable() bool {
if !a.IsActive() || !a.Schedulable {
return false
}
now := time.Now()
if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) {
return false
}
if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) {
return false
}
return true
}
// IsRateLimited 检查是否处于限流状态
func (a *Account) IsRateLimited() bool {
if a.RateLimitResetAt == nil {
return false
}
return time.Now().Before(*a.RateLimitResetAt)
}
// IsOverloaded 检查是否处于过载状态
func (a *Account) IsOverloaded() bool {
if a.OverloadUntil == nil {
return false
}
return time.Now().Before(*a.OverloadUntil)
}
// IsOAuth 检查是否为OAuth类型账号包括oauth和setup-token
func (a *Account) IsOAuth() bool {
return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken
}
// CanGetUsage 检查账号是否可以获取usage信息只有oauth类型可以setup-token没有profile权限
func (a *Account) CanGetUsage() bool {
return a.Type == AccountTypeOAuth
}
// GetCredential 获取凭证字段
func (a *Account) GetCredential(key string) string {
if a.Credentials == nil {
return ""
}
if v, ok := a.Credentials[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
// GetModelMapping 获取模型映射配置
// 返回格式: map[请求模型名]实际模型名
func (a *Account) GetModelMapping() map[string]string {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["model_mapping"]
if !ok || raw == nil {
return nil
}
// 处理map[string]interface{}类型
if m, ok := raw.(map[string]any); ok {
result := make(map[string]string)
for k, v := range m {
if s, ok := v.(string); ok {
result[k] = s
}
}
if len(result) > 0 {
return result
}
}
return nil
}
// IsModelSupported 检查请求的模型是否被该账号支持
// 如果没有设置模型映射,则支持所有模型
func (a *Account) IsModelSupported(requestedModel string) bool {
mapping := a.GetModelMapping()
if len(mapping) == 0 {
return true // 没有映射配置,支持所有模型
}
_, exists := mapping[requestedModel]
return exists
}
// GetMappedModel 获取映射后的实际模型名
// 如果没有映射,返回原始模型名
func (a *Account) GetMappedModel(requestedModel string) string {
mapping := a.GetModelMapping()
if len(mapping) == 0 {
return requestedModel
}
if mappedModel, exists := mapping[requestedModel]; exists {
return mappedModel
}
return requestedModel
}
// GetBaseURL 获取API基础URL用于apikey类型账号
func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeApiKey {
return ""
}
baseURL := a.GetCredential("base_url")
if baseURL == "" {
return "https://api.anthropic.com" // 默认URL
}
return baseURL
}
// GetExtraString 从Extra字段获取字符串值
func (a *Account) GetExtraString(key string) string {
if a.Extra == nil {
return ""
}
if v, ok := a.Extra[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
// IsCustomErrorCodesEnabled 检查是否启用自定义错误码功能(仅适用于 apikey 类型)
func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeApiKey || a.Credentials == nil {
return false
}
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// GetCustomErrorCodes 获取自定义错误码列表
func (a *Account) GetCustomErrorCodes() []int {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["custom_error_codes"]
if !ok || raw == nil {
return nil
}
// 处理 []interface{} 类型JSON反序列化后的格式
if arr, ok := raw.([]any); ok {
result := make([]int, 0, len(arr))
for _, v := range arr {
// JSON 数字默认解析为 float64
if f, ok := v.(float64); ok {
result = append(result, int(f))
}
}
return result
}
return nil
}
// ShouldHandleErrorCode 检查指定错误码是否应该被处理(停止调度/标记限流等)
// 如果未启用自定义错误码或列表为空,返回 true使用默认策略
// 如果启用且列表非空,只有在列表中的错误码才返回 true
func (a *Account) ShouldHandleErrorCode(statusCode int) bool {
if !a.IsCustomErrorCodesEnabled() {
return true // 未启用,使用默认策略
}
codes := a.GetCustomErrorCodes()
if len(codes) == 0 {
return true // 启用但列表为空fallback到默认策略
}
// 检查是否在自定义列表中
for _, code := range codes {
if code == statusCode {
return true
}
}
return false
}
// IsInterceptWarmupEnabled 检查是否启用预热请求拦截
// 启用后标题生成、Warmup等预热请求将返回mock响应不消耗上游token
func (a *Account) IsInterceptWarmupEnabled() bool {
if a.Credentials == nil {
return false
}
if v, ok := a.Credentials["intercept_warmup_requests"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// =============== OpenAI 相关方法 ===============
// IsOpenAI 检查是否为 OpenAI 平台账号
func (a *Account) IsOpenAI() bool {
return a.Platform == PlatformOpenAI
}
// IsAnthropic 检查是否为 Anthropic 平台账号
func (a *Account) IsAnthropic() bool {
return a.Platform == PlatformAnthropic
}
// IsOpenAIOAuth 检查是否为 OpenAI OAuth 类型账号
func (a *Account) IsOpenAIOAuth() bool {
return a.IsOpenAI() && a.Type == AccountTypeOAuth
}
// IsOpenAIApiKey 检查是否为 OpenAI API Key 类型账号Response 账号)
func (a *Account) IsOpenAIApiKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeApiKey
}
// GetOpenAIBaseURL 获取 OpenAI API 基础 URL
// 对于 API Key 类型账号,从 credentials 中获取 base_url
// 对于 OAuth 类型账号,返回默认的 OpenAI API URL
func (a *Account) GetOpenAIBaseURL() string {
if !a.IsOpenAI() {
return ""
}
if a.Type == AccountTypeApiKey {
baseURL := a.GetCredential("base_url")
if baseURL != "" {
return baseURL
}
}
return "https://api.openai.com" // OpenAI 默认 API URL
}
// GetOpenAIAccessToken 获取 OpenAI 访问令牌
func (a *Account) GetOpenAIAccessToken() string {
if !a.IsOpenAI() {
return ""
}
return a.GetCredential("access_token")
}
// GetOpenAIRefreshToken 获取 OpenAI 刷新令牌
func (a *Account) GetOpenAIRefreshToken() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("refresh_token")
}
// GetOpenAIIDToken 获取 OpenAI ID TokenJWT包含用户信息
func (a *Account) GetOpenAIIDToken() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("id_token")
}
// GetOpenAIApiKey 获取 OpenAI API Key用于 Response 账号)
func (a *Account) GetOpenAIApiKey() string {
if !a.IsOpenAIApiKey() {
return ""
}
return a.GetCredential("api_key")
}
// GetOpenAIUserAgent 获取 OpenAI 自定义 User-Agent
// 返回空字符串表示透传原始 User-Agent
func (a *Account) GetOpenAIUserAgent() string {
if !a.IsOpenAI() {
return ""
}
return a.GetCredential("user_agent")
}
// GetChatGPTAccountID 获取 ChatGPT 账号 ID从 ID Token 解析)
func (a *Account) GetChatGPTAccountID() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("chatgpt_account_id")
}
// GetChatGPTUserID 获取 ChatGPT 用户 ID从 ID Token 解析)
func (a *Account) GetChatGPTUserID() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("chatgpt_user_id")
}
// GetOpenAIOrganizationID 获取 OpenAI 组织 ID
func (a *Account) GetOpenAIOrganizationID() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("organization_id")
}
// GetOpenAITokenExpiresAt 获取 OpenAI Token 过期时间
func (a *Account) GetOpenAITokenExpiresAt() *time.Time {
if !a.IsOpenAIOAuth() {
return nil
}
expiresAtStr := a.GetCredential("expires_at")
if expiresAtStr == "" {
return nil
}
// 尝试解析时间
t, err := time.Parse(time.RFC3339, expiresAtStr)
if err != nil {
// 尝试解析为 Unix 时间戳
if v, ok := a.Credentials["expires_at"].(float64); ok {
t = time.Unix(int64(v), 0)
return &t
}
return nil
}
return &t
}
// IsOpenAITokenExpired 检查 OpenAI Token 是否过期
func (a *Account) IsOpenAITokenExpired() bool {
expiresAt := a.GetOpenAITokenExpiresAt()
if expiresAt == nil {
return false // 没有过期时间信息,假设未过期
}
// 提前 60 秒认为过期,便于刷新
return time.Now().Add(60 * time.Second).After(*expiresAt)
}

View File

@@ -1,20 +0,0 @@
package model
import (
"time"
)
type AccountGroup struct {
AccountID int64 `gorm:"primaryKey" json:"account_id"`
GroupID int64 `gorm:"primaryKey" json:"group_id"`
Priority int `gorm:"default:50;not null" json:"priority"` // 分组内优先级
CreatedAt time.Time `gorm:"not null" json:"created_at"`
// 关联
Account *Account `gorm:"foreignKey:AccountID" json:"account,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func (AccountGroup) TableName() string {
return "account_groups"
}

View File

@@ -1,32 +0,0 @@
package model
import (
"time"
"gorm.io/gorm"
)
type ApiKey struct {
ID int64 `gorm:"primaryKey" json:"id"`
UserID int64 `gorm:"index;not null" json:"user_id"`
Key string `gorm:"uniqueIndex;size:128;not null" json:"key"` // sk-xxx
Name string `gorm:"size:100;not null" json:"name"`
GroupID *int64 `gorm:"index" json:"group_id"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func (ApiKey) TableName() string {
return "api_keys"
}
// IsActive 检查是否激活
func (k *ApiKey) IsActive() bool {
return k.Status == "active"
}

View File

@@ -1,73 +0,0 @@
package model
import (
"time"
"gorm.io/gorm"
)
// 订阅类型常量
const (
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
)
type Group struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
Description string `gorm:"type:text" json:"description"`
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
// 订阅功能字段
SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription
DailyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"monthly_limit_usd"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
AccountGroups []AccountGroup `gorm:"foreignKey:GroupID" json:"account_groups,omitempty"`
// 虚拟字段 (不存储到数据库)
AccountCount int64 `gorm:"-" json:"account_count,omitempty"`
}
func (Group) TableName() string {
return "groups"
}
// IsActive 检查是否激活
func (g *Group) IsActive() bool {
return g.Status == "active"
}
// IsSubscriptionType 检查是否为订阅类型分组
func (g *Group) IsSubscriptionType() bool {
return g.SubscriptionType == SubscriptionTypeSubscription
}
// IsFreeSubscription 检查是否为免费订阅(不扣余额但有限额)
func (g *Group) IsFreeSubscription() bool {
return g.IsSubscriptionType() && g.RateMultiplier == 0
}
// HasDailyLimit 检查是否有日限额
func (g *Group) HasDailyLimit() bool {
return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0
}
// HasWeeklyLimit 检查是否有周限额
func (g *Group) HasWeeklyLimit() bool {
return g.WeeklyLimitUSD != nil && *g.WeeklyLimitUSD > 0
}
// HasMonthlyLimit 检查是否有月限额
func (g *Group) HasMonthlyLimit() bool {
return g.MonthlyLimitUSD != nil && *g.MonthlyLimitUSD > 0
}

View File

@@ -1,64 +0,0 @@
package model
import (
"gorm.io/gorm"
)
// AutoMigrate 自动迁移所有模型
func AutoMigrate(db *gorm.DB) error {
return db.AutoMigrate(
&User{},
&ApiKey{},
&Group{},
&Account{},
&AccountGroup{},
&Proxy{},
&RedeemCode{},
&UsageLog{},
&Setting{},
&UserSubscription{},
)
}
// 状态常量
const (
StatusActive = "active"
StatusDisabled = "disabled"
StatusError = "error"
StatusUnused = "unused"
StatusUsed = "used"
StatusExpired = "expired"
)
// 角色常量
const (
RoleAdmin = "admin"
RoleUser = "user"
)
// 平台常量
const (
PlatformAnthropic = "anthropic"
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
)
// 账号类型常量
const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeApiKey = "apikey" // API Key类型账号
)
// 卡密类型常量
const (
RedeemTypeBalance = "balance"
RedeemTypeConcurrency = "concurrency"
RedeemTypeSubscription = "subscription"
)
// 管理员调整类型常量
const (
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
)

View File

@@ -1,45 +0,0 @@
package model
import (
"fmt"
"time"
"gorm.io/gorm"
)
type Proxy struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"size:100;not null" json:"name"`
Protocol string `gorm:"size:20;not null" json:"protocol"` // http/https/socks5
Host string `gorm:"size:255;not null" json:"host"`
Port int `gorm:"not null" json:"port"`
Username string `gorm:"size:100" json:"username"`
Password string `gorm:"size:100" json:"-"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
}
func (Proxy) TableName() string {
return "proxies"
}
// IsActive 检查是否激活
func (p *Proxy) IsActive() bool {
return p.Status == "active"
}
// URL 返回代理URL
func (p *Proxy) URL() string {
if p.Username != "" && p.Password != "" {
return fmt.Sprintf("%s://%s:%s@%s:%d", p.Protocol, p.Username, p.Password, p.Host, p.Port)
}
return fmt.Sprintf("%s://%s:%d", p.Protocol, p.Host, p.Port)
}
// ProxyWithAccountCount extends Proxy with account count information
type ProxyWithAccountCount struct {
Proxy
AccountCount int64 `json:"account_count"`
}

View File

@@ -1,50 +0,0 @@
package model
import (
"crypto/rand"
"encoding/hex"
"time"
)
type RedeemCode struct {
ID int64 `gorm:"primaryKey" json:"id"`
Code string `gorm:"uniqueIndex;size:32;not null" json:"code"`
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used
UsedBy *int64 `gorm:"index" json:"used_by"`
UsedAt *time.Time `json:"used_at"`
Notes string `gorm:"type:text" json:"notes"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
// 订阅类型专用字段
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用)
// 关联
User *User `gorm:"foreignKey:UsedBy" json:"user,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func (RedeemCode) TableName() string {
return "redeem_codes"
}
// IsUsed 检查是否已使用
func (r *RedeemCode) IsUsed() bool {
return r.Status == "used"
}
// CanUse 检查是否可以使用
func (r *RedeemCode) CanUse() bool {
return r.Status == "unused"
}
// GenerateRedeemCode 生成唯一的兑换码
func GenerateRedeemCode() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}

View File

@@ -1,104 +0,0 @@
package model
import (
"time"
)
// Setting 系统设置模型Key-Value存储
type Setting struct {
ID int64 `gorm:"primaryKey" json:"id"`
Key string `gorm:"uniqueIndex;size:100;not null" json:"key"`
Value string `gorm:"type:text;not null" json:"value"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
}
func (Setting) TableName() string {
return "settings"
}
// 设置Key常量
const (
// 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
// 邮件服务设置
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
SettingKeySmtpPort = "smtp_port" // SMTP端口
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
SettingKeySmtpPassword = "smtp_password" // SMTP密码加密存储
SettingKeySmtpFrom = "smtp_from" // 发件人地址
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
// Cloudflare Turnstile 设置
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
// OEM设置
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyApiBaseUrl = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocUrl = "doc_url" // 文档链接
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
// 管理员 API Key
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
)
// 管理员 API Key 前缀(与用户 sk- 前缀区分)
const AdminApiKeyPrefix = "admin-"
// SystemSettings 系统设置结构体用于API响应
type SystemSettings struct {
// 注册设置
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
// 邮件服务设置
SmtpHost string `json:"smtp_host"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password,omitempty"` // 不返回明文密码
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
// Cloudflare Turnstile 设置
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"` // 不返回明文密钥
// OEM设置
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
}
// PublicSettings 公开设置(无需登录即可获取)
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
Version string `json:"version"`
}

View File

@@ -1,67 +0,0 @@
package model
import (
"time"
)
// 消费类型常量
const (
BillingTypeBalance int8 = 0 // 钱包余额
BillingTypeSubscription int8 = 1 // 订阅套餐
)
type UsageLog struct {
ID int64 `gorm:"primaryKey" json:"id"`
UserID int64 `gorm:"index;not null" json:"user_id"`
ApiKeyID int64 `gorm:"index;not null" json:"api_key_id"`
AccountID int64 `gorm:"index;not null" json:"account_id"`
RequestID string `gorm:"size:64" json:"request_id"`
Model string `gorm:"size:100;index;not null" json:"model"`
// 订阅关联(可选)
GroupID *int64 `gorm:"index" json:"group_id"`
SubscriptionID *int64 `gorm:"index" json:"subscription_id"`
// Token使用量4类
InputTokens int `gorm:"default:0;not null" json:"input_tokens"`
OutputTokens int `gorm:"default:0;not null" json:"output_tokens"`
CacheCreationTokens int `gorm:"default:0;not null" json:"cache_creation_tokens"`
CacheReadTokens int `gorm:"default:0;not null" json:"cache_read_tokens"`
// 详细的缓存创建分类
CacheCreation5mTokens int `gorm:"default:0;not null" json:"cache_creation_5m_tokens"`
CacheCreation1hTokens int `gorm:"default:0;not null" json:"cache_creation_1h_tokens"`
// 费用USD
InputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"input_cost"`
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率
// 元数据
BillingType int8 `gorm:"type:smallint;default:0;not null" json:"billing_type"` // 0=余额 1=订阅
Stream bool `gorm:"default:false;not null" json:"stream"`
DurationMs *int `json:"duration_ms"`
FirstTokenMs *int `json:"first_token_ms"` // 首字时间(流式请求)
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
// 关联
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
ApiKey *ApiKey `gorm:"foreignKey:ApiKeyID" json:"api_key,omitempty"`
Account *Account `gorm:"foreignKey:AccountID" json:"account,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
Subscription *UserSubscription `gorm:"foreignKey:SubscriptionID" json:"subscription,omitempty"`
}
func (UsageLog) TableName() string {
return "usage_logs"
}
// TotalTokens 总token数
func (u *UsageLog) TotalTokens() int {
return u.InputTokens + u.OutputTokens + u.CacheCreationTokens + u.CacheReadTokens
}

View File

@@ -1,78 +0,0 @@
package model
import (
"time"
"github.com/lib/pq"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type User struct {
ID int64 `gorm:"primaryKey" json:"id"`
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
Username string `gorm:"size:100;default:''" json:"username"`
Wechat string `gorm:"size:100;default:''" json:"wechat"`
Notes string `gorm:"type:text;default:''" json:"notes"`
PasswordHash string `gorm:"size:255;not null" json:"-"`
Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user
Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
Concurrency int `gorm:"default:5;not null" json:"concurrency"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
AllowedGroups pq.Int64Array `gorm:"type:bigint[]" json:"allowed_groups"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
ApiKeys []ApiKey `gorm:"foreignKey:UserID" json:"api_keys,omitempty"`
Subscriptions []UserSubscription `gorm:"foreignKey:UserID" json:"subscriptions,omitempty"`
}
func (User) TableName() string {
return "users"
}
// IsAdmin 检查是否管理员
func (u *User) IsAdmin() bool {
return u.Role == "admin"
}
// IsActive 检查是否激活
func (u *User) IsActive() bool {
return u.Status == "active"
}
// CanBindGroup 检查是否可以绑定指定分组
// 对于标准类型分组:
// - 如果 AllowedGroups 设置了值(非空数组),只能绑定列表中的分组
// - 如果 AllowedGroups 为 nil 或空数组,可以绑定所有非专属分组
func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool {
// 如果设置了 allowed_groups 且不为空,只能绑定指定的分组
if len(u.AllowedGroups) > 0 {
for _, id := range u.AllowedGroups {
if id == groupID {
return true
}
}
return false
}
// 如果没有设置 allowed_groups 或为空数组,可以绑定所有非专属分组
return !isExclusive
}
// SetPassword 设置密码(哈希存储)
func (u *User) SetPassword(password string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
u.PasswordHash = string(hash)
return nil
}
// CheckPassword 验证密码
func (u *User) CheckPassword(password string) bool {
err := bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password))
return err == nil
}

View File

@@ -1,157 +0,0 @@
package model
import (
"time"
)
// 订阅状态常量
const (
SubscriptionStatusActive = "active"
SubscriptionStatusExpired = "expired"
SubscriptionStatusSuspended = "suspended"
)
// UserSubscription 用户订阅模型
type UserSubscription struct {
ID int64 `gorm:"primaryKey" json:"id"`
UserID int64 `gorm:"index;not null" json:"user_id"`
GroupID int64 `gorm:"index;not null" json:"group_id"`
// 订阅有效期
StartsAt time.Time `gorm:"not null" json:"starts_at"`
ExpiresAt time.Time `gorm:"not null" json:"expires_at"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/expired/suspended
// 滑动窗口起始时间nil = 未激活)
DailyWindowStart *time.Time `json:"daily_window_start"`
WeeklyWindowStart *time.Time `json:"weekly_window_start"`
MonthlyWindowStart *time.Time `json:"monthly_window_start"`
// 当前窗口已用额度USD基于 total_cost 计算)
DailyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"daily_usage_usd"`
WeeklyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"weekly_usage_usd"`
MonthlyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"monthly_usage_usd"`
// 管理员分配信息
AssignedBy *int64 `gorm:"index" json:"assigned_by"`
AssignedAt time.Time `gorm:"not null" json:"assigned_at"`
Notes string `gorm:"type:text" json:"notes"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
// 关联
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
AssignedByUser *User `gorm:"foreignKey:AssignedBy" json:"assigned_by_user,omitempty"`
}
func (UserSubscription) TableName() string {
return "user_subscriptions"
}
// IsActive 检查订阅是否有效状态为active且未过期
func (s *UserSubscription) IsActive() bool {
return s.Status == SubscriptionStatusActive && time.Now().Before(s.ExpiresAt)
}
// IsExpired 检查订阅是否已过期
func (s *UserSubscription) IsExpired() bool {
return time.Now().After(s.ExpiresAt)
}
// DaysRemaining 返回订阅剩余天数
func (s *UserSubscription) DaysRemaining() int {
if s.IsExpired() {
return 0
}
return int(time.Until(s.ExpiresAt).Hours() / 24)
}
// IsWindowActivated 检查窗口是否已激活
func (s *UserSubscription) IsWindowActivated() bool {
return s.DailyWindowStart != nil || s.WeeklyWindowStart != nil || s.MonthlyWindowStart != nil
}
// NeedsDailyReset 检查日窗口是否需要重置
func (s *UserSubscription) NeedsDailyReset() bool {
if s.DailyWindowStart == nil {
return false
}
return time.Since(*s.DailyWindowStart) >= 24*time.Hour
}
// NeedsWeeklyReset 检查周窗口是否需要重置
func (s *UserSubscription) NeedsWeeklyReset() bool {
if s.WeeklyWindowStart == nil {
return false
}
return time.Since(*s.WeeklyWindowStart) >= 7*24*time.Hour
}
// NeedsMonthlyReset 检查月窗口是否需要重置
func (s *UserSubscription) NeedsMonthlyReset() bool {
if s.MonthlyWindowStart == nil {
return false
}
return time.Since(*s.MonthlyWindowStart) >= 30*24*time.Hour
}
// DailyResetTime 返回日窗口重置时间
func (s *UserSubscription) DailyResetTime() *time.Time {
if s.DailyWindowStart == nil {
return nil
}
t := s.DailyWindowStart.Add(24 * time.Hour)
return &t
}
// WeeklyResetTime 返回周窗口重置时间
func (s *UserSubscription) WeeklyResetTime() *time.Time {
if s.WeeklyWindowStart == nil {
return nil
}
t := s.WeeklyWindowStart.Add(7 * 24 * time.Hour)
return &t
}
// MonthlyResetTime 返回月窗口重置时间
func (s *UserSubscription) MonthlyResetTime() *time.Time {
if s.MonthlyWindowStart == nil {
return nil
}
t := s.MonthlyWindowStart.Add(30 * 24 * time.Hour)
return &t
}
// CheckDailyLimit 检查是否超出日限额
func (s *UserSubscription) CheckDailyLimit(group *Group, additionalCost float64) bool {
if !group.HasDailyLimit() {
return true // 无限制
}
return s.DailyUsageUSD+additionalCost <= *group.DailyLimitUSD
}
// CheckWeeklyLimit 检查是否超出周限额
func (s *UserSubscription) CheckWeeklyLimit(group *Group, additionalCost float64) bool {
if !group.HasWeeklyLimit() {
return true // 无限制
}
return s.WeeklyUsageUSD+additionalCost <= *group.WeeklyLimitUSD
}
// CheckMonthlyLimit 检查是否超出月限额
func (s *UserSubscription) CheckMonthlyLimit(group *Group, additionalCost float64) bool {
if !group.HasMonthlyLimit() {
return true // 无限制
}
return s.MonthlyUsageUSD+additionalCost <= *group.MonthlyLimitUSD
}
// CheckAllLimits 检查所有限额
func (s *UserSubscription) CheckAllLimits(group *Group, additionalCost float64) (daily, weekly, monthly bool) {
daily = s.CheckDailyLimit(group, additionalCost)
weekly = s.CheckWeeklyLimit(group, additionalCost)
monthly = s.CheckMonthlyLimit(group, additionalCost)
return
}

View File

@@ -0,0 +1,42 @@
package gemini
// This package provides minimal fallback model metadata for Gemini native endpoints.
// It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes).
type Model struct {
Name string `json:"name"`
DisplayName string `json:"displayName,omitempty"`
Description string `json:"description,omitempty"`
SupportedGenerationMethods []string `json:"supportedGenerationMethods,omitempty"`
}
type ModelsListResponse struct {
Models []Model `json:"models"`
}
func DefaultModels() []Model {
methods := []string{"generateContent", "streamGenerateContent"}
return []Model{
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.0-flash-lite", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-pro", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-flash-8b", SupportedGenerationMethods: methods},
}
}
func FallbackModelsList() ModelsListResponse {
return ModelsListResponse{Models: DefaultModels()}
}
func FallbackModel(model string) Model {
methods := []string{"generateContent", "streamGenerateContent"}
if model == "" {
return Model{Name: "models/unknown", SupportedGenerationMethods: methods}
}
if len(model) >= 7 && model[:7] == "models/" {
return Model{Name: model, SupportedGenerationMethods: methods}
}
return Model{Name: "models/" + model, SupportedGenerationMethods: methods}
}

View File

@@ -0,0 +1,38 @@
package geminicli
// LoadCodeAssistRequest matches done-hub's internal Code Assist call.
type LoadCodeAssistRequest struct {
Metadata LoadCodeAssistMetadata `json:"metadata"`
}
type LoadCodeAssistMetadata struct {
IDEType string `json:"ideType"`
Platform string `json:"platform"`
PluginType string `json:"pluginType"`
}
type LoadCodeAssistResponse struct {
CurrentTier string `json:"currentTier,omitempty"`
CloudAICompanionProject string `json:"cloudaicompanionProject,omitempty"`
AllowedTiers []AllowedTier `json:"allowedTiers,omitempty"`
}
type AllowedTier struct {
ID string `json:"id"`
IsDefault bool `json:"isDefault,omitempty"`
}
type OnboardUserRequest struct {
TierID string `json:"tierId"`
Metadata LoadCodeAssistMetadata `json:"metadata"`
}
type OnboardUserResponse struct {
Done bool `json:"done"`
Response *OnboardUserResultData `json:"response,omitempty"`
Name string `json:"name,omitempty"`
}
type OnboardUserResultData struct {
CloudAICompanionProject any `json:"cloudaicompanionProject,omitempty"`
}

View File

@@ -0,0 +1,42 @@
package geminicli
import "time"
const (
AIStudioBaseURL = "https://generativelanguage.googleapis.com"
GeminiCliBaseURL = "https://cloudcode-pa.googleapis.com"
AuthorizeURL = "https://accounts.google.com/o/oauth2/v2/auth"
TokenURL = "https://oauth2.googleapis.com/token"
// AIStudioOAuthRedirectURI is the default redirect URI used for AI Studio OAuth.
// This matches the "copy/paste callback URL" flow used by OpenAI OAuth in this project.
// Note: You still need to register this redirect URI in your Google OAuth client
// unless you use an OAuth client type that permits localhost redirect URIs.
AIStudioOAuthRedirectURI = "http://localhost:1455/auth/callback"
// DefaultScopes for Code Assist (includes cloud-platform for API access plus userinfo scopes)
// Required by Google's Code Assist API.
DefaultCodeAssistScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile"
// DefaultScopes for AI Studio (uses generativelanguage API with OAuth)
// Reference: https://ai.google.dev/gemini-api/docs/oauth
// For regular Google accounts, supports API calls to generativelanguage.googleapis.com
// Note: Google Auth platform currently documents the OAuth scope as
// https://www.googleapis.com/auth/generative-language.retriever (often with cloud-platform).
DefaultAIStudioScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever"
// GeminiCLIRedirectURI is the redirect URI used by Gemini CLI for Code Assist OAuth.
GeminiCLIRedirectURI = "https://codeassist.google.com/authcode"
// GeminiCLIOAuthClientID/Secret are the public OAuth client credentials used by Google Gemini CLI.
// They enable the "login without creating your own OAuth client" experience, but Google may
// restrict which scopes are allowed for this client.
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
GeminiCLIOAuthClientSecret = "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"
SessionTTL = 30 * time.Minute
// GeminiCLIUserAgent mimics Gemini CLI to maximize compatibility with internal endpoints.
GeminiCLIUserAgent = "GeminiCLI/0.1.5 (Windows; AMD64)"
)

View File

@@ -0,0 +1,21 @@
package geminicli
// Model represents a selectable Gemini model for UI/testing purposes.
// Keep JSON fields consistent with existing frontend expectations.
type Model struct {
ID string `json:"id"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
CreatedAt string `json:"created_at"`
}
// DefaultModels is the curated Gemini model list used by the admin UI "test account" flow.
var DefaultModels = []Model{
{ID: "gemini-3-pro", Type: "model", DisplayName: "Gemini 3 Pro", CreatedAt: ""},
{ID: "gemini-3-flash", Type: "model", DisplayName: "Gemini 3 Flash", CreatedAt: ""},
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
}
// DefaultTestModel is the default model to preselect in test flows.
const DefaultTestModel = "gemini-2.5-pro"

View File

@@ -0,0 +1,243 @@
package geminicli
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"net/url"
"strings"
"sync"
"time"
)
type OAuthConfig struct {
ClientID string
ClientSecret string
Scopes string
}
type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
ProxyURL string `json:"proxy_url,omitempty"`
RedirectURI string `json:"redirect_uri"`
ProjectID string `json:"project_id,omitempty"`
OAuthType string `json:"oauth_type"` // "code_assist" 或 "ai_studio"
CreatedAt time.Time `json:"created_at"`
}
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*OAuthSession
stopCh chan struct{}
}
func NewSessionStore() *SessionStore {
store := &SessionStore{
sessions: make(map[string]*OAuthSession),
stopCh: make(chan struct{}),
}
go store.cleanup()
return store
}
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[sessionID] = session
}
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
session, ok := s.sessions[sessionID]
if !ok {
return nil, false
}
if time.Since(session.CreatedAt) > SessionTTL {
return nil, false
}
return session, true
}
func (s *SessionStore) Delete(sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionID)
}
func (s *SessionStore) Stop() {
select {
case <-s.stopCh:
return
default:
close(s.stopCh)
}
}
func (s *SessionStore) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-s.stopCh:
return
case <-ticker.C:
s.mu.Lock()
for id, session := range s.sessions {
if time.Since(session.CreatedAt) > SessionTTL {
delete(s.sessions, id)
}
}
s.mu.Unlock()
}
}
}
func GenerateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
func GenerateState() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return base64URLEncode(bytes), nil
}
func GenerateSessionID() (string, error) {
bytes, err := GenerateRandomBytes(16)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateCodeVerifier returns an RFC 7636 compatible code verifier (43+ chars).
func GenerateCodeVerifier() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return base64URLEncode(bytes), nil
}
func GenerateCodeChallenge(verifier string) string {
hash := sha256.Sum256([]byte(verifier))
return base64URLEncode(hash[:])
}
func base64URLEncode(data []byte) string {
return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=")
}
// EffectiveOAuthConfig returns the effective OAuth configuration.
// oauthType: "code_assist" or "ai_studio" (defaults to "code_assist" if empty).
//
// If ClientID/ClientSecret is not provided, this falls back to the built-in Gemini CLI OAuth client.
//
// Note: The built-in Gemini CLI OAuth client is restricted and may reject some scopes (e.g.
// https://www.googleapis.com/auth/generative-language), which will surface as
// "restricted_client" / "Unregistered scope(s)" errors during browser authorization.
func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error) {
effective := OAuthConfig{
ClientID: strings.TrimSpace(cfg.ClientID),
ClientSecret: strings.TrimSpace(cfg.ClientSecret),
Scopes: strings.TrimSpace(cfg.Scopes),
}
// Normalize scopes: allow comma-separated input but send space-delimited scopes to Google.
if effective.Scopes != "" {
effective.Scopes = strings.Join(strings.Fields(strings.ReplaceAll(effective.Scopes, ",", " ")), " ")
}
// Fall back to built-in Gemini CLI OAuth client when not configured.
if effective.ClientID == "" && effective.ClientSecret == "" {
effective.ClientID = GeminiCLIOAuthClientID
effective.ClientSecret = GeminiCLIOAuthClientSecret
} else if effective.ClientID == "" || effective.ClientSecret == "" {
return OAuthConfig{}, fmt.Errorf("OAuth client not configured: please set both client_id and client_secret (or leave both empty to use the built-in Gemini CLI client)")
}
isBuiltinClient := effective.ClientID == GeminiCLIOAuthClientID &&
effective.ClientSecret == GeminiCLIOAuthClientSecret
if effective.Scopes == "" {
// Use different default scopes based on OAuth type
if oauthType == "ai_studio" {
// Built-in client can't request some AI Studio scopes (notably generative-language).
if isBuiltinClient {
effective.Scopes = DefaultCodeAssistScopes
} else {
effective.Scopes = DefaultAIStudioScopes
}
} else {
// Default to Code Assist scopes
effective.Scopes = DefaultCodeAssistScopes
}
} else if oauthType == "ai_studio" && isBuiltinClient {
// If user overrides scopes while still using the built-in client, strip restricted scopes.
parts := strings.Fields(effective.Scopes)
filtered := make([]string, 0, len(parts))
for _, s := range parts {
if strings.Contains(s, "generative-language") {
continue
}
filtered = append(filtered, s)
}
if len(filtered) == 0 {
effective.Scopes = DefaultCodeAssistScopes
} else {
effective.Scopes = strings.Join(filtered, " ")
}
}
// Backward compatibility: normalize older AI Studio scope to the currently documented one.
if oauthType == "ai_studio" && effective.Scopes != "" {
parts := strings.Fields(effective.Scopes)
for i := range parts {
if parts[i] == "https://www.googleapis.com/auth/generative-language" {
parts[i] = "https://www.googleapis.com/auth/generative-language.retriever"
}
}
effective.Scopes = strings.Join(parts, " ")
}
return effective, nil
}
func BuildAuthorizationURL(cfg OAuthConfig, state, codeChallenge, redirectURI, projectID, oauthType string) (string, error) {
effectiveCfg, err := EffectiveOAuthConfig(cfg, oauthType)
if err != nil {
return "", err
}
redirectURI = strings.TrimSpace(redirectURI)
if redirectURI == "" {
return "", fmt.Errorf("redirect_uri is required")
}
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", effectiveCfg.ClientID)
params.Set("redirect_uri", redirectURI)
params.Set("scope", effectiveCfg.Scopes)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
params.Set("access_type", "offline")
params.Set("prompt", "consent")
params.Set("include_granted_scopes", "true")
if strings.TrimSpace(projectID) != "" {
params.Set("project_id", strings.TrimSpace(projectID))
}
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode()), nil
}

View File

@@ -0,0 +1,46 @@
package geminicli
import "strings"
const maxLogBodyLen = 2048
func SanitizeBodyForLogs(body string) string {
body = truncateBase64InMessage(body)
if len(body) > maxLogBodyLen {
body = body[:maxLogBodyLen] + "...[truncated]"
}
return body
}
func truncateBase64InMessage(message string) string {
const maxBase64Length = 50
result := message
offset := 0
for {
idx := strings.Index(result[offset:], ";base64,")
if idx == -1 {
break
}
actualIdx := offset + idx
start := actualIdx + len(";base64,")
end := start
for end < len(result) && isBase64Char(result[end]) {
end++
}
if end-start > maxBase64Length {
result = result[:start+maxBase64Length] + "...[truncated]" + result[end:]
offset = start + maxBase64Length + len("...[truncated]")
continue
}
offset = end
}
return result
}
func isBase64Char(c byte) bool {
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '+' || c == '/' || c == '='
}

View File

@@ -0,0 +1,9 @@
package geminicli
type TokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
Scope string `json:"scope,omitempty"`
}

View File

@@ -0,0 +1,24 @@
package googleapi
import "net/http"
// HTTPStatusToGoogleStatus maps HTTP status codes to Google-style error status strings.
func HTTPStatusToGoogleStatus(status int) string {
switch status {
case http.StatusBadRequest:
return "INVALID_ARGUMENT"
case http.StatusUnauthorized:
return "UNAUTHENTICATED"
case http.StatusForbidden:
return "PERMISSION_DENIED"
case http.StatusNotFound:
return "NOT_FOUND"
case http.StatusTooManyRequests:
return "RESOURCE_EXHAUSTED"
default:
if status >= 500 {
return "INTERNAL"
}
return "UNKNOWN"
}
}

View File

@@ -4,14 +4,17 @@ import (
"math"
"net/http"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/gin-gonic/gin"
)
// Response 标准API响应格式
type Response struct {
Code int `json:"code"`
Message string `json:"message"`
Data any `json:"data,omitempty"`
Code int `json:"code"`
Message string `json:"message"`
Reason string `json:"reason,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
Data any `json:"data,omitempty"`
}
// PaginatedData 分页数据格式(匹配前端期望)
@@ -44,11 +47,36 @@ func Created(c *gin.Context, data any) {
// Error 返回错误响应
func Error(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, Response{
Code: statusCode,
Message: message,
Code: statusCode,
Message: message,
Reason: "",
Metadata: nil,
})
}
// ErrorWithDetails returns an error response compatible with the existing envelope while
// optionally providing structured error fields (reason/metadata).
func ErrorWithDetails(c *gin.Context, statusCode int, message, reason string, metadata map[string]string) {
c.JSON(statusCode, Response{
Code: statusCode,
Message: message,
Reason: reason,
Metadata: metadata,
})
}
// ErrorFrom converts an ApplicationError (or any error) into the envelope-compatible error response.
// It returns true if an error was written.
func ErrorFrom(c *gin.Context, err error) bool {
if err == nil {
return false
}
statusCode, status := infraerrors.ToHTTP(err)
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
return true
}
// BadRequest 返回400错误
func BadRequest(c *gin.Context, message string) {
Error(c, http.StatusBadRequest, message)

View File

@@ -0,0 +1,171 @@
//go:build unit
package response
import (
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestErrorWithDetails(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
statusCode int
message string
reason string
metadata map[string]string
want Response
}{
{
name: "plain_error",
statusCode: http.StatusBadRequest,
message: "invalid request",
want: Response{
Code: http.StatusBadRequest,
Message: "invalid request",
},
},
{
name: "structured_error",
statusCode: http.StatusForbidden,
message: "no access",
reason: "FORBIDDEN",
metadata: map[string]string{"k": "v"},
want: Response{
Code: http.StatusForbidden,
Message: "no access",
Reason: "FORBIDDEN",
Metadata: map[string]string{"k": "v"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
ErrorWithDetails(c, tt.statusCode, tt.message, tt.reason, tt.metadata)
require.Equal(t, tt.statusCode, w.Code)
var got Response
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
require.Equal(t, tt.want, got)
})
}
}
func TestErrorFrom(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
err error
wantWritten bool
wantHTTPCode int
wantBody Response
}{
{
name: "nil_error",
err: nil,
wantWritten: false,
},
{
name: "application_error",
err: infraerrors.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}),
wantWritten: true,
wantHTTPCode: http.StatusForbidden,
wantBody: Response{
Code: http.StatusForbidden,
Message: "no access",
Reason: "FORBIDDEN",
Metadata: map[string]string{"scope": "admin"},
},
},
{
name: "bad_request_error",
err: infraerrors.BadRequest("INVALID_REQUEST", "invalid request"),
wantWritten: true,
wantHTTPCode: http.StatusBadRequest,
wantBody: Response{
Code: http.StatusBadRequest,
Message: "invalid request",
Reason: "INVALID_REQUEST",
},
},
{
name: "unauthorized_error",
err: infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized"),
wantWritten: true,
wantHTTPCode: http.StatusUnauthorized,
wantBody: Response{
Code: http.StatusUnauthorized,
Message: "unauthorized",
Reason: "UNAUTHORIZED",
},
},
{
name: "not_found_error",
err: infraerrors.NotFound("NOT_FOUND", "not found"),
wantWritten: true,
wantHTTPCode: http.StatusNotFound,
wantBody: Response{
Code: http.StatusNotFound,
Message: "not found",
Reason: "NOT_FOUND",
},
},
{
name: "conflict_error",
err: infraerrors.Conflict("CONFLICT", "conflict"),
wantWritten: true,
wantHTTPCode: http.StatusConflict,
wantBody: Response{
Code: http.StatusConflict,
Message: "conflict",
Reason: "CONFLICT",
},
},
{
name: "unknown_error_defaults_to_500",
err: errors.New("boom"),
wantWritten: true,
wantHTTPCode: http.StatusInternalServerError,
wantBody: Response{
Code: http.StatusInternalServerError,
Message: infraerrors.UnknownMessage,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
written := ErrorFrom(c, tt.err)
require.Equal(t, tt.wantWritten, written)
if !tt.wantWritten {
require.Equal(t, 200, w.Code)
require.Empty(t, w.Body.String())
return
}
require.Equal(t, tt.wantHTTPCode, w.Code)
var got Response
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
require.Equal(t, tt.wantBody, got)
})
}
}

View File

@@ -127,10 +127,15 @@ type UserDashboardStats struct {
// UsageLogFilters represents filters for usage log queries
type UsageLogFilters struct {
UserID int64
ApiKeyID int64
StartTime *time.Time
EndTime *time.Time
UserID int64
ApiKeyID int64
AccountID int64
GroupID int64
Model string
Stream *bool
BillingType *int8
StartTime *time.Time
EndTime *time.Time
}
// UsageStats represents usage statistics

View File

@@ -2,68 +2,85 @@ package repository
import (
"context"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"errors"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/datatypes"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type AccountRepository struct {
type accountRepository struct {
db *gorm.DB
}
func NewAccountRepository(db *gorm.DB) *AccountRepository {
return &AccountRepository{db: db}
func NewAccountRepository(db *gorm.DB) service.AccountRepository {
return &accountRepository{db: db}
}
func (r *AccountRepository) Create(ctx context.Context, account *model.Account) error {
return r.db.WithContext(ctx).Create(account).Error
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
m := accountModelFromService(account)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyAccountModelToService(account, m)
}
return err
}
func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
var account model.Account
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&account, id).Error
func (r *accountRepository) GetByID(ctx context.Context, id int64) (*service.Account, error) {
var m accountModel
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&m, id).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
return accountModelToService(&m), nil
}
func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
if crsAccountID == "" {
return nil, nil
}
var m accountModel
err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&m).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
// 填充 GroupIDs 和 Groups 虚拟字段
account.GroupIDs = make([]int64, 0, len(account.AccountGroups))
account.Groups = make([]*model.Group, 0, len(account.AccountGroups))
for _, ag := range account.AccountGroups {
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
if ag.Group != nil {
account.Groups = append(account.Groups, ag.Group)
}
return accountModelToService(&m), nil
}
func (r *accountRepository) Update(ctx context.Context, account *service.Account) error {
m := accountModelFromService(account)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyAccountModelToService(account, m)
}
return &account, nil
return err
}
func (r *AccountRepository) Update(ctx context.Context, account *model.Account) error {
return r.db.WithContext(ctx).Save(account).Error
}
func (r *AccountRepository) Delete(ctx context.Context, id int64) error {
// 先删除账号与分组的绑定关系
if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&accountGroupModel{}).Error; err != nil {
return err
}
// 再删除账号
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
return r.db.WithContext(ctx).Delete(&accountModel{}, id).Error
}
func (r *AccountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "")
}
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
func (r *AccountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) {
var accounts []model.Account
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
var accounts []accountModel
var total int64
db := r.db.WithContext(ctx).Model(&model.Account{})
db := r.db.WithContext(ctx).Model(&accountModel{})
// Apply filters
if platform != "" {
db = db.Where("platform = ?", platform)
}
@@ -86,67 +103,84 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params paginati
return nil, nil, err
}
// 填充每个 Account 的虚拟字段GroupIDs 和 Groups
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
accounts[i].GroupIDs = make([]int64, 0, len(accounts[i].AccountGroups))
accounts[i].Groups = make([]*model.Group, 0, len(accounts[i].AccountGroups))
for _, ag := range accounts[i].AccountGroups {
accounts[i].GroupIDs = append(accounts[i].GroupIDs, ag.GroupID)
if ag.Group != nil {
accounts[i].Groups = append(accounts[i].Groups, ag.Group)
}
}
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return accounts, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return outAccounts, paginationResultFromTotal(total, params), nil
}
func (r *AccountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
var accounts []model.Account
func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
var accounts []accountModel
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ? AND accounts.status = ?", groupID, model.StatusActive).
Where("account_groups.group_id = ? AND accounts.status = ?", groupID, service.StatusActive).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
return accounts, err
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
func (r *AccountRepository) ListActive(ctx context.Context) ([]model.Account, error) {
var accounts []model.Account
func (r *accountRepository) ListActive(ctx context.Context) ([]service.Account, error) {
var accounts []accountModel
err := r.db.WithContext(ctx).
Where("status = ?", model.StatusActive).
Where("status = ?", service.StatusActive).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
func (r *AccountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
var accounts []accountModel
err := r.db.WithContext(ctx).
Where("platform = ? AND status = ?", platform, service.StatusActive).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Update("last_used_at", now).Error
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).Update("last_used_at", now).Error
}
func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Updates(map[string]any{
"status": model.StatusError,
"status": service.StatusError,
"error_message": errorMsg,
}).Error
}
func (r *AccountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
ag := &model.AccountGroup{
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
ag := &accountGroupModel{
AccountID: accountID,
GroupID: groupID,
Priority: priority,
@@ -154,133 +188,150 @@ func (r *AccountRepository) AddToGroup(ctx context.Context, accountID, groupID i
return r.db.WithContext(ctx).Create(ag).Error
}
func (r *AccountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID).
Delete(&model.AccountGroup{}).Error
Delete(&accountGroupModel{}).Error
}
func (r *AccountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) {
var groups []model.Group
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) {
var groups []groupModel
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.group_id = groups.id").
Where("account_groups.account_id = ?", accountID).
Find(&groups).Error
return groups, err
if err != nil {
return nil, err
}
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *groupModelToService(&groups[i]))
}
return outGroups, nil
}
func (r *AccountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
var accounts []model.Account
err := r.db.WithContext(ctx).
Where("platform = ? AND status = ?", platform, model.StatusActive).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
}
func (r *AccountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
// 删除现有绑定
if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&model.AccountGroup{}).Error; err != nil {
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&accountGroupModel{}).Error; err != nil {
return err
}
// 添加新绑定
if len(groupIDs) > 0 {
accountGroups := make([]model.AccountGroup, 0, len(groupIDs))
for i, groupID := range groupIDs {
accountGroups = append(accountGroups, model.AccountGroup{
AccountID: accountID,
GroupID: groupID,
Priority: i + 1, // 使用索引作为优先级
})
}
return r.db.WithContext(ctx).Create(&accountGroups).Error
if len(groupIDs) == 0 {
return nil
}
return nil
accountGroups := make([]accountGroupModel, 0, len(groupIDs))
for i, groupID := range groupIDs {
accountGroups = append(accountGroups, accountGroupModel{
AccountID: accountID,
GroupID: groupID,
Priority: i + 1,
})
}
return r.db.WithContext(ctx).Create(&accountGroups).Error
}
// ListSchedulable 获取所有可调度的账号
func (r *AccountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) {
var accounts []model.Account
func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Account, error) {
var accounts []accountModel
now := time.Now()
err := r.db.WithContext(ctx).
Where("status = ? AND schedulable = ?", model.StatusActive, true).
Where("status = ? AND schedulable = ?", service.StatusActive, true).
Where("(overload_until IS NULL OR overload_until <= ?)", now).
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
// ListSchedulableByGroupID 按组获取可调度的账号
func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) {
var accounts []model.Account
func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
var accounts []accountModel
now := time.Now()
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ?", groupID).
Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true).
Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
return accounts, err
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
// ListSchedulableByPlatform 按平台获取可调度的账号
func (r *AccountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
var accounts []model.Account
func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
var accounts []accountModel
now := time.Now()
err := r.db.WithContext(ctx).
Where("platform = ?", platform).
Where("status = ? AND schedulable = ?", model.StatusActive, true).
Where("status = ? AND schedulable = ?", service.StatusActive, true).
Where("(overload_until IS NULL OR overload_until <= ?)", now).
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
// ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号
func (r *AccountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) {
var accounts []model.Account
func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
var accounts []accountModel
now := time.Now()
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ?", groupID).
Where("accounts.platform = ?", platform).
Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true).
Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
return accounts, err
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
// SetRateLimited 标记账号为限流状态(429)
func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Updates(map[string]any{
"rate_limited_at": now,
"rate_limit_reset_at": resetAt,
}).Error
}
// SetOverloaded 标记账号为过载状态(529)
func (r *AccountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Update("overload_until", until).Error
}
// ClearRateLimit 清除账号的限流状态
func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Updates(map[string]any{
"rate_limited_at": nil,
"rate_limit_reset_at": nil,
@@ -288,8 +339,7 @@ func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error
}).Error
}
// UpdateSessionWindow 更新账号的5小时时间窗口信息
func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
updates := map[string]any{
"session_window_status": status,
}
@@ -299,39 +349,241 @@ func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
if end != nil {
updates["session_window_end"] = end
}
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Updates(updates).Error
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).Updates(updates).Error
}
// SetSchedulable 设置账号的调度开关
func (r *AccountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Update("schedulable", schedulable).Error
}
// UpdateExtra updates specific fields in account's Extra JSONB field
// It merges the updates into existing Extra data without overwriting other fields
func (r *AccountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
if len(updates) == 0 {
return nil
}
// Get current account to preserve existing Extra data
var account model.Account
var account accountModel
if err := r.db.WithContext(ctx).Select("extra").Where("id = ?", id).First(&account).Error; err != nil {
return err
}
// Initialize Extra if nil
if account.Extra == nil {
account.Extra = make(model.JSONB)
account.Extra = datatypes.JSONMap{}
}
// Merge updates into existing Extra
for k, v := range updates {
account.Extra[k] = v
}
// Save updated Extra
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Update("extra", account.Extra).Error
}
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
if len(ids) == 0 {
return 0, nil
}
updateMap := map[string]any{}
if updates.Name != nil {
updateMap["name"] = *updates.Name
}
if updates.ProxyID != nil {
updateMap["proxy_id"] = updates.ProxyID
}
if updates.Concurrency != nil {
updateMap["concurrency"] = *updates.Concurrency
}
if updates.Priority != nil {
updateMap["priority"] = *updates.Priority
}
if updates.Status != nil {
updateMap["status"] = *updates.Status
}
if len(updates.Credentials) > 0 {
updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", datatypes.JSONMap(updates.Credentials))
}
if len(updates.Extra) > 0 {
updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", datatypes.JSONMap(updates.Extra))
}
if len(updateMap) == 0 {
return 0, nil
}
result := r.db.WithContext(ctx).
Model(&accountModel{}).
Where("id IN ?", ids).
Clauses(clause.Returning{}).
Updates(updateMap)
return result.RowsAffected, result.Error
}
type accountModel struct {
ID int64 `gorm:"primaryKey"`
Name string `gorm:"size:100;not null"`
Platform string `gorm:"size:50;not null"`
Type string `gorm:"size:20;not null"`
Credentials datatypes.JSONMap `gorm:"type:jsonb;default:'{}'"`
Extra datatypes.JSONMap `gorm:"type:jsonb;default:'{}'"`
ProxyID *int64 `gorm:"index"`
Concurrency int `gorm:"default:3;not null"`
Priority int `gorm:"default:50;not null"`
Status string `gorm:"size:20;default:active;not null"`
ErrorMessage string `gorm:"type:text"`
LastUsedAt *time.Time `gorm:"index"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
Schedulable bool `gorm:"default:true;not null"`
RateLimitedAt *time.Time `gorm:"index"`
RateLimitResetAt *time.Time `gorm:"index"`
OverloadUntil *time.Time `gorm:"index"`
SessionWindowStart *time.Time
SessionWindowEnd *time.Time
SessionWindowStatus string `gorm:"size:20"`
Proxy *proxyModel `gorm:"foreignKey:ProxyID"`
AccountGroups []accountGroupModel `gorm:"foreignKey:AccountID"`
}
func (accountModel) TableName() string { return "accounts" }
type accountGroupModel struct {
AccountID int64 `gorm:"primaryKey"`
GroupID int64 `gorm:"primaryKey"`
Priority int `gorm:"default:50;not null"`
CreatedAt time.Time `gorm:"not null"`
Account *accountModel `gorm:"foreignKey:AccountID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
}
func (accountGroupModel) TableName() string { return "account_groups" }
func accountGroupModelToService(m *accountGroupModel) *service.AccountGroup {
if m == nil {
return nil
}
return &service.AccountGroup{
AccountID: m.AccountID,
GroupID: m.GroupID,
Priority: m.Priority,
CreatedAt: m.CreatedAt,
Account: accountModelToService(m.Account),
Group: groupModelToService(m.Group),
}
}
func accountModelToService(m *accountModel) *service.Account {
if m == nil {
return nil
}
var credentials map[string]any
if m.Credentials != nil {
credentials = map[string]any(m.Credentials)
}
var extra map[string]any
if m.Extra != nil {
extra = map[string]any(m.Extra)
}
account := &service.Account{
ID: m.ID,
Name: m.Name,
Platform: m.Platform,
Type: m.Type,
Credentials: credentials,
Extra: extra,
ProxyID: m.ProxyID,
Concurrency: m.Concurrency,
Priority: m.Priority,
Status: m.Status,
ErrorMessage: m.ErrorMessage,
LastUsedAt: m.LastUsedAt,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
Schedulable: m.Schedulable,
RateLimitedAt: m.RateLimitedAt,
RateLimitResetAt: m.RateLimitResetAt,
OverloadUntil: m.OverloadUntil,
SessionWindowStart: m.SessionWindowStart,
SessionWindowEnd: m.SessionWindowEnd,
SessionWindowStatus: m.SessionWindowStatus,
Proxy: proxyModelToService(m.Proxy),
}
if len(m.AccountGroups) > 0 {
account.AccountGroups = make([]service.AccountGroup, 0, len(m.AccountGroups))
account.GroupIDs = make([]int64, 0, len(m.AccountGroups))
account.Groups = make([]*service.Group, 0, len(m.AccountGroups))
for i := range m.AccountGroups {
ag := accountGroupModelToService(&m.AccountGroups[i])
if ag == nil {
continue
}
account.AccountGroups = append(account.AccountGroups, *ag)
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
if ag.Group != nil {
account.Groups = append(account.Groups, ag.Group)
}
}
}
return account
}
func accountModelFromService(a *service.Account) *accountModel {
if a == nil {
return nil
}
var credentials datatypes.JSONMap
if a.Credentials != nil {
credentials = datatypes.JSONMap(a.Credentials)
}
var extra datatypes.JSONMap
if a.Extra != nil {
extra = datatypes.JSONMap(a.Extra)
}
return &accountModel{
ID: a.ID,
Name: a.Name,
Platform: a.Platform,
Type: a.Type,
Credentials: credentials,
Extra: extra,
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
Priority: a.Priority,
Status: a.Status,
ErrorMessage: a.ErrorMessage,
LastUsedAt: a.LastUsedAt,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
Schedulable: a.Schedulable,
RateLimitedAt: a.RateLimitedAt,
RateLimitResetAt: a.RateLimitResetAt,
OverloadUntil: a.OverloadUntil,
SessionWindowStart: a.SessionWindowStart,
SessionWindowEnd: a.SessionWindowEnd,
SessionWindowStatus: a.SessionWindowStatus,
}
}
func applyAccountModelToService(account *service.Account, m *accountModel) {
if account == nil || m == nil {
return
}
account.ID = m.ID
account.CreatedAt = m.CreatedAt
account.UpdatedAt = m.UpdatedAt
}

View File

@@ -0,0 +1,585 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/datatypes"
"gorm.io/gorm"
)
type AccountRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *accountRepository
}
func (s *AccountRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewAccountRepository(s.db).(*accountRepository)
}
func TestAccountRepoSuite(t *testing.T) {
suite.Run(t, new(AccountRepoSuite))
}
// --- Create / GetByID / Update / Delete ---
func (s *AccountRepoSuite) TestCreate() {
account := &service.Account{
Name: "test-create",
Platform: service.PlatformAnthropic,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Credentials: map[string]any{},
Extra: map[string]any{},
Concurrency: 3,
Priority: 50,
Schedulable: true,
}
err := s.repo.Create(s.ctx, account)
s.Require().NoError(err, "Create")
s.Require().NotZero(account.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("test-create", got.Name)
}
func (s *AccountRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *AccountRepoSuite) TestUpdate() {
account := accountModelToService(mustCreateAccount(s.T(), s.db, &accountModel{Name: "original"}))
account.Name = "updated"
err := s.repo.Update(s.ctx, account)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated", got.Name)
}
func (s *AccountRepoSuite) TestDelete() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "to-delete"})
err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, account.ID)
s.Require().Error(err, "expected error after delete")
}
func (s *AccountRepoSuite) TestDelete_WithGroupBindings() {
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-del"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-del"})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete should cascade remove bindings")
var count int64
s.db.Model(&accountGroupModel{}).Where("account_id = ?", account.ID).Count(&count)
s.Require().Zero(count, "expected bindings to be removed")
}
// --- List / ListWithFilters ---
func (s *AccountRepoSuite) TestList() {
mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc1"})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc2"})
accounts, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
s.Require().Len(accounts, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *AccountRepoSuite) TestListWithFilters() {
tests := []struct {
name string
setup func(db *gorm.DB)
platform string
accType string
status string
search string
wantCount int
validate func(accounts []service.Account)
}{
{
name: "filter_by_platform",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic})
mustCreateAccount(s.T(), db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI})
},
platform: service.PlatformOpenAI,
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal(service.PlatformOpenAI, accounts[0].Platform)
},
},
{
name: "filter_by_type",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &accountModel{Name: "t1", Type: service.AccountTypeOAuth})
mustCreateAccount(s.T(), db, &accountModel{Name: "t2", Type: service.AccountTypeApiKey})
},
accType: service.AccountTypeApiKey,
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal(service.AccountTypeApiKey, accounts[0].Type)
},
},
{
name: "filter_by_status",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &accountModel{Name: "s1", Status: service.StatusActive})
mustCreateAccount(s.T(), db, &accountModel{Name: "s2", Status: service.StatusDisabled})
},
status: service.StatusDisabled,
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal(service.StatusDisabled, accounts[0].Status)
},
},
{
name: "filter_by_search",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &accountModel{Name: "alpha-account"})
mustCreateAccount(s.T(), db, &accountModel{Name: "beta-account"})
},
search: "alpha",
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Contains(accounts[0].Name, "alpha")
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
// 每个 case 重新获取隔离资源
db := testTx(s.T())
repo := NewAccountRepository(db).(*accountRepository)
ctx := context.Background()
tt.setup(db)
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search)
s.Require().NoError(err)
s.Require().Len(accounts, tt.wantCount)
if tt.validate != nil {
tt.validate(accounts)
}
})
}
}
// --- ListByGroup / ListActive / ListByPlatform ---
func (s *AccountRepoSuite) TestListByGroup() {
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
acc1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Status: service.StatusActive})
acc2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Status: service.StatusActive})
mustBindAccountToGroup(s.T(), s.db, acc1.ID, group.ID, 2)
mustBindAccountToGroup(s.T(), s.db, acc2.ID, group.ID, 1)
accounts, err := s.repo.ListByGroup(s.ctx, group.ID)
s.Require().NoError(err, "ListByGroup")
s.Require().Len(accounts, 2)
// Should be ordered by priority
s.Require().Equal(acc2.ID, accounts[0].ID, "expected acc2 first (priority=1)")
}
func (s *AccountRepoSuite) TestListActive() {
mustCreateAccount(s.T(), s.db, &accountModel{Name: "active1", Status: service.StatusActive})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "inactive1", Status: service.StatusDisabled})
accounts, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
s.Require().Len(accounts, 1)
s.Require().Equal("active1", accounts[0].Name)
}
func (s *AccountRepoSuite) TestListByPlatform() {
mustCreateAccount(s.T(), s.db, &accountModel{Name: "p1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "p2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
accounts, err := s.repo.ListByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err, "ListByPlatform")
s.Require().Len(accounts, 1)
s.Require().Equal(service.PlatformAnthropic, accounts[0].Platform)
}
// --- Preload and VirtualFields ---
func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
account := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "acc1",
ProxyID: &proxy.ID,
})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err, "GetByID")
s.Require().NotNil(got.Proxy, "expected Proxy preload")
s.Require().Equal(proxy.ID, got.Proxy.ID)
s.Require().Len(got.GroupIDs, 1, "expected GroupIDs to be populated")
s.Require().Equal(group.ID, got.GroupIDs[0])
s.Require().Len(got.Groups, 1, "expected Groups to be populated")
s.Require().Equal(group.ID, got.Groups[0].ID)
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc")
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(accounts, 1)
s.Require().NotNil(accounts[0].Proxy, "expected Proxy preload in list")
s.Require().Equal(proxy.ID, accounts[0].Proxy.ID)
s.Require().Len(accounts[0].GroupIDs, 1, "expected GroupIDs in list")
s.Require().Equal(group.ID, accounts[0].GroupIDs[0])
}
// --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups ---
func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc"})
s.Require().NoError(s.repo.AddToGroup(s.ctx, account.ID, g1.ID, 10), "AddToGroup")
groups, err := s.repo.GetGroups(s.ctx, account.ID)
s.Require().NoError(err, "GetGroups")
s.Require().Len(groups, 1, "expected 1 group")
s.Require().Equal(g1.ID, groups[0].ID)
s.Require().NoError(s.repo.RemoveFromGroup(s.ctx, account.ID, g1.ID), "RemoveFromGroup")
groups, err = s.repo.GetGroups(s.ctx, account.ID)
s.Require().NoError(err, "GetGroups after remove")
s.Require().Empty(groups, "expected 0 groups after remove")
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{g1.ID, g2.ID}), "BindGroups")
groups, err = s.repo.GetGroups(s.ctx, account.ID)
s.Require().NoError(err, "GetGroups after bind")
s.Require().Len(groups, 2, "expected 2 groups after bind")
}
func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-empty"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-empty"})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{}), "BindGroups empty")
groups, err := s.repo.GetGroups(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Empty(groups, "expected 0 groups after binding empty list")
}
// --- Schedulable ---
func (s *AccountRepoSuite) TestListSchedulable() {
now := time.Now()
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sched"})
okAcc := mustCreateAccount(s.T(), s.db, &accountModel{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
future := now.Add(10 * time.Minute)
overloaded := mustCreateAccount(s.T(), s.db, &accountModel{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
sched, err := s.repo.ListSchedulable(s.ctx)
s.Require().NoError(err, "ListSchedulable")
ids := idsOfAccounts(sched)
s.Require().Contains(ids, okAcc.ID)
s.Require().NotContains(ids, overloaded.ID)
}
func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates() {
now := time.Now()
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sched"})
okAcc := mustCreateAccount(s.T(), s.db, &accountModel{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
future := now.Add(10 * time.Minute)
overloaded := mustCreateAccount(s.T(), s.db, &accountModel{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
rateLimited := mustCreateAccount(s.T(), s.db, &accountModel{Name: "rl", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, rateLimited.ID, group.ID, 1)
s.Require().NoError(s.repo.SetRateLimited(s.ctx, rateLimited.ID, now.Add(10*time.Minute)), "SetRateLimited")
s.Require().NoError(s.repo.SetError(s.ctx, overloaded.ID, "boom"), "SetError")
sched, err := s.repo.ListSchedulableByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ListSchedulableByGroupID")
s.Require().Len(sched, 1, "expected only ok account schedulable")
s.Require().Equal(okAcc.ID, sched[0].ID)
s.Require().NoError(s.repo.ClearRateLimit(s.ctx, rateLimited.ID), "ClearRateLimit")
sched2, err := s.repo.ListSchedulableByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ListSchedulableByGroupID after ClearRateLimit")
s.Require().Len(sched2, 2, "expected 2 schedulable accounts after ClearRateLimit")
}
func (s *AccountRepoSuite) TestListSchedulableByPlatform() {
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err)
s.Require().Len(accounts, 1)
s.Require().Equal(service.PlatformAnthropic, accounts[0].Platform)
}
func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sp"})
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, service.PlatformAnthropic)
s.Require().NoError(err)
s.Require().Len(accounts, 1)
s.Require().Equal(a1.ID, accounts[0].ID)
}
func (s *AccountRepoSuite) TestSetSchedulable() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-sched", Schedulable: true})
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().False(got.Schedulable)
}
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
func (s *AccountRepoSuite) TestSetOverloaded() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-over"})
until := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().NotNil(got.OverloadUntil)
s.Require().WithinDuration(until, *got.OverloadUntil, time.Second)
}
func (s *AccountRepoSuite) TestSetRateLimited() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-rl"})
resetAt := time.Date(2025, 6, 15, 14, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, resetAt))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().NotNil(got.RateLimitedAt)
s.Require().NotNil(got.RateLimitResetAt)
s.Require().WithinDuration(resetAt, *got.RateLimitResetAt, time.Second)
}
func (s *AccountRepoSuite) TestClearRateLimit() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-clear"})
until := time.Now().Add(1 * time.Hour)
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, until))
s.Require().NoError(s.repo.ClearRateLimit(s.ctx, account.ID))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Nil(got.RateLimitedAt)
s.Require().Nil(got.RateLimitResetAt)
s.Require().Nil(got.OverloadUntil)
}
// --- UpdateLastUsed ---
func (s *AccountRepoSuite) TestUpdateLastUsed() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-used"})
s.Require().Nil(account.LastUsedAt)
s.Require().NoError(s.repo.UpdateLastUsed(s.ctx, account.ID))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().NotNil(got.LastUsedAt)
}
// --- SetError ---
func (s *AccountRepoSuite) TestSetError() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-err", Status: service.StatusActive})
s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong"))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Equal(service.StatusError, got.Status)
s.Require().Equal("something went wrong", got.ErrorMessage)
}
// --- UpdateSessionWindow ---
func (s *AccountRepoSuite) TestUpdateSessionWindow() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-win"})
start := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC)
end := time.Date(2025, 6, 15, 15, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.UpdateSessionWindow(s.ctx, account.ID, &start, &end, "active"))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().NotNil(got.SessionWindowStart)
s.Require().NotNil(got.SessionWindowEnd)
s.Require().Equal("active", got.SessionWindowStatus)
}
// --- UpdateExtra ---
func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
account := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "acc-extra",
Extra: datatypes.JSONMap{"a": "1"},
})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"b": "2"}), "UpdateExtra")
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("1", got.Extra["a"])
s.Require().Equal("2", got.Extra["b"])
}
func (s *AccountRepoSuite) TestUpdateExtra_EmptyUpdates() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-extra-empty"})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{}))
}
func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-nil-extra", Extra: nil})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"key": "val"}))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Equal("val", got.Extra["key"])
}
// --- GetByCRSAccountID ---
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
crsID := "crs-12345"
mustCreateAccount(s.T(), s.db, &accountModel{
Name: "acc-crs",
Extra: datatypes.JSONMap{"crs_account_id": crsID},
})
got, err := s.repo.GetByCRSAccountID(s.ctx, crsID)
s.Require().NoError(err)
s.Require().NotNil(got)
s.Require().Equal("acc-crs", got.Name)
}
func (s *AccountRepoSuite) TestGetByCRSAccountID_NotFound() {
got, err := s.repo.GetByCRSAccountID(s.ctx, "non-existent")
s.Require().NoError(err)
s.Require().Nil(got)
}
func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() {
got, err := s.repo.GetByCRSAccountID(s.ctx, "")
s.Require().NoError(err)
s.Require().Nil(got)
}
// --- BulkUpdate ---
func (s *AccountRepoSuite) TestBulkUpdate() {
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk1", Priority: 1})
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk2", Priority: 1})
newPriority := 99
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, service.AccountBulkUpdate{
Priority: &newPriority,
})
s.Require().NoError(err)
s.Require().GreaterOrEqual(affected, int64(1), "expected at least one affected row")
got1, _ := s.repo.GetByID(s.ctx, a1.ID)
got2, _ := s.repo.GetByID(s.ctx, a2.ID)
s.Require().Equal(99, got1.Priority)
s.Require().Equal(99, got2.Priority)
}
func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
a1 := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "bulk-cred",
Credentials: datatypes.JSONMap{"existing": "value"},
})
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
Credentials: datatypes.JSONMap{"new_key": "new_value"},
})
s.Require().NoError(err)
got, _ := s.repo.GetByID(s.ctx, a1.ID)
s.Require().Equal("value", got.Credentials["existing"])
s.Require().Equal("new_value", got.Credentials["new_key"])
}
func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() {
a1 := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "bulk-extra",
Extra: datatypes.JSONMap{"existing": "val"},
})
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
Extra: datatypes.JSONMap{"new_key": "new_val"},
})
s.Require().NoError(err)
got, _ := s.repo.GetByID(s.ctx, a1.ID)
s.Require().Equal("val", got.Extra["existing"])
s.Require().Equal("new_val", got.Extra["new_key"])
}
func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
affected, err := s.repo.BulkUpdate(s.ctx, []int64{}, service.AccountBulkUpdate{})
s.Require().NoError(err)
s.Require().Zero(affected)
}
func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() {
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk-empty"})
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{})
s.Require().NoError(err)
s.Require().Zero(affected)
}
func idsOfAccounts(accounts []service.Account) []int64 {
out := make([]int64, 0, len(accounts))
for i := range accounts {
out = append(out, accounts[i].ID)
}
return out
}

View File

@@ -2,11 +2,11 @@ package repository
import (
"context"
"errors"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
@@ -15,21 +15,30 @@ const (
apiKeyRateLimitDuration = 24 * time.Hour
)
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
func apiKeyRateLimitKey(userID int64) string {
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
}
type apiKeyCache struct {
rdb *redis.Client
}
func NewApiKeyCache(rdb *redis.Client) ports.ApiKeyCache {
func NewApiKeyCache(rdb *redis.Client) service.ApiKeyCache {
return &apiKeyCache{rdb: rdb}
}
func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
return c.rdb.Get(ctx, key).Int()
key := apiKeyRateLimitKey(userID)
count, err := c.rdb.Get(ctx, key).Int()
if errors.Is(err, redis.Nil) {
return 0, nil
}
return count, err
}
func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
key := apiKeyRateLimitKey(userID)
pipe := c.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
@@ -38,7 +47,7 @@ func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID in
}
func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
key := apiKeyRateLimitKey(userID)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,127 @@
//go:build integration
package repository
import (
"context"
"fmt"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ApiKeyCacheSuite struct {
IntegrationRedisSuite
}
func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
}{
{
name: "missing_key_returns_zero_nil",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
userID := int64(1)
count, err := cache.GetCreateAttemptCount(ctx, userID)
require.NoError(s.T(), err, "expected nil error for missing key")
require.Equal(s.T(), 0, count, "expected zero count for missing key")
},
},
{
name: "increment_increases_count_and_sets_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
userID := int64(1)
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount")
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount 2")
count, err := cache.GetCreateAttemptCount(ctx, userID)
require.NoError(s.T(), err, "GetCreateAttemptCount")
require.Equal(s.T(), 2, count, "count mismatch")
ttl, err := rdb.TTL(ctx, key).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, apiKeyRateLimitDuration)
},
},
{
name: "delete_removes_key",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
userID := int64(1)
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID))
require.NoError(s.T(), cache.DeleteCreateAttemptCount(ctx, userID), "DeleteCreateAttemptCount")
count, err := cache.GetCreateAttemptCount(ctx, userID)
require.NoError(s.T(), err, "expected nil error after delete")
require.Equal(s.T(), 0, count, "expected zero count after delete")
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
// 每个 case 重新获取隔离资源
rdb := testRedis(s.T())
cache := &apiKeyCache{rdb: rdb}
ctx := context.Background()
tt.fn(ctx, rdb, cache)
})
}
}
func (s *ApiKeyCacheSuite) TestDailyUsage() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
}{
{
name: "increment_increases_count",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
dailyKey := "daily:sk-test"
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage")
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage 2")
n, err := rdb.Get(ctx, dailyKey).Int()
require.NoError(s.T(), err, "Get dailyKey")
require.Equal(s.T(), 2, n, "expected daily usage=2")
},
},
{
name: "set_expiry_sets_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
dailyKey := "daily:sk-test-expiry"
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey))
require.NoError(s.T(), cache.SetDailyUsageExpiry(ctx, dailyKey, 1*time.Hour), "SetDailyUsageExpiry")
ttl, err := rdb.TTL(ctx, dailyKey).Result()
require.NoError(s.T(), err, "TTL dailyKey")
require.Greater(s.T(), ttl, time.Duration(0), "expected ttl > 0")
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
rdb := testRedis(s.T())
cache := &apiKeyCache{rdb: rdb}
ctx := context.Background()
tt.fn(ctx, rdb, cache)
})
}
}
func TestApiKeyCacheSuite(t *testing.T) {
suite.Run(t, new(ApiKeyCacheSuite))
}

View File

@@ -0,0 +1,46 @@
//go:build unit
package repository
import (
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestApiKeyRateLimitKey(t *testing.T) {
tests := []struct {
name string
userID int64
expected string
}{
{
name: "normal_user_id",
userID: 123,
expected: "apikey:ratelimit:123",
},
{
name: "zero_user_id",
userID: 0,
expected: "apikey:ratelimit:0",
},
{
name: "negative_user_id",
userID: -1,
expected: "apikey:ratelimit:-1",
},
{
name: "max_int64",
userID: math.MaxInt64,
expected: "apikey:ratelimit:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := apiKeyRateLimitKey(tc.userID)
require.Equal(t, tc.expected, got)
})
}
}

View File

@@ -2,55 +2,68 @@ package repository
import (
"context"
"github.com/Wei-Shaw/sub2api/internal/model"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
type ApiKeyRepository struct {
type apiKeyRepository struct {
db *gorm.DB
}
func NewApiKeyRepository(db *gorm.DB) *ApiKeyRepository {
return &ApiKeyRepository{db: db}
func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository {
return &apiKeyRepository{db: db}
}
func (r *ApiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error {
return r.db.WithContext(ctx).Create(key).Error
}
func (r *ApiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
var key model.ApiKey
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&key, id).Error
if err != nil {
return nil, err
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
m := apiKeyModelFromService(key)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyApiKeyModelToService(key, m)
}
return &key, nil
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
}
func (r *ApiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
var apiKey model.ApiKey
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&apiKey).Error
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
var m apiKeyModel
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&m, id).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
}
return &apiKey, nil
return apiKeyModelToService(&m), nil
}
func (r *ApiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error {
return r.db.WithContext(ctx).Model(key).Select("name", "group_id", "status", "updated_at").Updates(key).Error
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
var m apiKeyModel
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&m).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
}
return apiKeyModelToService(&m), nil
}
func (r *ApiKeyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
m := apiKeyModelFromService(key)
err := r.db.WithContext(ctx).Model(m).Select("name", "group_id", "status", "updated_at").Updates(m).Error
if err == nil {
applyApiKeyModelToService(key, m)
}
return err
}
func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&apiKeyModel{}, id).Error
}
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
var keys []apiKeyModel
var total int64
db := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID)
db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID)
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
@@ -60,36 +73,47 @@ func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
}
return keys, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return outKeys, paginationResultFromTotal(total, params), nil
}
func (r *ApiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
if len(apiKeyIDs) == 0 {
return []int64{}, nil
}
ids := make([]int64, 0, len(apiKeyIDs))
err := r.db.WithContext(ctx).
Model(&apiKeyModel{}).
Where("user_id = ? AND id IN ?", userID, apiKeyIDs).
Pluck("id", &ids).Error
if err != nil {
return nil, err
}
return ids, nil
}
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID).Count(&count).Error
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID).Count(&count).Error
return count, err
}
func (r *ApiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("key = ?", key).Count(&count).Error
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("key = ?", key).Count(&count).Error
return count > 0, err
}
func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
var keys []apiKeyModel
var total int64
db := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID)
db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID)
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
@@ -99,24 +123,19 @@ func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
}
return keys, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return outKeys, paginationResultFromTotal(total, params), nil
}
// SearchApiKeys searches API keys by user ID and/or keyword (name)
func (r *ApiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
var keys []model.ApiKey
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
var keys []apiKeyModel
db := r.db.WithContext(ctx).Model(&model.ApiKey{})
db := r.db.WithContext(ctx).Model(&apiKeyModel{})
if userID > 0 {
db = db.Where("user_id = ?", userID)
@@ -131,20 +150,84 @@ func (r *ApiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
return nil, err
}
return keys, nil
outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
}
return outKeys, nil
}
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *ApiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.ApiKey{}).
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&apiKeyModel{}).
Where("group_id = ?", groupID).
Update("group_id", nil)
return result.RowsAffected, result.Error
}
// CountByGroupID 获取分组的 API Key 数量
func (r *ApiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID).Count(&count).Error
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID).Count(&count).Error
return count, err
}
type apiKeyModel struct {
ID int64 `gorm:"primaryKey"`
UserID int64 `gorm:"index;not null"`
Key string `gorm:"uniqueIndex;size:128;not null"`
Name string `gorm:"size:100;not null"`
GroupID *int64 `gorm:"index"`
Status string `gorm:"size:20;default:active;not null"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
User *userModel `gorm:"foreignKey:UserID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
}
func (apiKeyModel) TableName() string { return "api_keys" }
func apiKeyModelToService(m *apiKeyModel) *service.ApiKey {
if m == nil {
return nil
}
return &service.ApiKey{
ID: m.ID,
UserID: m.UserID,
Key: m.Key,
Name: m.Name,
GroupID: m.GroupID,
Status: m.Status,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
User: userModelToService(m.User),
Group: groupModelToService(m.Group),
}
}
func apiKeyModelFromService(k *service.ApiKey) *apiKeyModel {
if k == nil {
return nil
}
return &apiKeyModel{
ID: k.ID,
UserID: k.UserID,
Key: k.Key,
Name: k.Name,
GroupID: k.GroupID,
Status: k.Status,
CreatedAt: k.CreatedAt,
UpdatedAt: k.UpdatedAt,
}
}
func applyApiKeyModelToService(key *service.ApiKey, m *apiKeyModel) {
if key == nil || m == nil {
return
}
key.ID = m.ID
key.CreatedAt = m.CreatedAt
key.UpdatedAt = m.UpdatedAt
}

View File

@@ -0,0 +1,355 @@
//go:build integration
package repository
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type ApiKeyRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *apiKeyRepository
}
func (s *ApiKeyRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewApiKeyRepository(s.db).(*apiKeyRepository)
}
func TestApiKeyRepoSuite(t *testing.T) {
suite.Run(t, new(ApiKeyRepoSuite))
}
// --- Create / GetByID / GetByKey ---
func (s *ApiKeyRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "create@test.com"})
key := &service.ApiKey{
UserID: user.ID,
Key: "sk-create-test",
Name: "Test Key",
Status: service.StatusActive,
}
err := s.repo.Create(s.ctx, key)
s.Require().NoError(err, "Create")
s.Require().NotZero(key.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, key.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("sk-create-test", got.Key)
}
func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *ApiKeyRepoSuite) TestGetByKey() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "getbykey@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-key"})
key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-getbykey",
Name: "My Key",
GroupID: &group.ID,
Status: service.StatusActive,
})
got, err := s.repo.GetByKey(s.ctx, key.Key)
s.Require().NoError(err, "GetByKey")
s.Require().Equal(key.ID, got.ID)
s.Require().NotNil(got.User, "expected User preload")
s.Require().Equal(user.ID, got.User.ID)
s.Require().NotNil(got.Group, "expected Group preload")
s.Require().Equal(group.ID, got.Group.ID)
}
func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
_, err := s.repo.GetByKey(s.ctx, "non-existent-key")
s.Require().Error(err, "expected error for non-existent key")
}
// --- Update ---
func (s *ApiKeyRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com"})
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-update",
Name: "Original",
Status: service.StatusActive,
}))
key.Name = "Renamed"
key.Status = service.StatusDisabled
err := s.repo.Update(s.ctx, key)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, key.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("sk-update", got.Key, "Update should not change key")
s.Require().Equal(user.ID, got.UserID, "Update should not change user_id")
s.Require().Equal("Renamed", got.Name)
s.Require().Equal(service.StatusDisabled, got.Status)
}
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear"})
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-clear-group",
Name: "Group Key",
GroupID: &group.ID,
}))
key.GroupID = nil
err := s.repo.Update(s.ctx, key)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, key.ID)
s.Require().NoError(err)
s.Require().Nil(got.GroupID, "expected GroupID to be cleared")
}
// --- Delete ---
func (s *ApiKeyRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-delete",
Name: "Delete Me",
})
err := s.repo.Delete(s.ctx, key.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, key.ID)
s.Require().Error(err, "expected error after delete")
}
// --- ListByUserID / CountByUserID ---
func (s *ApiKeyRepoSuite) TestListByUserID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyuser@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"})
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUserID")
s.Require().Len(keys, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "paging@test.com"})
for i := 0; i < 5; i++ {
mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-page-" + string(rune('a'+i)),
Name: "Key",
})
}
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
s.Require().NoError(err)
s.Require().Len(keys, 2)
s.Require().Equal(int64(5), page.Total)
s.Require().Equal(3, page.Pages)
}
func (s *ApiKeyRepoSuite) TestCountByUserID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "count@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-2", Name: "K2"})
count, err := s.repo.CountByUserID(s.ctx, user.ID)
s.Require().NoError(err, "CountByUserID")
s.Require().Equal(int64(2), count)
}
// --- ListByGroupID / CountByGroupID ---
func (s *ApiKeyRepoSuite) TestListByGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbygroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByGroupID")
s.Require().Len(keys, 2)
s.Require().Equal(int64(2), page.Total)
// User preloaded
s.Require().NotNil(keys[0].User)
}
func (s *ApiKeyRepoSuite) TestCountByGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "countgroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID})
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID")
s.Require().Equal(int64(1), count)
}
// --- ExistsByKey ---
func (s *ApiKeyRepoSuite) TestExistsByKey() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-exists", Name: "K"})
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
s.Require().NoError(err, "ExistsByKey")
s.Require().True(exists)
notExists, err := s.repo.ExistsByKey(s.ctx, "sk-not-exists")
s.Require().NoError(err)
s.Require().False(notExists)
}
// --- SearchApiKeys ---
func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "search@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"})
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
s.Require().NoError(err, "SearchApiKeys")
s.Require().Len(found, 1)
s.Require().Contains(found[0].Name, "Production")
}
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnokw@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-2", Name: "K2"})
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
s.Require().NoError(err)
s.Require().Len(found, 2)
}
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnouid@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"})
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
s.Require().NoError(err)
s.Require().Len(found, 1)
}
// --- ClearGroupIDByGroupID ---
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargrp@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear-bulk"})
k1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID})
k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ClearGroupIDByGroupID")
s.Require().Equal(int64(2), affected)
got1, _ := s.repo.GetByID(s.ctx, k1.ID)
got2, _ := s.repo.GetByID(s.ctx, k2.ID)
s.Require().Nil(got1.GroupID)
s.Require().Nil(got2.GroupID)
count, _ := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().Zero(count)
}
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "k@example.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-k"})
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-test-1",
Name: "My Key",
GroupID: &group.ID,
Status: service.StatusActive,
}))
got, err := s.repo.GetByKey(s.ctx, key.Key)
s.Require().NoError(err, "GetByKey")
s.Require().Equal(key.ID, got.ID)
s.Require().NotNil(got.User)
s.Require().Equal(user.ID, got.User.ID)
s.Require().NotNil(got.Group)
s.Require().Equal(group.ID, got.Group.ID)
key.Name = "Renamed"
key.Status = service.StatusDisabled
key.GroupID = nil
s.Require().NoError(s.repo.Update(s.ctx, key), "Update")
got2, err := s.repo.GetByID(s.ctx, key.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("sk-test-1", got2.Key, "Update should not change key")
s.Require().Equal(user.ID, got2.UserID, "Update should not change user_id")
s.Require().Equal("Renamed", got2.Name)
s.Require().Equal(service.StatusDisabled, got2.Status)
s.Require().Nil(got2.GroupID)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUserID")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(keys, 1)
exists, err := s.repo.ExistsByKey(s.ctx, "sk-test-1")
s.Require().NoError(err, "ExistsByKey")
s.Require().True(exists, "expected key to exist")
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "renam", 10)
s.Require().NoError(err, "SearchApiKeys")
s.Require().Len(found, 1)
s.Require().Equal(key.ID, found[0].ID)
// ClearGroupIDByGroupID
k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-test-2",
Name: "Group Key",
GroupID: &group.ID,
})
countBefore, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID")
s.Require().Equal(int64(1), countBefore, "expected 1 key in group before clear")
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ClearGroupIDByGroupID")
s.Require().Equal(int64(1), affected, "expected 1 affected row")
got3, err := s.repo.GetByID(s.ctx, k2.ID)
s.Require().NoError(err, "GetByID")
s.Require().Nil(got3.GroupID, "expected GroupID cleared")
countAfter, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID after clear")
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
}

View File

@@ -0,0 +1,20 @@
package repository
import "gorm.io/gorm"
// AutoMigrate runs schema migrations for all repository persistence models.
// Persistence models are defined within individual `*_repo.go` files.
func AutoMigrate(db *gorm.DB) error {
return db.AutoMigrate(
&userModel{},
&apiKeyModel{},
&groupModel{},
&accountModel{},
&accountGroupModel{},
&proxyModel{},
&redeemCodeModel{},
&usageLogModel{},
&settingModel{},
&userSubscriptionModel{},
)
}

View File

@@ -8,8 +8,7 @@ import (
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
@@ -19,6 +18,16 @@ const (
billingCacheTTL = 5 * time.Minute
)
// billingBalanceKey generates the Redis key for user balance cache.
func billingBalanceKey(userID int64) string {
return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
}
// billingSubKey generates the Redis key for subscription cache.
func billingSubKey(userID, groupID int64) string {
return fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
}
const (
subFieldStatus = "status"
subFieldExpiresAt = "expires_at"
@@ -58,12 +67,12 @@ type billingCache struct {
rdb *redis.Client
}
func NewBillingCache(rdb *redis.Client) ports.BillingCache {
func NewBillingCache(rdb *redis.Client) service.BillingCache {
return &billingCache{rdb: rdb}
}
func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
key := billingBalanceKey(userID)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return 0, err
@@ -72,12 +81,12 @@ func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float6
}
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
key := billingBalanceKey(userID)
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
}
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
key := billingBalanceKey(userID)
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
@@ -86,12 +95,12 @@ func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amou
}
func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
key := billingBalanceKey(userID)
return c.rdb.Del(ctx, key).Err()
}
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*ports.SubscriptionCacheData, error) {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*service.SubscriptionCacheData, error) {
key := billingSubKey(userID, groupID)
result, err := c.rdb.HGetAll(ctx, key).Result()
if err != nil {
return nil, err
@@ -102,8 +111,8 @@ func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID
return c.parseSubscriptionCache(result)
}
func (c *billingCache) parseSubscriptionCache(data map[string]string) (*ports.SubscriptionCacheData, error) {
result := &ports.SubscriptionCacheData{}
func (c *billingCache) parseSubscriptionCache(data map[string]string) (*service.SubscriptionCacheData, error) {
result := &service.SubscriptionCacheData{}
result.Status = data[subFieldStatus]
if result.Status == "" {
@@ -136,12 +145,12 @@ func (c *billingCache) parseSubscriptionCache(data map[string]string) (*ports.Su
return result, nil
}
func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *ports.SubscriptionCacheData) error {
func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *service.SubscriptionCacheData) error {
if data == nil {
return nil
}
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
key := billingSubKey(userID, groupID)
fields := map[string]any{
subFieldStatus: data.Status,
@@ -160,7 +169,7 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID
}
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
key := billingSubKey(userID, groupID)
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
@@ -169,6 +178,6 @@ func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, grou
}
func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
key := billingSubKey(userID, groupID)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,283 @@
//go:build integration
package repository
import (
"context"
"fmt"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type BillingCacheSuite struct {
IntegrationRedisSuite
}
func (s *BillingCacheSuite) TestUserBalance() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
}{
{
name: "missing_key_returns_redis_nil",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
_, err := cache.GetUserBalance(ctx, 1)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing balance key")
},
},
{
name: "deduct_on_nonexistent_is_noop",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(1)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 1), "DeductUserBalance should not error")
_, err := rdb.Get(ctx, balanceKey).Result()
require.ErrorIs(s.T(), err, redis.Nil, "expected missing key after deduct on non-existent")
},
},
{
name: "set_and_get_with_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(2)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
got, err := cache.GetUserBalance(ctx, userID)
require.NoError(s.T(), err, "GetUserBalance")
require.Equal(s.T(), 10.5, got, "balance mismatch")
ttl, err := rdb.TTL(ctx, balanceKey).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
},
},
{
name: "deduct_reduces_balance",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(3)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 2.25), "DeductUserBalance")
got, err := cache.GetUserBalance(ctx, userID)
require.NoError(s.T(), err, "GetUserBalance after deduct")
require.Equal(s.T(), 8.25, got, "deduct mismatch")
},
},
{
name: "invalidate_removes_key",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(100)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 50.0), "SetUserBalance")
exists, err := rdb.Exists(ctx, balanceKey).Result()
require.NoError(s.T(), err, "Exists")
require.Equal(s.T(), int64(1), exists, "expected balance key to exist")
require.NoError(s.T(), cache.InvalidateUserBalance(ctx, userID), "InvalidateUserBalance")
exists, err = rdb.Exists(ctx, balanceKey).Result()
require.NoError(s.T(), err, "Exists after invalidate")
require.Equal(s.T(), int64(0), exists, "expected balance key to be removed after invalidate")
_, err = cache.GetUserBalance(ctx, userID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
},
},
{
name: "deduct_refreshes_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(103)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 100.0), "SetUserBalance")
ttl1, err := rdb.TTL(ctx, balanceKey).Result()
require.NoError(s.T(), err, "TTL before deduct")
s.AssertTTLWithin(ttl1, 1*time.Second, billingCacheTTL)
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 25.0), "DeductUserBalance")
balance, err := cache.GetUserBalance(ctx, userID)
require.NoError(s.T(), err, "GetUserBalance")
require.Equal(s.T(), 75.0, balance, "expected balance 75.0")
ttl2, err := rdb.TTL(ctx, balanceKey).Result()
require.NoError(s.T(), err, "TTL after deduct")
s.AssertTTLWithin(ttl2, 1*time.Second, billingCacheTTL)
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
rdb := testRedis(s.T())
cache := NewBillingCache(rdb)
ctx := context.Background()
tt.fn(ctx, rdb, cache)
})
}
}
func (s *BillingCacheSuite) TestSubscriptionCache() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
}{
{
name: "missing_key_returns_redis_nil",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(10)
groupID := int64(20)
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing subscription key")
},
},
{
name: "update_usage_on_nonexistent_is_noop",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(11)
groupID := int64(21)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 1.0), "UpdateSubscriptionUsage should not error")
exists, err := rdb.Exists(ctx, subKey).Result()
require.NoError(s.T(), err, "Exists")
require.Equal(s.T(), int64(0), exists, "expected missing subscription key after UpdateSubscriptionUsage on non-existent")
},
},
{
name: "set_and_get_with_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(12)
groupID := int64(22)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
data := &service.SubscriptionCacheData{
Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour),
DailyUsage: 1.0,
WeeklyUsage: 2.0,
MonthlyUsage: 3.0,
Version: 7,
}
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
require.NoError(s.T(), err, "GetSubscriptionCache")
require.Equal(s.T(), "active", gotSub.Status)
require.Equal(s.T(), int64(7), gotSub.Version)
require.Equal(s.T(), 1.0, gotSub.DailyUsage)
ttl, err := rdb.TTL(ctx, subKey).Result()
require.NoError(s.T(), err, "TTL subKey")
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
},
},
{
name: "update_usage_increments_all_fields",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(13)
groupID := int64(23)
data := &service.SubscriptionCacheData{
Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour),
DailyUsage: 1.0,
WeeklyUsage: 2.0,
MonthlyUsage: 3.0,
Version: 1,
}
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 0.5), "UpdateSubscriptionUsage")
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
require.NoError(s.T(), err, "GetSubscriptionCache after update")
require.Equal(s.T(), 1.5, gotSub.DailyUsage)
require.Equal(s.T(), 2.5, gotSub.WeeklyUsage)
require.Equal(s.T(), 3.5, gotSub.MonthlyUsage)
},
},
{
name: "invalidate_removes_key",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(101)
groupID := int64(10)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
data := &service.SubscriptionCacheData{
Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour),
DailyUsage: 1.0,
WeeklyUsage: 2.0,
MonthlyUsage: 3.0,
Version: 1,
}
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
exists, err := rdb.Exists(ctx, subKey).Result()
require.NoError(s.T(), err, "Exists")
require.Equal(s.T(), int64(1), exists, "expected subscription key to exist")
require.NoError(s.T(), cache.InvalidateSubscriptionCache(ctx, userID, groupID), "InvalidateSubscriptionCache")
exists, err = rdb.Exists(ctx, subKey).Result()
require.NoError(s.T(), err, "Exists after invalidate")
require.Equal(s.T(), int64(0), exists, "expected subscription key to be removed after invalidate")
_, err = cache.GetSubscriptionCache(ctx, userID, groupID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
},
},
{
name: "missing_status_returns_parsing_error",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(102)
groupID := int64(11)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
fields := map[string]any{
"expires_at": time.Now().Add(1 * time.Hour).Unix(),
"daily_usage": 1.0,
"weekly_usage": 2.0,
"monthly_usage": 3.0,
"version": 1,
}
require.NoError(s.T(), rdb.HSet(ctx, subKey, fields).Err(), "HSet")
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
require.Error(s.T(), err, "expected error for missing status field")
require.NotErrorIs(s.T(), err, redis.Nil, "expected parsing error, not redis.Nil")
require.Equal(s.T(), "invalid cache: missing status", err.Error())
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
rdb := testRedis(s.T())
cache := NewBillingCache(rdb)
ctx := context.Background()
tt.fn(ctx, rdb, cache)
})
}
}
func TestBillingCacheSuite(t *testing.T) {
suite.Run(t, new(BillingCacheSuite))
}

View File

@@ -0,0 +1,87 @@
//go:build unit
package repository
import (
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestBillingBalanceKey(t *testing.T) {
tests := []struct {
name string
userID int64
expected string
}{
{
name: "normal_user_id",
userID: 123,
expected: "billing:balance:123",
},
{
name: "zero_user_id",
userID: 0,
expected: "billing:balance:0",
},
{
name: "negative_user_id",
userID: -1,
expected: "billing:balance:-1",
},
{
name: "max_int64",
userID: math.MaxInt64,
expected: "billing:balance:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := billingBalanceKey(tc.userID)
require.Equal(t, tc.expected, got)
})
}
}
func TestBillingSubKey(t *testing.T) {
tests := []struct {
name string
userID int64
groupID int64
expected string
}{
{
name: "normal_ids",
userID: 123,
groupID: 456,
expected: "billing:sub:123:456",
},
{
name: "zero_ids",
userID: 0,
groupID: 0,
expected: "billing:sub:0:0",
},
{
name: "negative_ids",
userID: -1,
groupID: -2,
expected: "billing:sub:-1:-2",
},
{
name: "max_int64_ids",
userID: math.MaxInt64,
groupID: math.MaxInt64,
expected: "billing:sub:9223372036854775807:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := billingSubKey(tc.userID, tc.groupID)
require.Equal(t, tc.expected, got)
})
}
}

View File

@@ -16,20 +16,28 @@ import (
"github.com/imroc/req/v3"
)
type claudeOAuthService struct{}
func NewClaudeOAuthClient() service.ClaudeOAuthClient {
return &claudeOAuthService{}
return &claudeOAuthService{
baseURL: "https://claude.ai",
tokenURL: oauth.TokenURL,
clientFactory: createReqClient,
}
}
type claudeOAuthService struct {
baseURL string
tokenURL string
clientFactory func(proxyURL string) *req.Client
}
func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
client := createReqClient(proxyURL)
client := s.clientFactory(proxyURL)
var orgs []struct {
UUID string `json:"uuid"`
}
targetURL := "https://claude.ai/api/organizations"
targetURL := s.baseURL + "/api/organizations"
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
resp, err := client.R().
@@ -61,9 +69,9 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
}
func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
client := createReqClient(proxyURL)
client := s.clientFactory(proxyURL)
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID)
reqBody := map[string]any{
"response_type": "code",
@@ -133,12 +141,12 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
fullCode = authCode + "#" + responseState
}
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", authCode[:20])
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", prefix(authCode, 20))
return fullCode, nil
}
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) {
client := createReqClient(proxyURL)
client := s.clientFactory(proxyURL)
// Parse code which may contain state in format "authCode#state"
authCode := code
@@ -161,7 +169,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", oauth.TokenURL)
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
var tokenResp oauth.TokenResponse
@@ -171,7 +179,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
SetHeader("Content-Type", "application/json").
SetBody(reqBody).
SetSuccessResult(&tokenResp).
Post(oauth.TokenURL)
Post(s.tokenURL)
if err != nil {
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
@@ -189,20 +197,24 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
}
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
client := createReqClient(proxyURL)
client := s.clientFactory(proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken)
formData.Set("client_id", oauth.ClientID)
// 使用 JSON 格式(与 ExchangeCodeForToken 保持一致)
// Anthropic OAuth API 期望 JSON 格式的请求体
reqBody := map[string]any{
"grant_type": "refresh_token",
"refresh_token": refreshToken,
"client_id": oauth.ClientID,
}
var tokenResp oauth.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetFormDataFromValues(formData).
SetHeader("Content-Type", "application/json").
SetBody(reqBody).
SetSuccessResult(&tokenResp).
Post(oauth.TokenURL)
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
@@ -226,3 +238,13 @@ func createReqClient(proxyURL string) *req.Client {
return client
}
func prefix(s string, n int) string {
if n <= 0 {
return ""
}
if len(s) <= n {
return s
}
return s[:n]
}

View File

@@ -0,0 +1,372 @@
package repository
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ClaudeOAuthServiceSuite struct {
suite.Suite
srv *httptest.Server
client *claudeOAuthService
}
func (s *ClaudeOAuthServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
// requestCapture holds captured request data for assertions in the main goroutine.
type requestCapture struct {
path string
method string
cookies []*http.Cookie
body []byte
bodyJSON map[string]any
contentType string
}
func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
tests := []struct {
name string
handler http.HandlerFunc
wantErr bool
errContain string
wantUUID string
validate func(captured requestCapture)
}{
{
name: "success",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`[{"uuid":"org-1"}]`))
},
wantUUID: "org-1",
validate: func(captured requestCapture) {
require.Equal(s.T(), "/api/organizations", captured.path, "unexpected path")
require.Len(s.T(), captured.cookies, 1, "expected 1 cookie")
require.Equal(s.T(), "sessionKey", captured.cookies[0].Name)
require.Equal(s.T(), "sess", captured.cookies[0].Value)
},
},
{
name: "non_200_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte("unauthorized"))
},
wantErr: true,
errContain: "401",
},
{
name: "invalid_json_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte("not-json"))
},
wantErr: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.path = r.URL.Path
captured.cookies = r.Cookies()
tt.handler(w, r)
}))
defer s.srv.Close()
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.baseURL = s.srv.URL
got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "")
if tt.wantErr {
require.Error(s.T(), err)
if tt.errContain != "" {
require.ErrorContains(s.T(), err, tt.errContain)
}
return
}
require.NoError(s.T(), err)
require.Equal(s.T(), tt.wantUUID, got)
if tt.validate != nil {
tt.validate(captured)
}
})
}
}
func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
tests := []struct {
name string
handler http.HandlerFunc
wantErr bool
wantCode string
validate func(captured requestCapture)
}{
{
name: "parses_redirect_uri",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{
"redirect_uri": oauth.RedirectURI + "?code=AUTH&state=STATE",
})
},
wantCode: "AUTH#STATE",
validate: func(captured requestCapture) {
require.True(s.T(), strings.HasPrefix(captured.path, "/v1/oauth/") && strings.HasSuffix(captured.path, "/authorize"), "unexpected path: %s", captured.path)
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
require.Len(s.T(), captured.cookies, 1, "expected 1 cookie")
require.Equal(s.T(), "sess", captured.cookies[0].Value)
require.Equal(s.T(), "org-1", captured.bodyJSON["organization_uuid"])
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
require.Equal(s.T(), "st", captured.bodyJSON["state"])
},
},
{
name: "missing_code_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{
"redirect_uri": oauth.RedirectURI + "?state=STATE", // no code
})
},
wantErr: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.path = r.URL.Path
captured.method = r.Method
captured.cookies = r.Cookies()
captured.body, _ = io.ReadAll(r.Body)
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r)
}))
defer s.srv.Close()
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.baseURL = s.srv.URL
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "")
if tt.wantErr {
require.Error(s.T(), err)
return
}
require.NoError(s.T(), err)
require.Equal(s.T(), tt.wantCode, code)
if tt.validate != nil {
tt.validate(captured)
}
})
}
}
func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
tests := []struct {
name string
handler http.HandlerFunc
code string
wantErr bool
wantResp *oauth.TokenResponse
validate func(captured requestCapture)
}{
{
name: "sends_state_when_embedded",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
AccessToken: "at",
TokenType: "bearer",
ExpiresIn: 3600,
RefreshToken: "rt",
Scope: "s",
})
},
code: "AUTH#STATE2",
wantResp: &oauth.TokenResponse{
AccessToken: "at",
RefreshToken: "rt",
},
validate: func(captured requestCapture) {
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
require.True(s.T(), strings.HasPrefix(captured.contentType, "application/json"), "unexpected content-type")
require.Equal(s.T(), "AUTH", captured.bodyJSON["code"])
require.Equal(s.T(), "STATE2", captured.bodyJSON["state"])
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
require.Equal(s.T(), "ver", captured.bodyJSON["code_verifier"])
},
},
{
name: "non_200_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte("bad request"))
},
code: "AUTH",
wantErr: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.method = r.Method
captured.contentType = r.Header.Get("Content-Type")
captured.body, _ = io.ReadAll(r.Body)
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r)
}))
defer s.srv.Close()
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.tokenURL = s.srv.URL
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "")
if tt.wantErr {
require.Error(s.T(), err)
return
}
require.NoError(s.T(), err)
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
require.Equal(s.T(), tt.wantResp.RefreshToken, resp.RefreshToken)
if tt.validate != nil {
tt.validate(captured)
}
})
}
}
func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
tests := []struct {
name string
handler http.HandlerFunc
wantErr bool
wantResp *oauth.TokenResponse
validate func(captured requestCapture)
}{
{
name: "sends_json_format",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
AccessToken: "new_access_token",
TokenType: "bearer",
ExpiresIn: 28800,
RefreshToken: "new_refresh_token",
Scope: "user:profile user:inference",
})
},
wantResp: &oauth.TokenResponse{
AccessToken: "new_access_token",
RefreshToken: "new_refresh_token",
},
validate: func(captured requestCapture) {
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
// 验证使用 JSON 格式(不是 form 格式)
require.True(s.T(), strings.HasPrefix(captured.contentType, "application/json"),
"expected JSON content-type, got: %s", captured.contentType)
// 验证 JSON body 内容
require.Equal(s.T(), "refresh_token", captured.bodyJSON["grant_type"])
require.Equal(s.T(), "rt", captured.bodyJSON["refresh_token"])
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
},
},
{
name: "returns_new_refresh_token",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
AccessToken: "at",
TokenType: "bearer",
ExpiresIn: 28800,
RefreshToken: "rotated_rt", // Anthropic rotates refresh tokens
})
},
wantResp: &oauth.TokenResponse{
AccessToken: "at",
RefreshToken: "rotated_rt",
},
},
{
name: "non_200_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"error":"invalid_grant"}`))
},
wantErr: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.method = r.Method
captured.contentType = r.Header.Get("Content-Type")
captured.body, _ = io.ReadAll(r.Body)
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r)
}))
defer s.srv.Close()
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.tokenURL = s.srv.URL
resp, err := s.client.RefreshToken(context.Background(), "rt", "")
if tt.wantErr {
require.Error(s.T(), err)
return
}
require.NoError(s.T(), err)
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
require.Equal(s.T(), tt.wantResp.RefreshToken, resp.RefreshToken)
if tt.validate != nil {
tt.validate(captured)
}
})
}
}
func TestClaudeOAuthServiceSuite(t *testing.T) {
suite.Run(t, new(ClaudeOAuthServiceSuite))
}

View File

@@ -12,10 +12,14 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service"
)
type claudeUsageService struct{}
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
type claudeUsageService struct {
usageURL string
}
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
return &claudeUsageService{}
return &claudeUsageService{usageURL: defaultClaudeUsageURL}
}
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
@@ -35,7 +39,7 @@ func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyU
Timeout: 30 * time.Second,
}
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil)
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}

View File

@@ -0,0 +1,105 @@
package repository
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ClaudeUsageServiceSuite struct {
suite.Suite
srv *httptest.Server
fetcher *claudeUsageService
}
func (s *ClaudeUsageServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
// usageRequestCapture holds captured request data for assertions in the main goroutine.
type usageRequestCapture struct {
authorization string
anthropicBeta string
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
var captured usageRequestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.authorization = r.Header.Get("Authorization")
captured.anthropicBeta = r.Header.Get("anthropic-beta")
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{
"five_hour": {"utilization": 12.5, "resets_at": "2025-01-01T00:00:00Z"},
"seven_day": {"utilization": 34.0, "resets_at": "2025-01-08T00:00:00Z"},
"seven_day_sonnet": {"utilization": 56.0, "resets_at": "2025-01-08T00:00:00Z"}
}`)
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
require.NoError(s.T(), err, "FetchUsage")
require.Equal(s.T(), 12.5, resp.FiveHour.Utilization, "FiveHour utilization mismatch")
require.Equal(s.T(), 34.0, resp.SevenDay.Utilization, "SevenDay utilization mismatch")
require.Equal(s.T(), 56.0, resp.SevenDaySonnet.Utilization, "SevenDaySonnet utilization mismatch")
// Assertions on captured request data
require.Equal(s.T(), "Bearer at", captured.authorization, "Authorization header mismatch")
require.Equal(s.T(), "oauth-2025-04-20", captured.anthropicBeta, "anthropic-beta header mismatch")
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = io.WriteString(w, "nope")
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "status 401")
require.ErrorContains(s.T(), err, "nope")
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-json")
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "decode response failed")
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Never respond - simulate slow server
<-r.Context().Done()
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
_, err := s.fetcher.FetchUsage(ctx, "at", "")
require.Error(s.T(), err, "expected error for cancelled context")
}
func TestClaudeUsageServiceSuite(t *testing.T) {
suite.Run(t, new(ClaudeUsageServiceSuite))
}

View File

@@ -5,8 +5,7 @@ import (
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
@@ -107,7 +106,7 @@ type concurrencyCache struct {
rdb *redis.Client
}
func NewConcurrencyCache(rdb *redis.Client) ports.ConcurrencyCache {
func NewConcurrencyCache(rdb *redis.Client) service.ConcurrencyCache {
return &concurrencyCache{rdb: rdb}
}

View File

@@ -0,0 +1,231 @@
//go:build integration
package repository
import (
"errors"
"fmt"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ConcurrencyCacheSuite struct {
IntegrationRedisSuite
cache service.ConcurrencyCache
}
func (s *ConcurrencyCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewConcurrencyCache(s.rdb)
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
accountID := int64(10)
reqID1, reqID2, reqID3 := "req1", "req2", "req3"
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID1)
require.NoError(s.T(), err, "AcquireAccountSlot 1")
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID2)
require.NoError(s.T(), err, "AcquireAccountSlot 2")
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID3)
require.NoError(s.T(), err, "AcquireAccountSlot 3")
require.False(s.T(), ok, "expected third acquire to fail")
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err, "GetAccountConcurrency")
require.Equal(s.T(), 2, cur, "concurrency mismatch")
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID1), "ReleaseAccountSlot")
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err, "GetAccountConcurrency after release")
require.Equal(s.T(), 1, cur, "expected 1 after release")
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
accountID := int64(11)
reqID := "req_ttl_test"
slotKey := fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, reqID)
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID)
require.NoError(s.T(), err, "AcquireAccountSlot")
require.True(s.T(), ok)
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() {
accountID := int64(12)
reqID := "dup-req"
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Acquiring with same reqID should be idempotent
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
require.NoError(s.T(), err)
require.True(s.T(), ok)
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 1, cur, "expected concurrency=1 (idempotent)")
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_ReleaseIdempotent() {
accountID := int64(13)
reqID := "release-test"
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 1, reqID)
require.NoError(s.T(), err)
require.True(s.T(), ok)
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot")
// Releasing again should not error
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot again")
// Releasing non-existent should not error
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, "non-existent"), "ReleaseAccountSlot non-existent")
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 0, cur)
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_MaxZero() {
accountID := int64(14)
reqID := "max-zero-test"
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 0, reqID)
require.NoError(s.T(), err)
require.False(s.T(), ok, "expected acquire to fail with max=0")
}
func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() {
userID := int64(42)
reqID1, reqID2 := "req1", "req2"
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID1)
require.NoError(s.T(), err, "AcquireUserSlot")
require.True(s.T(), ok)
ok, err = s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID2)
require.NoError(s.T(), err, "AcquireUserSlot 2")
require.False(s.T(), ok, "expected second acquire to fail at max=1")
cur, err := s.cache.GetUserConcurrency(s.ctx, userID)
require.NoError(s.T(), err, "GetUserConcurrency")
require.Equal(s.T(), 1, cur, "expected concurrency=1")
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, reqID1), "ReleaseUserSlot")
// Releasing a non-existent slot should not error
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, "non-existent"), "ReleaseUserSlot non-existent")
cur, err = s.cache.GetUserConcurrency(s.ctx, userID)
require.NoError(s.T(), err, "GetUserConcurrency after release")
require.Equal(s.T(), 0, cur, "expected concurrency=0 after release")
}
func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
userID := int64(200)
reqID := "req_ttl_test"
slotKey := fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, reqID)
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID)
require.NoError(s.T(), err, "AcquireUserSlot")
require.True(s.T(), ok)
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
}
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
userID := int64(20)
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2)
require.NoError(s.T(), err, "IncrementWaitCount 1")
require.True(s.T(), ok)
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
require.NoError(s.T(), err, "IncrementWaitCount 2")
require.True(s.T(), ok)
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
require.NoError(s.T(), err, "IncrementWaitCount 3")
require.False(s.T(), ok, "expected wait increment over max to fail")
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.Equal(s.T(), 1, val, "expected wait count 1")
}
func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
userID := int64(300)
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
// Test decrement on non-existent key - should not error and should not create negative value
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key")
// Verify no key was created or it's not negative
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty")
// Set count to 1, then decrement twice
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 5)
require.NoError(s.T(), err, "IncrementWaitCount")
require.True(s.T(), ok)
// Decrement once (1 -> 0)
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
// Decrement again on 0 - should not go negative
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero")
// Verify count is 0, not negative
val, err = s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey after double decrement")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
}
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
// When no slots exist, GetAccountConcurrency should return 0
cur, err := s.cache.GetAccountConcurrency(s.ctx, 999)
require.NoError(s.T(), err)
require.Equal(s.T(), 0, cur)
}
func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
// When no slots exist, GetUserConcurrency should return 0
cur, err := s.cache.GetUserConcurrency(s.ctx, 999)
require.NoError(s.T(), err)
require.Equal(s.T(), 0, cur)
}
func TestConcurrencyCacheSuite(t *testing.T) {
suite.Run(t, new(ConcurrencyCacheSuite))
}

View File

@@ -5,36 +5,40 @@ import (
"encoding/json"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const verifyCodeKeyPrefix = "verify_code:"
// verifyCodeKey generates the Redis key for email verification code.
func verifyCodeKey(email string) string {
return verifyCodeKeyPrefix + email
}
type emailCache struct {
rdb *redis.Client
}
func NewEmailCache(rdb *redis.Client) ports.EmailCache {
func NewEmailCache(rdb *redis.Client) service.EmailCache {
return &emailCache{rdb: rdb}
}
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*ports.VerificationCodeData, error) {
key := verifyCodeKeyPrefix + email
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*service.VerificationCodeData, error) {
key := verifyCodeKey(email)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
}
var data ports.VerificationCodeData
var data service.VerificationCodeData
if err := json.Unmarshal([]byte(val), &data); err != nil {
return nil, err
}
return &data, nil
}
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *ports.VerificationCodeData, ttl time.Duration) error {
key := verifyCodeKeyPrefix + email
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *service.VerificationCodeData, ttl time.Duration) error {
key := verifyCodeKey(email)
val, err := json.Marshal(data)
if err != nil {
return err
@@ -43,6 +47,6 @@ func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data
}
func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) error {
key := verifyCodeKeyPrefix + email
key := verifyCodeKey(email)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,92 @@
//go:build integration
package repository
import (
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type EmailCacheSuite struct {
IntegrationRedisSuite
cache service.EmailCache
}
func (s *EmailCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewEmailCache(s.rdb)
}
func (s *EmailCacheSuite) TestGetVerificationCode_Missing() {
_, err := s.cache.GetVerificationCode(s.ctx, "nonexistent@example.com")
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing verification code")
}
func (s *EmailCacheSuite) TestSetAndGetVerificationCode() {
email := "a@example.com"
emailTTL := 2 * time.Minute
data := &service.VerificationCodeData{Code: "123456", Attempts: 1, CreatedAt: time.Now()}
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
got, err := s.cache.GetVerificationCode(s.ctx, email)
require.NoError(s.T(), err, "GetVerificationCode")
require.Equal(s.T(), "123456", got.Code)
require.Equal(s.T(), 1, got.Attempts)
}
func (s *EmailCacheSuite) TestVerificationCode_TTL() {
email := "ttl@example.com"
emailTTL := 2 * time.Minute
data := &service.VerificationCodeData{Code: "654321", Attempts: 0, CreatedAt: time.Now()}
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
emailKey := verifyCodeKeyPrefix + email
ttl, err := s.rdb.TTL(s.ctx, emailKey).Result()
require.NoError(s.T(), err, "TTL emailKey")
s.AssertTTLWithin(ttl, 1*time.Second, emailTTL)
}
func (s *EmailCacheSuite) TestDeleteVerificationCode() {
email := "delete@example.com"
data := &service.VerificationCodeData{Code: "999999", Attempts: 0, CreatedAt: time.Now()}
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, 2*time.Minute), "SetVerificationCode")
// Verify it exists
_, err := s.cache.GetVerificationCode(s.ctx, email)
require.NoError(s.T(), err, "GetVerificationCode before delete")
// Delete
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, email), "DeleteVerificationCode")
// Verify it's gone
_, err = s.cache.GetVerificationCode(s.ctx, email)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
}
func (s *EmailCacheSuite) TestDeleteVerificationCode_NonExistent() {
// Deleting a non-existent key should not error
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, "nonexistent@example.com"), "DeleteVerificationCode non-existent")
}
func (s *EmailCacheSuite) TestGetVerificationCode_JSONCorruption() {
emailKey := verifyCodeKeyPrefix + "corrupted@example.com"
require.NoError(s.T(), s.rdb.Set(s.ctx, emailKey, "not-json", 1*time.Minute).Err(), "Set invalid JSON")
_, err := s.cache.GetVerificationCode(s.ctx, "corrupted@example.com")
require.Error(s.T(), err, "expected error for corrupted JSON")
require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil")
}
func TestEmailCacheSuite(t *testing.T) {
suite.Run(t, new(EmailCacheSuite))
}

View File

@@ -0,0 +1,45 @@
//go:build unit
package repository
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestVerifyCodeKey(t *testing.T) {
tests := []struct {
name string
email string
expected string
}{
{
name: "normal_email",
email: "user@example.com",
expected: "verify_code:user@example.com",
},
{
name: "empty_email",
email: "",
expected: "verify_code:",
},
{
name: "email_with_plus",
email: "user+tag@example.com",
expected: "verify_code:user+tag@example.com",
},
{
name: "email_with_special_chars",
email: "user.name+tag@sub.domain.com",
expected: "verify_code:user.name+tag@sub.domain.com",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := verifyCodeKey(tc.email)
require.Equal(t, tc.expected, got)
})
}
}

View File

@@ -0,0 +1,40 @@
package repository
import (
"errors"
"strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"gorm.io/gorm"
)
func translatePersistenceError(err error, notFound, conflict *infraerrors.ApplicationError) error {
if err == nil {
return nil
}
if notFound != nil && errors.Is(err, gorm.ErrRecordNotFound) {
return notFound.WithCause(err)
}
if conflict != nil && isUniqueConstraintViolation(err) {
return conflict.WithCause(err)
}
return err
}
func isUniqueConstraintViolation(err error) bool {
if err == nil {
return false
}
if errors.Is(err, gorm.ErrDuplicatedKey) {
return true
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "duplicate key") ||
strings.Contains(msg, "unique constraint") ||
strings.Contains(msg, "duplicate entry")
}

View File

@@ -0,0 +1,177 @@
//go:build integration
package repository
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"gorm.io/datatypes"
"gorm.io/gorm"
)
func mustCreateUser(t *testing.T, db *gorm.DB, u *userModel) *userModel {
t.Helper()
if u.PasswordHash == "" {
u.PasswordHash = "test-password-hash"
}
if u.Role == "" {
u.Role = service.RoleUser
}
if u.Status == "" {
u.Status = service.StatusActive
}
if u.Concurrency == 0 {
u.Concurrency = 5
}
if u.CreatedAt.IsZero() {
u.CreatedAt = time.Now()
}
if u.UpdatedAt.IsZero() {
u.UpdatedAt = u.CreatedAt
}
require.NoError(t, db.Create(u).Error, "create user")
return u
}
func mustCreateGroup(t *testing.T, db *gorm.DB, g *groupModel) *groupModel {
t.Helper()
if g.Platform == "" {
g.Platform = service.PlatformAnthropic
}
if g.Status == "" {
g.Status = service.StatusActive
}
if g.SubscriptionType == "" {
g.SubscriptionType = service.SubscriptionTypeStandard
}
if g.CreatedAt.IsZero() {
g.CreatedAt = time.Now()
}
if g.UpdatedAt.IsZero() {
g.UpdatedAt = g.CreatedAt
}
require.NoError(t, db.Create(g).Error, "create group")
return g
}
func mustCreateProxy(t *testing.T, db *gorm.DB, p *proxyModel) *proxyModel {
t.Helper()
if p.Protocol == "" {
p.Protocol = "http"
}
if p.Host == "" {
p.Host = "127.0.0.1"
}
if p.Port == 0 {
p.Port = 8080
}
if p.Status == "" {
p.Status = service.StatusActive
}
if p.CreatedAt.IsZero() {
p.CreatedAt = time.Now()
}
if p.UpdatedAt.IsZero() {
p.UpdatedAt = p.CreatedAt
}
require.NoError(t, db.Create(p).Error, "create proxy")
return p
}
func mustCreateAccount(t *testing.T, db *gorm.DB, a *accountModel) *accountModel {
t.Helper()
if a.Platform == "" {
a.Platform = service.PlatformAnthropic
}
if a.Type == "" {
a.Type = service.AccountTypeOAuth
}
if a.Status == "" {
a.Status = service.StatusActive
}
if !a.Schedulable {
a.Schedulable = true
}
if a.Credentials == nil {
a.Credentials = datatypes.JSONMap{}
}
if a.Extra == nil {
a.Extra = datatypes.JSONMap{}
}
if a.CreatedAt.IsZero() {
a.CreatedAt = time.Now()
}
if a.UpdatedAt.IsZero() {
a.UpdatedAt = a.CreatedAt
}
require.NoError(t, db.Create(a).Error, "create account")
return a
}
func mustCreateApiKey(t *testing.T, db *gorm.DB, k *apiKeyModel) *apiKeyModel {
t.Helper()
if k.Status == "" {
k.Status = service.StatusActive
}
if k.CreatedAt.IsZero() {
k.CreatedAt = time.Now()
}
if k.UpdatedAt.IsZero() {
k.UpdatedAt = k.CreatedAt
}
require.NoError(t, db.Create(k).Error, "create api key")
return k
}
func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *redeemCodeModel) *redeemCodeModel {
t.Helper()
if c.Status == "" {
c.Status = service.StatusUnused
}
if c.Type == "" {
c.Type = service.RedeemTypeBalance
}
if c.CreatedAt.IsZero() {
c.CreatedAt = time.Now()
}
require.NoError(t, db.Create(c).Error, "create redeem code")
return c
}
func mustCreateSubscription(t *testing.T, db *gorm.DB, s *userSubscriptionModel) *userSubscriptionModel {
t.Helper()
if s.Status == "" {
s.Status = service.SubscriptionStatusActive
}
now := time.Now()
if s.StartsAt.IsZero() {
s.StartsAt = now.Add(-1 * time.Hour)
}
if s.ExpiresAt.IsZero() {
s.ExpiresAt = now.Add(24 * time.Hour)
}
if s.AssignedAt.IsZero() {
s.AssignedAt = now
}
if s.CreatedAt.IsZero() {
s.CreatedAt = now
}
if s.UpdatedAt.IsZero() {
s.UpdatedAt = now
}
require.NoError(t, db.Create(s).Error, "create user subscription")
return s
}
func mustBindAccountToGroup(t *testing.T, db *gorm.DB, accountID, groupID int64, priority int) {
t.Helper()
require.NoError(t, db.Create(&accountGroupModel{
AccountID: accountID,
GroupID: groupID,
Priority: priority,
CreatedAt: time.Now(),
}).Error, "create account_group")
}

View File

@@ -4,8 +4,7 @@ import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
@@ -15,7 +14,7 @@ type gatewayCache struct {
rdb *redis.Client
}
func NewGatewayCache(rdb *redis.Client) ports.GatewayCache {
func NewGatewayCache(rdb *redis.Client) service.GatewayCache {
return &gatewayCache{rdb: rdb}
}

View File

@@ -0,0 +1,92 @@
//go:build integration
package repository
import (
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type GatewayCacheSuite struct {
IntegrationRedisSuite
cache service.GatewayCache
}
func (s *GatewayCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewGatewayCache(s.rdb)
}
func (s *GatewayCacheSuite) TestGetSessionAccountID_Missing() {
_, err := s.cache.GetSessionAccountID(s.ctx, "nonexistent")
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing session")
}
func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() {
sessionID := "s1"
accountID := int64(99)
sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID")
sid, err := s.cache.GetSessionAccountID(s.ctx, sessionID)
require.NoError(s.T(), err, "GetSessionAccountID")
require.Equal(s.T(), accountID, sid, "session id mismatch")
}
func (s *GatewayCacheSuite) TestSessionAccountID_TTL() {
sessionID := "s2"
accountID := int64(100)
sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID")
sessionKey := stickySessionPrefix + sessionID
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
require.NoError(s.T(), err, "TTL sessionKey after Set")
s.AssertTTLWithin(ttl, 1*time.Second, sessionTTL)
}
func (s *GatewayCacheSuite) TestRefreshSessionTTL() {
sessionID := "s3"
accountID := int64(101)
initialTTL := 1 * time.Minute
refreshTTL := 3 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, initialTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, sessionID, refreshTTL), "RefreshSessionTTL")
sessionKey := stickySessionPrefix + sessionID
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
require.NoError(s.T(), err, "TTL after Refresh")
s.AssertTTLWithin(ttl, 1*time.Second, refreshTTL)
}
func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
// RefreshSessionTTL on a missing key should not error (no-op)
err := s.cache.RefreshSessionTTL(s.ctx, "missing-session", 1*time.Minute)
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
}
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
sessionID := "corrupted"
sessionKey := stickySessionPrefix + sessionID
// Set a non-integer value
require.NoError(s.T(), s.rdb.Set(s.ctx, sessionKey, "not-a-number", 1*time.Minute).Err(), "Set invalid value")
_, err := s.cache.GetSessionAccountID(s.ctx, sessionID)
require.Error(s.T(), err, "expected error for corrupted value")
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
}
func TestGatewayCacheSuite(t *testing.T) {
suite.Run(t, new(GatewayCacheSuite))
}

View File

@@ -0,0 +1,117 @@
package repository
import (
"context"
"fmt"
"net/url"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3"
)
type geminiOAuthClient struct {
tokenURL string
cfg *config.Config
}
func NewGeminiOAuthClient(cfg *config.Config) service.GeminiOAuthClient {
return &geminiOAuthClient{
tokenURL: geminicli.TokenURL,
cfg: cfg,
}
}
func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) {
client := createGeminiReqClient(proxyURL)
// Use different OAuth clients based on oauthType:
// - code_assist: always use built-in Gemini CLI OAuth client (public)
// - ai_studio: requires a user-provided OAuth client
oauthCfgInput := geminicli.OAuthConfig{
ClientID: c.cfg.Gemini.OAuth.ClientID,
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
Scopes: c.cfg.Gemini.OAuth.Scopes,
}
if oauthType == "code_assist" {
oauthCfgInput.ClientID = ""
oauthCfgInput.ClientSecret = ""
}
oauthCfg, err := geminicli.EffectiveOAuthConfig(oauthCfgInput, oauthType)
if err != nil {
return nil, err
}
formData := url.Values{}
formData.Set("grant_type", "authorization_code")
formData.Set("client_id", oauthCfg.ClientID)
formData.Set("client_secret", oauthCfg.ClientSecret)
formData.Set("code", code)
formData.Set("code_verifier", codeVerifier)
formData.Set("redirect_uri", redirectURI)
var tokenResp geminicli.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(c.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, geminicli.SanitizeBodyForLogs(resp.String()))
}
return &tokenResp, nil
}
func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) {
client := createGeminiReqClient(proxyURL)
oauthCfgInput := geminicli.OAuthConfig{
ClientID: c.cfg.Gemini.OAuth.ClientID,
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
Scopes: c.cfg.Gemini.OAuth.Scopes,
}
if oauthType == "code_assist" {
oauthCfgInput.ClientID = ""
oauthCfgInput.ClientSecret = ""
}
oauthCfg, err := geminicli.EffectiveOAuthConfig(oauthCfgInput, oauthType)
if err != nil {
return nil, err
}
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken)
formData.Set("client_id", oauthCfg.ClientID)
formData.Set("client_secret", oauthCfg.ClientSecret)
var tokenResp geminicli.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(c.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, geminicli.SanitizeBodyForLogs(resp.String()))
}
return &tokenResp, nil
}
func createGeminiReqClient(proxyURL string) *req.Client {
client := req.C().SetTimeout(60 * time.Second)
if proxyURL != "" {
client.SetProxyURL(proxyURL)
}
return client
}

View File

@@ -0,0 +1,44 @@
package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
geminiTokenKeyPrefix = "gemini:token:"
geminiRefreshLockKeyPrefix = "gemini:refresh_lock:"
)
type geminiTokenCache struct {
rdb *redis.Client
}
func NewGeminiTokenCache(rdb *redis.Client) service.GeminiTokenCache {
return &geminiTokenCache{rdb: rdb}
}
func (c *geminiTokenCache) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
return c.rdb.Get(ctx, key).Result()
}
func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
return c.rdb.Set(ctx, key, token, ttl).Err()
}
func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
}
func (c *geminiTokenCache) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,105 @@
package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3"
)
type geminiCliCodeAssistClient struct {
baseURL string
}
func NewGeminiCliCodeAssistClient() service.GeminiCliCodeAssistClient {
return &geminiCliCodeAssistClient{baseURL: geminicli.GeminiCliBaseURL}
}
func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessToken, proxyURL string, reqBody *geminicli.LoadCodeAssistRequest) (*geminicli.LoadCodeAssistResponse, error) {
if reqBody == nil {
reqBody = defaultLoadCodeAssistRequest()
}
var out geminicli.LoadCodeAssistResponse
resp, err := createGeminiCliReqClient(proxyURL).R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+accessToken).
SetHeader("Content-Type", "application/json").
SetHeader("User-Agent", geminicli.GeminiCLIUserAgent).
SetBody(reqBody).
SetSuccessResult(&out).
Post(c.baseURL + "/v1internal:loadCodeAssist")
if err != nil {
fmt.Printf("[CodeAssist] LoadCodeAssist request error: %v\n", err)
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
body := geminicli.SanitizeBodyForLogs(resp.String())
fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, body)
return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, body)
}
fmt.Printf("[CodeAssist] LoadCodeAssist success: status %d, response: %+v\n", resp.StatusCode, out)
return &out, nil
}
func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken, proxyURL string, reqBody *geminicli.OnboardUserRequest) (*geminicli.OnboardUserResponse, error) {
if reqBody == nil {
reqBody = defaultOnboardUserRequest()
}
fmt.Printf("[CodeAssist] OnboardUser request body: %+v\n", reqBody)
var out geminicli.OnboardUserResponse
resp, err := createGeminiCliReqClient(proxyURL).R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+accessToken).
SetHeader("Content-Type", "application/json").
SetHeader("User-Agent", geminicli.GeminiCLIUserAgent).
SetBody(reqBody).
SetSuccessResult(&out).
Post(c.baseURL + "/v1internal:onboardUser")
if err != nil {
fmt.Printf("[CodeAssist] OnboardUser request error: %v\n", err)
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
body := geminicli.SanitizeBodyForLogs(resp.String())
fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, body)
return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, body)
}
fmt.Printf("[CodeAssist] OnboardUser success: status %d, response: %+v\n", resp.StatusCode, out)
return &out, nil
}
func createGeminiCliReqClient(proxyURL string) *req.Client {
client := req.C().SetTimeout(30 * time.Second)
if proxyURL != "" {
client.SetProxyURL(proxyURL)
}
return client
}
func defaultLoadCodeAssistRequest() *geminicli.LoadCodeAssistRequest {
return &geminicli.LoadCodeAssistRequest{
Metadata: geminicli.LoadCodeAssistMetadata{
IDEType: "ANTIGRAVITY",
Platform: "PLATFORM_UNSPECIFIED",
PluginType: "GEMINI",
},
}
}
func defaultOnboardUserRequest() *geminicli.OnboardUserRequest {
return &geminicli.OnboardUserRequest{
TierID: "LEGACY",
Metadata: geminicli.LoadCodeAssistMetadata{
IDEType: "ANTIGRAVITY",
Platform: "PLATFORM_UNSPECIFIED",
PluginType: "GEMINI",
},
}
}

View File

@@ -0,0 +1,328 @@
package repository
import (
"bytes"
"context"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type GitHubReleaseServiceSuite struct {
suite.Suite
srv *httptest.Server
client *githubReleaseClient
tempDir string
}
// testTransport redirects requests to the test server
type testTransport struct {
testServerURL string
}
func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Rewrite the URL to point to our test server
testURL := t.testServerURL + req.URL.Path
newReq, err := http.NewRequestWithContext(req.Context(), req.Method, testURL, req.Body)
if err != nil {
return nil, err
}
newReq.Header = req.Header
return http.DefaultTransport.RoundTrip(newReq)
}
func (s *GitHubReleaseServiceSuite) SetupTest() {
s.tempDir = s.T().TempDir()
}
func (s *GitHubReleaseServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLength() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", "100")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(bytes.Repeat([]byte("a"), 100))
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "file1.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
require.Error(s.T(), err, "expected error for oversized download with Content-Length")
_, statErr := os.Stat(dest)
require.Error(s.T(), statErr, "expected file to not exist for rejected download")
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Force chunked encoding (unknown Content-Length) by flushing headers before writing.
w.WriteHeader(http.StatusOK)
if fl, ok := w.(http.Flusher); ok {
fl.Flush()
}
for i := 0; i < 10; i++ {
_, _ = w.Write(bytes.Repeat([]byte("b"), 10))
if fl, ok := w.(http.Flusher); ok {
fl.Flush()
}
}
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "file2.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
require.Error(s.T(), err, "expected error for oversized chunked download")
_, statErr := os.Stat(dest)
require.Error(s.T(), statErr, "expected file to be cleaned up for oversized chunked download")
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
if fl, ok := w.(http.Flusher); ok {
fl.Flush()
}
for i := 0; i < 10; i++ {
_, _ = w.Write(bytes.Repeat([]byte("b"), 10))
if fl, ok := w.(http.Flusher); ok {
fl.Flush()
}
}
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "file3.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 200)
require.NoError(s.T(), err, "expected success")
b, err := os.ReadFile(dest)
require.NoError(s.T(), err, "read")
require.True(s.T(), strings.HasPrefix(string(b), "b"), "downloaded content should start with 'b'")
require.Len(s.T(), b, 100, "downloaded content length mismatch")
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "notfound.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
require.Error(s.T(), err, "expected error for 404")
_, statErr := os.Stat(dest)
require.Error(s.T(), statErr, "expected file to not exist for 404")
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("sum"))
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
body, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
require.NoError(s.T(), err, "FetchChecksumFile")
require.Equal(s.T(), "sum", string(body), "checksum body mismatch")
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
_, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
require.Error(s.T(), err, "expected error for non-200")
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done()
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
ctx, cancel := context.WithCancel(context.Background())
cancel()
dest := filepath.Join(s.tempDir, "cancelled.bin")
err := s.client.DownloadFile(ctx, s.srv.URL, dest, 100)
require.Error(s.T(), err, "expected error for cancelled context")
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "invalid.bin")
err := s.client.DownloadFile(context.Background(), "://invalid-url", dest, 100)
require.Error(s.T(), err, "expected error for invalid URL")
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("content"))
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
// Use a path that cannot be created (directory doesn't exist)
dest := filepath.Join(s.tempDir, "nonexistent", "subdir", "file.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
require.Error(s.T(), err, "expected error for invalid destination path")
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_InvalidURL() {
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
_, err := s.client.FetchChecksumFile(context.Background(), "://invalid-url")
require.Error(s.T(), err, "expected error for invalid URL")
}
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
releaseJSON := `{
"tag_name": "v1.0.0",
"name": "Release 1.0.0",
"body": "Release notes",
"html_url": "https://github.com/test/repo/releases/v1.0.0",
"assets": [
{
"name": "app-linux-amd64.tar.gz",
"browser_download_url": "https://github.com/test/repo/releases/download/v1.0.0/app-linux-amd64.tar.gz"
}
]
}`
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(s.T(), "/repos/test/repo/releases/latest", r.URL.Path)
require.Equal(s.T(), "application/vnd.github.v3+json", r.Header.Get("Accept"))
require.Equal(s.T(), "Sub2API-Updater", r.Header.Get("User-Agent"))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(releaseJSON))
}))
// Use custom transport to redirect requests to test server
s.client = &githubReleaseClient{
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
}
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
require.NoError(s.T(), err)
require.Equal(s.T(), "v1.0.0", release.TagName)
require.Equal(s.T(), "Release 1.0.0", release.Name)
require.Len(s.T(), release.Assets, 1)
require.Equal(s.T(), "app-linux-amd64.tar.gz", release.Assets[0].Name)
}
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
s.client = &githubReleaseClient{
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
}
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
require.Error(s.T(), err)
require.Contains(s.T(), err.Error(), "404")
}
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("not valid json"))
}))
s.client = &githubReleaseClient{
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
}
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
require.Error(s.T(), err)
}
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done()
}))
s.client = &githubReleaseClient{
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := s.client.FetchLatestRelease(ctx, "test/repo")
require.Error(s.T(), err)
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done()
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := s.client.FetchChecksumFile(ctx, s.srv.URL)
require.Error(s.T(), err)
}
func TestGitHubReleaseServiceSuite(t *testing.T) {
suite.Run(t, new(GitHubReleaseServiceSuite))
}

View File

@@ -2,51 +2,68 @@ package repository
import (
"context"
"github.com/Wei-Shaw/sub2api/internal/model"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type GroupRepository struct {
type groupRepository struct {
db *gorm.DB
}
func NewGroupRepository(db *gorm.DB) *GroupRepository {
return &GroupRepository{db: db}
func NewGroupRepository(db *gorm.DB) service.GroupRepository {
return &groupRepository{db: db}
}
func (r *GroupRepository) Create(ctx context.Context, group *model.Group) error {
return r.db.WithContext(ctx).Create(group).Error
}
func (r *GroupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) {
var group model.Group
err := r.db.WithContext(ctx).First(&group, id).Error
if err != nil {
return nil, err
func (r *groupRepository) Create(ctx context.Context, group *service.Group) error {
m := groupModelFromService(group)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyGroupModelToService(group, m)
}
return &group, nil
return translatePersistenceError(err, nil, service.ErrGroupExists)
}
func (r *GroupRepository) Update(ctx context.Context, group *model.Group) error {
return r.db.WithContext(ctx).Save(group).Error
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) {
var m groupModel
err := r.db.WithContext(ctx).First(&m, id).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
group := groupModelToService(&m)
count, _ := r.GetAccountCount(ctx, group.ID)
group.AccountCount = count
return group, nil
}
func (r *GroupRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error
func (r *groupRepository) Update(ctx context.Context, group *service.Group) error {
m := groupModelFromService(group)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyGroupModelToService(group, m)
}
return err
}
func (r *GroupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
func (r *groupRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&groupModel{}, id).Error
}
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", nil)
}
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) {
var groups []model.Group
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
var groups []groupModel
var total int64
db := r.db.WithContext(ctx).Model(&model.Group{})
db := r.db.WithContext(ctx).Model(&groupModel{})
// Apply filters
if platform != "" {
@@ -67,72 +84,198 @@ func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination
return nil, nil, err
}
// 获取每个分组的账号数量
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
count, _ := r.GetAccountCount(ctx, groups[i].ID)
groups[i].AccountCount = count
outGroups = append(outGroups, *groupModelToService(&groups[i]))
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
// 获取每个分组的账号数量
for i := range outGroups {
count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
outGroups[i].AccountCount = count
}
return groups, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return outGroups, paginationResultFromTotal(total, params), nil
}
func (r *GroupRepository) ListActive(ctx context.Context) ([]model.Group, error) {
var groups []model.Group
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Order("id ASC").Find(&groups).Error
func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, error) {
var groups []groupModel
err := r.db.WithContext(ctx).Where("status = ?", service.StatusActive).Order("id ASC").Find(&groups).Error
if err != nil {
return nil, err
}
// 获取每个分组的账号数量
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
count, _ := r.GetAccountCount(ctx, groups[i].ID)
groups[i].AccountCount = count
outGroups = append(outGroups, *groupModelToService(&groups[i]))
}
return groups, nil
// 获取每个分组的账号数量
for i := range outGroups {
count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
outGroups[i].AccountCount = count
}
return outGroups, nil
}
func (r *GroupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) {
var groups []model.Group
err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", model.StatusActive, platform).Order("id ASC").Find(&groups).Error
func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
var groups []groupModel
err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", service.StatusActive, platform).Order("id ASC").Find(&groups).Error
if err != nil {
return nil, err
}
// 获取每个分组的账号数量
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
count, _ := r.GetAccountCount(ctx, groups[i].ID)
groups[i].AccountCount = count
outGroups = append(outGroups, *groupModelToService(&groups[i]))
}
return groups, nil
// 获取每个分组的账号数量
for i := range outGroups {
count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
outGroups[i].AccountCount = count
}
return outGroups, nil
}
func (r *GroupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.Group{}).Where("name = ?", name).Count(&count).Error
err := r.db.WithContext(ctx).Model(&groupModel{}).Where("name = ?", name).Count(&count).Error
return count > 0, err
}
func (r *GroupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.AccountGroup{}).Where("group_id = ?", groupID).Count(&count).Error
err := r.db.WithContext(ctx).Table("account_groups").Where("group_id = ?", groupID).Count(&count).Error
return count, err
}
// DeleteAccountGroupsByGroupID 删除分组与账号的关联关系
func (r *GroupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.AccountGroup{})
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Exec("DELETE FROM account_groups WHERE group_id = ?", groupID)
return result.RowsAffected, result.Error
}
// DB 返回底层数据库连接,用于事务处理
func (r *GroupRepository) DB() *gorm.DB {
return r.db
func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
group, err := r.GetByID(ctx, id)
if err != nil {
return nil, err
}
var affectedUserIDs []int64
if group.IsSubscriptionType() {
if err := r.db.WithContext(ctx).
Table("user_subscriptions").
Where("group_id = ?", id).
Pluck("user_id", &affectedUserIDs).Error; err != nil {
return nil, err
}
}
err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 1. 删除订阅类型分组的订阅记录
if group.IsSubscriptionType() {
if err := tx.Exec("DELETE FROM user_subscriptions WHERE group_id = ?", id).Error; err != nil {
return err
}
}
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil
if err := tx.Exec("UPDATE api_keys SET group_id = NULL WHERE group_id = ?", id).Error; err != nil {
return err
}
// 3. 从 users.allowed_groups 数组中移除该分组 ID
if err := tx.Exec(
"UPDATE users SET allowed_groups = array_remove(allowed_groups, ?) WHERE ? = ANY(allowed_groups)",
id, id,
).Error; err != nil {
return err
}
// 4. 删除 account_groups 中间表的数据
if err := tx.Exec("DELETE FROM account_groups WHERE group_id = ?", id).Error; err != nil {
return err
}
// 5. 删除分组本身(带锁,避免并发写)
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Delete(&groupModel{}, id).Error; err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
return affectedUserIDs, nil
}
type groupModel struct {
ID int64 `gorm:"primaryKey"`
Name string `gorm:"uniqueIndex;size:100;not null"`
Description string `gorm:"type:text"`
Platform string `gorm:"size:50;default:anthropic;not null"`
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null"`
IsExclusive bool `gorm:"default:false;not null"`
Status string `gorm:"size:20;default:active;not null"`
SubscriptionType string `gorm:"size:20;default:standard;not null"`
DailyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
WeeklyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
MonthlyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (groupModel) TableName() string { return "groups" }
func groupModelToService(m *groupModel) *service.Group {
if m == nil {
return nil
}
return &service.Group{
ID: m.ID,
Name: m.Name,
Description: m.Description,
Platform: m.Platform,
RateMultiplier: m.RateMultiplier,
IsExclusive: m.IsExclusive,
Status: m.Status,
SubscriptionType: m.SubscriptionType,
DailyLimitUSD: m.DailyLimitUSD,
WeeklyLimitUSD: m.WeeklyLimitUSD,
MonthlyLimitUSD: m.MonthlyLimitUSD,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func groupModelFromService(sg *service.Group) *groupModel {
if sg == nil {
return nil
}
return &groupModel{
ID: sg.ID,
Name: sg.Name,
Description: sg.Description,
Platform: sg.Platform,
RateMultiplier: sg.RateMultiplier,
IsExclusive: sg.IsExclusive,
Status: sg.Status,
SubscriptionType: sg.SubscriptionType,
DailyLimitUSD: sg.DailyLimitUSD,
WeeklyLimitUSD: sg.WeeklyLimitUSD,
MonthlyLimitUSD: sg.MonthlyLimitUSD,
CreatedAt: sg.CreatedAt,
UpdatedAt: sg.UpdatedAt,
}
}
func applyGroupModelToService(group *service.Group, m *groupModel) {
if group == nil || m == nil {
return
}
group.ID = m.ID
group.CreatedAt = m.CreatedAt
group.UpdatedAt = m.UpdatedAt
}

View File

@@ -0,0 +1,236 @@
//go:build integration
package repository
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type GroupRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *groupRepository
}
func (s *GroupRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewGroupRepository(s.db).(*groupRepository)
}
func TestGroupRepoSuite(t *testing.T) {
suite.Run(t, new(GroupRepoSuite))
}
// --- Create / GetByID / Update / Delete ---
func (s *GroupRepoSuite) TestCreate() {
group := &service.Group{
Name: "test-create",
Platform: service.PlatformAnthropic,
Status: service.StatusActive,
}
err := s.repo.Create(s.ctx, group)
s.Require().NoError(err, "Create")
s.Require().NotZero(group.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, group.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("test-create", got.Name)
}
func (s *GroupRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *GroupRepoSuite) TestUpdate() {
group := groupModelToService(mustCreateGroup(s.T(), s.db, &groupModel{Name: "original"}))
group.Name = "updated"
err := s.repo.Update(s.ctx, group)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, group.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated", got.Name)
}
func (s *GroupRepoSuite) TestDelete() {
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "to-delete"})
err := s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, group.ID)
s.Require().Error(err, "expected error after delete")
}
// --- List / ListWithFilters ---
func (s *GroupRepoSuite) TestList() {
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2"})
groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
s.Require().Len(groups, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *GroupRepoSuite) TestListWithFilters_Platform() {
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Platform: service.PlatformAnthropic})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Platform: service.PlatformOpenAI})
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil)
s.Require().NoError(err)
s.Require().Len(groups, 1)
s.Require().Equal(service.PlatformOpenAI, groups[0].Platform)
}
func (s *GroupRepoSuite) TestListWithFilters_Status() {
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Status: service.StatusDisabled})
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, nil)
s.Require().NoError(err)
s.Require().Len(groups, 1)
s.Require().Equal(service.StatusDisabled, groups[0].Status)
}
func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", IsExclusive: false})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", IsExclusive: true})
isExclusive := true
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive)
s.Require().NoError(err)
s.Require().Len(groups, 1)
s.Require().True(groups[0].IsExclusive)
}
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
g1 := mustCreateGroup(s.T(), s.db, &groupModel{
Name: "g1",
Platform: service.PlatformAnthropic,
Status: service.StatusActive,
})
g2 := mustCreateGroup(s.T(), s.db, &groupModel{
Name: "g2",
Platform: service.PlatformAnthropic,
Status: service.StatusActive,
IsExclusive: true,
})
a := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc1"})
mustBindAccountToGroup(s.T(), s.db, a.ID, g1.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a.ID, g2.ID, 1)
isExclusive := true
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, &isExclusive)
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(groups, 1)
s.Require().Equal(g2.ID, groups[0].ID, "ListWithFilters returned wrong group")
s.Require().Equal(int64(1), groups[0].AccountCount, "AccountCount mismatch")
}
// --- ListActive / ListActiveByPlatform ---
func (s *GroupRepoSuite) TestListActive() {
mustCreateGroup(s.T(), s.db, &groupModel{Name: "active1", Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "inactive1", Status: service.StatusDisabled})
groups, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
s.Require().Len(groups, 1)
s.Require().Equal("active1", groups[0].Name)
}
func (s *GroupRepoSuite) TestListActiveByPlatform() {
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g3", Platform: service.PlatformAnthropic, Status: service.StatusDisabled})
groups, err := s.repo.ListActiveByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err, "ListActiveByPlatform")
s.Require().Len(groups, 1)
s.Require().Equal("g1", groups[0].Name)
}
// --- ExistsByName ---
func (s *GroupRepoSuite) TestExistsByName() {
mustCreateGroup(s.T(), s.db, &groupModel{Name: "existing-group"})
exists, err := s.repo.ExistsByName(s.ctx, "existing-group")
s.Require().NoError(err, "ExistsByName")
s.Require().True(exists)
notExists, err := s.repo.ExistsByName(s.ctx, "non-existing")
s.Require().NoError(err)
s.Require().False(notExists)
}
// --- GetAccountCount ---
func (s *GroupRepoSuite) TestGetAccountCount() {
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1"})
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2"})
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err, "GetAccountCount")
s.Require().Equal(int64(2), count)
}
func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-empty"})
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err)
s.Require().Zero(count)
}
// --- DeleteAccountGroupsByGroupID ---
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
g := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-del"})
a := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-del"})
mustBindAccountToGroup(s.T(), s.db, a.ID, g.ID, 1)
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
s.Require().NoError(err, "DeleteAccountGroupsByGroupID")
s.Require().Equal(int64(1), affected, "expected 1 affected row")
count, err := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().NoError(err, "GetAccountCount")
s.Require().Equal(int64(0), count, "expected 0 account groups")
}
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
g := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-multi"})
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1"})
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2"})
a3 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3"})
mustBindAccountToGroup(s.T(), s.db, a1.ID, g.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, g.ID, 2)
mustBindAccountToGroup(s.T(), s.db, a3.ID, g.ID, 3)
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
s.Require().NoError(err)
s.Require().Equal(int64(3), affected)
count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().Zero(count)
}

View File

@@ -6,7 +6,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// httpUpstreamService is a generic HTTP upstream service that can be used for
@@ -17,7 +17,7 @@ type httpUpstreamService struct {
}
// NewHTTPUpstream creates a new generic HTTP upstream service
func NewHTTPUpstream(cfg *config.Config) ports.HTTPUpstream {
func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout == 0 {
responseHeaderTimeout = 300 * time.Second

Some files were not shown because too many files have changed in this diff Show More