Compare commits

...

145 Commits

Author SHA1 Message Date
shaw
2b79c4e8b7 chore: home页面更新gemini为已支持 2025-12-26 23:13:00 +08:00
shaw
429f38d0c9 Merge PR #37: Add Gemini OAuth and Messages Compat Support 2025-12-26 22:42:34 +08:00
IanShaw027
2714be99a9 feat(test): 添加 Gemini 双响应格式支持
添加对两种 Gemini 响应格式的支持:
- AI Studio: `{"candidates": [...]}`
- Gemini CLI: `{"response": {"candidates": [...]}}`

通过 unwrap 逻辑自动检测并适配两种格式,确保账号测试功能
对所有 Gemini 账号类型都能正常工作。

合并 PR #43 的剩余功能到 PR #37
2025-12-26 22:31:12 +08:00
IanShaw027
d851818035 fix(lint): 修复 gofmt 格式问题
修复 golangci-lint 检查失败的问题:
- gemini_token_provider.go: 删除 import 后多余空行
- gemini_token_refresher.go: 删除 import 后多余空行

Fixes CI golangci-lint check for PR #37
2025-12-26 22:24:22 +08:00
IanShaw027
576bf4639c refactor: 统一使用 mergeMap 函数提升代码一致性
根据 Gemini CLI 代码审查建议:

## 修改内容
- 将 Gemini OAuth 同步中的 `mergeJSONB` 调用替换为 `mergeMap`
- 删除不再使用的 `mergeJSONB` 函数定义

## 原因
- 其他平台(OpenAI、Anthropic)的账户同步都使用 `mergeMap`
- `mergeJSONB` 是为旧的 `model.JSONB` 类型设计,与重构后的架构不一致
- 统一函数命名提高代码可读性和可维护性

## 影响范围
- backend/internal/service/crs_sync_service.go (4处替换)
- backend/internal/service/account.go (删除 mergeJSONB 函数)

## 验证
✓ 编译通过
✓ 功能逻辑无变化(mergeMap 和 mergeJSONB 实现相同)
2025-12-26 22:15:15 +08:00
IanShaw027
9db52838b5 fix(backend): 适配重构后的架构修复 Gemini OAuth 集成
## 主要修改

1. **移除 model 包引用**
   - 删除所有 `internal/model` 包的 import
   - 使用 service 包中的类型定义(Account, Platform常量等)

2. **修复类型转换**
   - JSONB → map[string]any
   - 添加 mergeJSONB 辅助函数
   - 添加 Account.IsGemini() 方法

3. **更新中间件调用**
   - GetUserFromContext → GetAuthSubjectFromContext
   - 适配新的并发控制签名(传递 ID 和 Concurrency 而不是完整对象)

4. **修复 handler 层**
   - 更新 gemini_v1beta_handler.go
   - 修正 billing 检查和 usage 记录

## 影响范围
- backend/internal/service/gemini_*.go
- backend/internal/service/account_test_service.go
- backend/internal/service/crs_sync_service.go
- backend/internal/handler/gemini_v1beta_handler.go
- backend/internal/handler/gateway_handler.go
- backend/internal/handler/admin/account_handler.go
2025-12-26 22:07:55 +08:00
IanShaw027
bfcd9501c2 merge: 合并 upstream/main 解决 PR #37 冲突
- 删除 backend/internal/model/account.go 符合重构方向
- 合并最新的项目结构重构
- 包含 SSE 格式解析修复
- 更新依赖和配置文件
2025-12-26 21:56:08 +08:00
September999999999
12252c6005 fix: 卸载时删除安装锁文件以支持重新安装 (#39)
- 卸载时自动删除 .installed 安装锁文件
- 新增 --purge 参数支持完全清理(包括配置目录)
- 交互模式下增加是否删除配置目录的确认提示
- 支持中英文消息
2025-12-26 21:32:22 +08:00
shaw
2d89f36687 Merge PR #42: fix(sse): 修复非标准 SSE 格式解析问题 2025-12-26 21:31:34 +08:00
shaw
3d608c2625 Merge branch 'refactor/redis-key-helpers' 2025-12-26 21:26:18 +08:00
shaw
739d0ee61e fix: admin handlers 添加 DTO 转换修复 JSON 序列化
修复 PR #36 合并后部分 admin handler 直接返回 service 层对象导致
JSON 字段名为 PascalCase 而非期望的 snake_case 问题。

修复内容:
- account_handler: Refresh 接口添加 dto.AccountFromService
- openai_oauth_handler: RefreshAccountToken/CreateAccountFromOAuth 添加 dto 转换
- subscription_handler: BulkAssign 添加 dto.BulkAssignResultFromService
- usage_handler: List 接口添加 dto.UsageLogFromService 转换
- 新增 dto.BulkAssignResult 类型和对应的 mapper 函数
2025-12-26 21:22:48 +08:00
shaw
22f07a7bb6 Merge PR #36: refactor: 调整项目结构为单向依赖 2025-12-26 20:08:26 +08:00
ianshaw
16eec4eb41 fix(sse): 修复非标准 SSE 格式解析问题
部分上游 API 返回的 SSE 格式不符合标准规范:
- 标准格式: `data: {...}`(冒号后有空格)
- 非标准格式: `data:{...}`(冒号后无空格)

使用预编译正则 `^data:\s*` 统一处理两种格式。
2025-12-26 03:49:55 -08:00
shaw
ecb2c5353c fix: 修复docker-compose.yml redis密码传递问题 2025-12-26 17:25:12 +08:00
Forest
06d5876b02 refactor: 封装 Redis key 生成函数 2025-12-26 16:47:44 +08:00
Forest
e5a77853b0 refactor: 调整项目结构为单向依赖 2025-12-26 16:45:40 +08:00
ianshaw
9780f0fd9d fix(backend): 修复 rebase 后的代码集成问题
- 更新 middleware import 路径到 internal/server/middleware
- 修复 api_key_auth_google.go 使用正确的 service 类型
- 更新 router.go 和 http.go 支持 Gemini v1beta 路由
- 在 routes/gateway.go 中添加 Gemini v1beta API 端点
- 在 routes/admin.go 中添加 Gemini OAuth 路由
- 更新 wire.go 添加 GeminiOAuthService cleanup
- 重新生成 wire_gen.go
2025-12-26 00:17:55 -08:00
ianshaw
3559830882 fix(service): 应用德摩根定律修复 staticcheck QF1001 警告 2025-12-26 00:11:04 -08:00
ianshaw
5594680130 docs(deploy): 说明 AI Studio OAuth Client 需发布为正式版本
README.md:
- 添加第 7 步:发布 OAuth 应用到正式版本
- 说明 Testing 模式限制(100 用户、7 天 token 过期)
- 说明 sensitive scope 可能需要 Google 审核

.env.example:
- 添加 OAuth Client 需发布为正式版本的说明
2025-12-26 00:11:04 -08:00
ianshaw
50855ec15f feat(i18n): 添加 AI Studio OAuth 配置状态相关文案
- 添加 aiStudioNotConfiguredShort/Tip/aiStudioNotConfigured 翻译
- 更新 needsProjectIdDesc/noProjectIdNeededDesc 描述更准确
2025-12-26 00:11:04 -08:00
ianshaw
f9f33e7b5c feat(frontend): UI 显示 AI Studio OAuth 配置状态
CreateAccountModal:
- 检测 AI Studio OAuth 是否可用(服务器配置了自定义 OAuth 客户端)
- 未配置时禁用 AI Studio 选项并显示提示
- 添加悬停提示说明配置要求

ReAuthAccountModal:
- 同步 AI Studio OAuth 可用性检测逻辑
- 未配置时自动回退到 Code Assist
2025-12-26 00:11:03 -08:00
ianshaw
1bec35999b feat(frontend): 添加 Gemini OAuth 能力查询 API
- 添加 GeminiOAuthCapabilities 类型定义
- 添加 getCapabilities API 函数
- useGeminiOAuth composable 导出 getCapabilities 方法
2025-12-26 00:11:03 -08:00
ianshaw
632318ad33 feat(backend): 添加 OAuth 能力查询接口,改进 OAuth 客户端选择逻辑
Handler 改进:
- 添加 GET /api/v1/admin/gemini/oauth/capabilities 接口
- 简化 GenerateAuthURL,redirect_uri 由服务层决定

Repository 改进:
- ExchangeCode/RefreshToken 根据 oauthType 选择正确的 OAuth 客户端
- Code Assist 始终使用内置客户端,AI Studio 使用用户配置的客户端
2025-12-26 00:11:03 -08:00
ianshaw
456e8984b0 feat(service): 改进 Gemini OAuth 服务层,区分 Code Assist 和 AI Studio 客户端
OAuth 服务改进:
- 添加 GetOAuthConfig 返回 AI Studio OAuth 可用性
- Code Assist 强制使用内置 Gemini CLI 客户端
- AI Studio OAuth 要求用户配置自定义 OAuth 客户端
- ExchangeCode/RefreshToken 接口添加 oauthType 参数
- 添加 unauthorized_client 错误的向后兼容重试逻辑

兼容层改进:
- 403 重试逻辑仅对 Code Assist OAuth 生效
- 添加 insufficient-scope 错误检测,避免无效重试
- 上游错误消息脱敏处理(隐藏 API key 等敏感信息)
- 改进错误提示,显示更多上游错误详情
2025-12-26 00:11:03 -08:00
ianshaw
eea949853a feat(geminicli): 添加内置 Gemini CLI OAuth 客户端常量和改进配置逻辑
- 添加 GeminiCLIOAuthClientID/Secret 常量(Gemini CLI 公开 OAuth 客户端)
- 更新 DefaultAIStudioScopes 使用 generative-language.retriever(符合 Google 文档)
- EffectiveOAuthConfig 支持自动回退到内置客户端
- 内置客户端自动过滤受限 scope(如 generative-language)
- 添加 scope 向后兼容性处理
2025-12-26 00:11:03 -08:00
ianshaw
85fd1e4a2c fix(backend): 移除对已删除 ports 包的依赖
适配 main 分支的 ports 目录删除重构:
- 将 ports 包中的接口移至 service 包
- 更新 repository 层的导入路径
2025-12-26 00:11:03 -08:00
ianshaw
6682d06c99 fix(backend): 修复 golangci-lint 报告的格式和代码规范问题
- gofmt: 修复 account_handler.go, models.go, gemini_messages_compat_service.go 的格式
- staticcheck ST1005: 将 error strings 改为小写开头
2025-12-26 00:11:03 -08:00
ianshaw
efa470efc7 fix(backend): 修复 golangci-lint 报告的问题
- gofmt: 修复代码格式问题
- errcheck: 处理 WriteString 和 Close 返回值
- staticcheck: 错误信息改为小写开头
- staticcheck: 移除无效的 nil 检查
- staticcheck: 使用 append 替换循环
- staticcheck: 使用无条件的 TrimPrefix
- ineffassign: 移除无效赋值
- unused: 移除未使用的 geminiOAuthService 字段
- 重新生成 wire_gen.go
2025-12-26 00:11:03 -08:00
ianshaw
79d1585250 docs(deploy): 更新部署配置和文档
- .env.example: 新增 Gemini OAuth 环境变量配置示例
- config.example.yaml: 新增 Gemini OAuth 配置示例
- README.md: 更新部署文档
- docker-compose.yml: 添加 Gemini OAuth 环境变量传递
2025-12-26 00:11:03 -08:00
ianshaw
2d1a15b196 feat(frontend): 添加 Gemini OAuth 类型国际化
- zh.ts: 添加中文翻译(Code Assist/AI Studio 选择等)
- en.ts: 添加英文翻译
2025-12-26 00:11:03 -08:00
ianshaw
09431cfc0b feat(frontend): 支持 Gemini OAuth 类型选择 (Code Assist/AI Studio)
- CreateAccountModal.vue: 新增 OAuth 类型选择 UI
- ReAuthAccountModal.vue: 重授权支持选择类型
- OAuthAuthorizationFlow.vue: 新增 Project ID 输入框
- AccountTestModal.vue: Gemini 模型默认选择优化
- useGeminiOAuth.ts: OAuth 逻辑参数变更
- gemini.ts: API 调用更新
2025-12-26 00:11:03 -08:00
ianshaw
46cb82bac0 feat(backend): 添加 Gemini V1beta Handler 和路由
- 新增 gemini_v1beta_handler.go: 代理原生 Google API 格式
- 更新 gemini_oauth_handler.go: 移除 redirectUri,新增 oauthType
- 更新 account_handler.go: 账户 Handler 增强
- 更新 router.go: 注册 v1beta 路由
- 更新 config.go: Gemini OAuth 通过环境变量配置
- 更新 wire_gen.go: 依赖注入
2025-12-26 00:11:03 -08:00
ianshaw
b2d71da2a2 feat(backend): 实现 Gemini AI Studio OAuth 和消息兼容服务
- gemini_oauth_service.go: 新增 AI Studio OAuth 类型支持
- gemini_token_provider.go: Token 提供器增强
- gemini_messages_compat_service.go: 支持 AI Studio 端点
- account_test_service.go: Gemini 账户可用性检测
- gateway_service.go: 网关服务适配
- openai_gateway_service.go: OpenAI 兼容层调整
2025-12-26 00:11:03 -08:00
ianshaw
2d6e1d26c0 feat(backend): 扩展 Gemini OAuth Repository 层
- 更新 gemini_oauth_client.go: 支持 AI Studio OAuth 客户端
- 更新 geminicli_codeassist_client.go: 适配新的认证流程
2025-12-26 00:11:03 -08:00
ianshaw
50734c5edc feat(backend): 添加 Google API Key 认证中间件
- 新增 api_key_auth_google.go: 支持 x-goog-api-key 格式认证
- 更新 api_key_auth.go: 适配 Gemini 原生 API 格式
2025-12-26 00:11:03 -08:00
ianshaw
040dc27ea5 feat(backend): 添加 Gemini/Google API 基础包
- 新增 pkg/gemini: 模型定义与回退列表
- 新增 pkg/googleapi: Google API 错误状态处理
- 新增 pkg/geminicli/models.go: CLI 模型结构
- 更新 constants.go: AI Studio 相关常量
- 更新 oauth.go: 支持 AI Studio OAuth 流程,凭据通过环境变量配置
2025-12-26 00:10:44 -08:00
ianshaw
d7090de0e0 chore: 更新 .gitignore 忽略 Go 编译缓存
添加 backend/.gocache/ 到忽略列表
2025-12-26 00:10:44 -08:00
ianshaw
cab681c7d1 chore(frontend): 更新项目配置和依赖
- 更新 package-lock.json 依赖版本
- 优化 Vite、PostCSS、Tailwind 配置
- 更新入口 HTML 文件
- 更新 TypeScript 构建缓存
2025-12-26 00:10:44 -08:00
ianshaw
01f990a5c9 style(frontend): 统一核心模块代码风格
- Composables: 优化 OAuth 相关 hooks 代码格式
- Stores: 规范状态管理模块格式
- Types: 统一类型定义格式
- Utils: 优化工具函数格式
- App.vue & style.css: 调整全局样式和主组件格式
2025-12-26 00:10:44 -08:00
ianshaw
5763f5ced3 style(frontend): 统一 Views 模块代码风格
- 移除语句末尾分号,规范代码格式
- 优化组件结构和类型定义
- 改进视图文档和示例
- 提升代码一致性
2025-12-26 00:10:44 -08:00
ianshaw
f79b0f0fad style(frontend): 统一 API 模块代码风格
- 移除所有语句末尾分号
- 统一对象属性尾随逗号格式
- 优化类型定义导入顺序
- 提升代码一致性和可读性
2025-12-26 00:10:44 -08:00
ianshaw
34183b527b feat(frontend): 添加 OAuth 回调路由
- 新增 /auth/callback 路由用于接收 OAuth 授权回调
- 优化路由配置和文档
- 统一代码格式
2025-12-26 00:10:44 -08:00
ianshaw
bceed08fc3 feat(frontend): 添加 Gemini 平台国际化支持
- 新增中文 Gemini OAuth 相关翻译(步骤说明、错误提示等)
- 新增英文 Gemini OAuth 相关翻译
- 添加 Gemini 账号类型、平台名称等基础翻译
- 优化代码格式
2025-12-26 00:10:44 -08:00
ianshaw
5deef27e1d style(frontend): 优化 Components 代码风格和结构
- 统一移除语句末尾分号,规范代码格式
- 优化组件类型定义和 props 声明
- 改进组件文档和示例代码
- 提升代码可读性和一致性
2025-12-26 00:10:01 -08:00
ianshaw
1ac8b1f03e feat(frontend): Components 集成 Gemini 账号支持
- CreateAccountModal: 添加 Gemini 平台选项和 OAuth 授权流程
- EditAccountModal: 支持 Gemini 账号编辑
- OAuthAuthorizationFlow: 新增 Gemini 平台 OAuth 流程处理(支持 state 参数)
- ReAuthAccountModal: 支持 Gemini 账号重新授权
- 优化代码格式和组件逻辑
2025-12-26 00:09:46 -08:00
ianshaw
0b30cc2b7e feat(frontend): 新增 Gemini OAuth 授权流程
- 新增 /admin/gemini API 接口封装(generateAuthUrl, exchangeCode)
- 新增 useGeminiOAuth composable 处理 Gemini OAuth 流程
- 新增 OAuthCallbackView 视图用于接收 OAuth 回调
- 支持 code/state 参数提取和 credentials 构建
2025-12-26 00:09:46 -08:00
ianshaw
03a8ae62e5 feat(backend): 完善 Gemini OAuth Token 处理
- 修复 account_handler 中 token 字段类型转换(int64 转 string)
- 增强 Account.GetCredential 支持多种数值类型(float64, int, json.Number 等)
- 添加 Account.IsGemini() 方法用于平台判断
- 优化 refresh_token 和 scope 的空值处理
2025-12-26 00:09:46 -08:00
ianshaw
e36fb98fb9 feat(handler): 添加 Gemini OAuth Handler 和完善依赖注入
- 新增 Gemini OAuth 授权处理器
- 扩展账号和网关处理器支持 Gemini
- 注册 Gemini 相关路由
- 更新 Wire 依赖注入配置(所有层)
- 更新 Docker Compose 配置
2025-12-26 00:09:46 -08:00
ianshaw
55258bf099 feat(service): 扩展 CRS 同步和定价服务支持 Gemini
- CRS 同步服务新增 Gemini 账号同步逻辑(+273行)
- 定价服务扩展 Gemini 模型定价计算(+99行)
- 更新 Token 刷新服务集成 Gemini
- 更新相关单元测试
2025-12-26 00:09:04 -08:00
ianshaw
dc109827b7 feat(service): 实现 Gemini OAuth 和 Token 管理服务
- 实现 OAuth 授权流程服务
- 添加 Token 提供者和自动刷新机制
- 实现 Gemini Messages API 兼容层
- 更新服务容器注册
2025-12-26 00:09:04 -08:00
ianshaw
71c28e436a feat(service): 定义 Gemini 服务端口接口
- 定义 OAuth 服务接口
- 定义 Token 缓存服务接口
- 定义 Code Assist 服务接口
2025-12-26 00:08:27 -08:00
ianshaw
2bafc28a9b feat(repository): 实现 Gemini OAuth 和 Token 缓存客户端
- 添加 Gemini OAuth 客户端实现
- 实现 Redis 基础的 Token 缓存
- 添加 gemini-cli Code Assist 客户端封装
2025-12-26 00:08:27 -08:00
ianshaw
aea48ae1ab feat(config): 新增 Gemini 配置项和 geminicli 核心包
- 添加 Gemini OAuth 配置结构
- 实现 geminicli 包(OAuth、Token、CodeAssist 类型)
- 更新配置示例文件
2025-12-26 00:08:27 -08:00
shaw
b3463769dc chore: 调整403重试次数跟间隔 2025-12-26 14:19:57 +08:00
shaw
d9e6cfc44d feat: apikey使用弹出适配codex分组 2025-12-26 13:46:40 +08:00
Forest
57fd172287 refactor: 调整 server 目录结构 2025-12-26 10:42:35 +08:00
NepetaLemon
8d7a497553 refactor: 自定义业务错误 (#33)
* refactor: 自定义业务错误

* refactor: 隐藏服务器错误与统一 panic 响应
2025-12-26 08:47:00 +08:00
shaw
b31698b9f2 fix: 修复账户代理ip编辑保存不生效的bug 2025-12-25 21:58:09 +08:00
Forest
eeaff85e47 refactor: 自定义业务错误 2025-12-25 21:06:40 +08:00
Forest
f51ad2e126 refactor: 删除 ports 目录 2025-12-25 17:15:01 +08:00
hi_yueban
f57f12c6cc fix: 修复 OpenAI 账号 5h/7d 使用限制显示错误的问题 (#30)
* fix: 修复 OpenAI 账号 5h/7d 使用限制显示错误的问题

问题描述:
- 账号管理页面中,OpenAI OAuth 账号的 5h 列显示 7 天的剩余时间
- 7d 列却显示几小时的剩余时间
- 根本原因: OpenAI 响应头中 primary/secondary 的实际含义与代码假设相反

修复方案:
1. 后端归一化 (openai_gateway_service.go):
   - 根据 window_minutes 动态判断哪个是 5h/7d 限制
   - 新增规范字段 codex_5h_* 和 codex_7d_*
   - 保留旧字段以兼容性

2. 前端适配 (AccountUsageCell.vue):
   - 优先使用新的规范字段
   - Fallback 到旧字段时基于 window_minutes 动态判断
   - 更新 computed 属性命名

3. 类型定义更新 (types/index.ts):
   - 添加新的规范字段定义
   - 更新注释说明实际语义由 window_minutes 决定

🤖 Generated with Claude Code and Codex collaboration

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-Authored-By: OpenAI Codex <noreply@openai.com>

* fix: 改进窗口判断逻辑,修复两个窗口都小于阈值时的bug

问题:
当两个窗口都小于360分钟时(如 primary=180分钟,secondary=300分钟),
之前的逻辑会导致:
- primary5h = true, secondary5h = true
- 5h 字段会使用 primary(错误)
- 7d 字段没有数据(bug)

修复方案:
改用比较策略:
1. 当两个窗口都存在时:较小的分配给5h,较大的分配给7d
2. 当只有一个窗口时:根据大小(<=360分钟)判断是5h还是7d
3. 确保数据不会丢失,逻辑更健壮

示例:
- Primary: 180分钟, Secondary: 300分钟
  → 5h 使用 Primary(180分钟), 7d 使用 Secondary(300分钟) ✓

🤖 Generated with Claude Code

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* fix: 修正窗口大小判断逻辑 - 不能用剩余时间判断窗口类型

**严重bug修复:**
之前的 fallback 逻辑错误地使用 reset_after_seconds 来判断窗口大小。

问题示例:
- 周限制(7d)剩余 2h → reset_after_seconds = 7200秒
- 5h限制 剩余 4h → reset_after_seconds = 14400秒
- 错误逻辑:7200/60 < 14400/60,把周限制当成5h 

根本问题:
- window_minutes = 窗口的总大小(300 or 10080)
- reset_after_seconds = 距离重置的剩余时间(变化的)
- 不能用剩余时间来判断窗口类型!

修复方案:
1. **只使用 window_minutes** 来判断窗口大小
2. 移除错误的 reset_after_seconds fallback
3. 如果 window_minutes 都不存在,使用传统假设
4. 添加详细注释说明这个陷阱

🤖 Generated with Claude Code

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* fix: 修复 lint 问题 - 改进 fallback 逻辑的变量赋值

问题:
第882-883行的简单布尔赋值可能触发 ineffassign 或 staticcheck 警告:
  use5hFromSecondary = snapshot.SecondaryUsedPercent != nil
  use7dFromPrimary = snapshot.PrimaryUsedPercent != nil

修复:
改用明确的 if 语句检查任意字段是否存在,更符合代码意图:
- 如果 secondary 的任意字段存在,将其视为 5h
- 如果 primary 的任意字段存在,将其视为 7d

这样逻辑更清晰,也避免了 lint 警告。

---------

Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-authored-by: OpenAI Codex <noreply@openai.com>
2025-12-25 17:00:02 +08:00
shaw
5fca2d10b9 fix: 修复image地址 2025-12-25 16:57:29 +08:00
shaw
8fbe1ad70d chore: 调整403重试次数跟间隔 2025-12-25 16:23:31 +08:00
Forest
25a304c231 test: 增加 repository 测试 2025-12-25 16:01:17 +08:00
刀刀
9d30ceae8d CC 400 返回具体错误信息 && 非 CC 请求时增加 system prompt (#26)
* feat: http 400 返回具体错误

* 更新 workflows

* 优化打包/docker 构建流程

* 400 是返回 原始错误 - json 格式

* feat: 非 cc请求时补充 system

* go mod tidy
2025-12-25 14:47:19 +08:00
IanShaw
60f6ed6bf6 feat: CRS 同步增强 - 自动刷新 OAuth token 和修复测试配置 (#27)
* fix(service): 修复 OpenAI Responses API 测试负载配置

- 所有账号类型统一添加 instructions 字段(不再仅限 OAuth)
- Responses API 要求所有请求必须包含 instructions 参数

* feat(crs-sync): CRS 同步时自动刷新 OAuth token 并保留完整 extra 字段

**核心功能**:
- CRSSyncService 注入 OAuth 服务依赖(Anthropic + OpenAI)
- 账号创建/更新后自动刷新 OAuth token,确保可用性
- 完整保留 CRS extra 字段,避免数据丢失

**Extra 字段增强**:
- 保留 CRS 所有原始 extra 字段
- 新增同步元数据: crs_account_id, crs_kind, crs_synced_at
- Claude 账号: 从 credentials 提取 org_uuid/account_uuid 到 extra
- OpenAI 账号: 映射 crs_email -> email

**Token 刷新逻辑**:
- 新增 refreshOAuthToken() 方法处理 Anthropic/OpenAI 平台
- 保留原有 credentials 字段,仅更新 token 相关字段
- 刷新失败静默处理,不中断同步流程

**依赖注入**:
- wire_gen.go: CRSSyncService 新增 oAuthService/openaiOAuthService

* style(crs-sync): 使用 switch 替代 if-else 修复 golangci-lint 警告

- 将 refreshOAuthToken 中的 if-else 改为 switch 语句
- 符合 staticcheck 规范
- 添加 default 分支处理未知平台
2025-12-25 14:45:17 +08:00
shaw
4a2f7d4a99 chore: CRS迁移功能增加版本提示 2025-12-25 10:57:04 +08:00
shaw
c19a393be9 Merge PR #24: feat: 添加账户同步与批量编辑功能
- 添加从 CRS 同步账户功能 (Claude OAuth/API Key, OpenAI OAuth/Responses)
- 添加批量编辑账户功能,支持 JSONB 字段智能合并
- 新增 CRSSyncService、BulkUpdate 仓储方法
- 前端新增 SyncFromCrsModal 和 BulkEditAccountModal 组件
2025-12-25 10:44:40 +08:00
ianshaw
938ffb002e style(frontend): format code with prettier
格式化前端业务代码,符合代码规范
- 统一代码风格
- 修复 ESLint 警告
2025-12-24 18:07:58 -08:00
ianshaw
372a01290b fix(backend): handle defer Close() errors in crs_sync_service
修复 golangci-lint 错误检查问题
- 使用匿名函数包装 defer Close() 并忽略错误
- 符合 Go 最佳实践
2025-12-24 17:58:47 -08:00
ianshaw
8b163ca49b chore: trigger CI after enabling Actions 2025-12-24 17:56:55 -08:00
ianshaw
d23810dc53 chore: trigger CI workflow 2025-12-24 17:54:43 -08:00
ianshaw
62ed5422dd feat(account): 优化批量更新实现,使用统一 SQL 合并 JSONB 字段
- 新增 BulkUpdate 仓储方法,使用单条 SQL 更新所有账户
- credentials/extra 使用 COALESCE(...) || ? 合并,只更新传入的 key
- name/proxy_id/concurrency/priority/status 只在提供时更新
- 分组绑定仍逐账号处理(需要独立操作)
- 前端优化:Base URL 留空则不修改,按勾选字段更新
- 完善 i18n 文案:说明留空不修改、批量更新行为
2025-12-24 17:16:19 -08:00
ianshaw
2e76302af7 feat(account): 添加批量编辑账户凭据功能并优化 CRS 同步
- 新增批量更新账户凭据接口(account_uuid/org_uuid/intercept_warmup_requests)
- 新增前端批量编辑模态框组件
- 优化 CRS 同步逻辑,改进 extra 字段处理
- 优化 CRS 同步 UI,添加更详细的结果展示
- 完善国际化文案(中英文)
2025-12-24 16:56:48 -08:00
ianshaw
6553828008 feat(account): 添加从 CRS 同步账户功能
- 添加账户同步 API 接口 (account_handler.go)
- 实现 CRS 同步服务 (crs_sync_service.go)
- 添加前端同步对话框组件 (SyncFromCrsModal.vue)
- 更新账户管理界面支持同步操作
- 添加账户仓库批量创建方法
- 添加中英文国际化翻译
- 更新依赖注入配置
2025-12-24 08:48:58 -08:00
ianshaw
adcb7bf00e chore: 更新 .gitignore 忽略配置文件并还原 Makefile
- 添加 backend/config.yaml 到 .gitignore(包含敏感信息)
- 添加 deploy/config.yaml 到 .gitignore(包含敏感信息)
- 添加 backend/.installed 到 .gitignore
- 还原 Makefile 到原始版本
2025-12-24 08:48:49 -08:00
shaw
876e85e7ad Merge branch 'feat/rename-go-module' 2025-12-24 21:34:37 +08:00
shaw
2e7818d688 feat(settings): 添加文档链接配置功能
- 后台系统设置新增文档链接(doc_url)配置项
- 首页顶部导航栏显示文档链接图标(条件渲染)
- Footer区域添加文档链接和GitHub链接
- 支持中英文国际化
2025-12-24 21:30:19 +08:00
Forest
836c4dda2b refactor: 重命名 go module 2025-12-24 21:07:21 +08:00
shaw
e65e9587b4 fix(concurrency): 重构并发管理使用独立Key+原生TTL
问题:旧方案使用计数器模式,每次acquire都刷新TTL,导致僵尸数据永不过期

解决方案:
- 每个槽位使用独立Redis Key: concurrency:account:{id}:{requestID}
- 利用Redis原生TTL,每个槽位独立5分钟过期
- 服务崩溃后僵尸数据自动清理,无需手动干预
- 兼容多实例K8s部署

技术改动:
- 新增SCAN脚本统计活跃槽位数量
- 移除冗余的releaseScript,直接使用DEL命令
- Wait队列TTL只在首次创建时设置,避免刷新
2025-12-24 21:00:29 +08:00
shaw
aaadd6ed04 fix(dashboard): 修复性能指标 RPM/TPM 显示为0的问题
- 修复 Admin Dashboard Handler 遗漏返回 rpm/tpm 字段
- 将性能统计时间窗口从1分钟改为5分钟平均值,数据更稳定
2025-12-24 19:58:33 +08:00
shaw
870b21916c feat(install): 添加安装指定版本和回退功能
- 新增 rollback 命令支持回退到指定版本
- 新增 list-versions 命令列出可用版本
- 新增 -v/--version 参数指定安装版本
- upgrade 命令支持升级到指定版本
- 添加安装状态检查,未安装时给出明确提示
- 版本切换仅替换二进制文件,保留配置和数据
- 自动备份当前版本(带版本号或时间戳后缀)
- 改进网络错误处理,添加超时和友好提示
- 修复 grep -oP 兼容性问题,改用 grep -oE
2025-12-24 17:44:13 +08:00
shaw
fb119f9a67 fix(version): 优化服务重启后页面刷新时机
- 将重启后等待时间从 3 秒增加到 8 秒
- 添加倒计时显示,提升用户体验
- 倒计时结束后先检测服务健康状态再刷新页面
- 避免刷新过早导致 502 错误
2025-12-24 17:21:17 +08:00
shaw
ad54795a24 feat(gateway): 添加上游错误重试机制
- OAuth/Setup Token 账号遇到 403 错误时,等待 2 秒后重试,最多 3 次
- Console 账号遇到未配置的错误码时,同样进行重试
- 重试耗尽后:OAuth 403 标记账号异常,Console 未配置错误码不标记账号
- 移除 handleErrorResponse 中已被重试逻辑覆盖的死代码
2025-12-24 16:55:46 +08:00
shaw
0abe322cca feat(accounts): 账户列表显示实时并发数
- 在账户列表 API 返回中添加 current_concurrency 字段
- 合并平台和类型列为 PlatformTypeBadge 组件,节省表格空间
- 新增并发状态列,显示 当前/最大 并发数,支持颜色编码
2025-12-24 15:44:45 +08:00
shaw
b071511676 refactor(accounts): 优化用量窗口显示,统一 OAuth 和 Setup Token 处理
- Setup Token 账号现在也调用 API 获取 5h 窗口用量数据
- 重新设计 UsageProgressBar UI,将用量统计移到进度条上方
- 删除冗余的 SetupTokenTimeWindow 组件
- 请求数/Token数支持 K/M/B 单位显示
2025-12-24 10:57:40 +08:00
shaw
7d9a757a26 feat(dashboard): 添加 RPM/TPM 性能指标
在 Dashboard 中用 RPM/TPM 卡片替换原来的"今日缓存"卡片,
实时显示最近1分钟的请求数和 Token 吞吐量。
2025-12-24 10:24:02 +08:00
Forest
bbf4024dc7 refactor(usage): 移动 usage 查询到 services 2025-12-24 08:41:31 +08:00
shaw
5831eb8a6a fix: 修复Claude OAuth token交换时authorization code解析错误
原代码中 `parts` 变量被创建但从未使用,导致 `len(parts) == 0`
永远为 true,使得即使成功从 `code#state` 格式中分割出 authCode,
最后也会被覆盖为原始的完整字符串。

这导致传递给Claude Token端点的code包含了 `#state` 部分,
Claude返回 "Invalid 'code' in request" 错误。
2025-12-23 19:42:52 +08:00
shaw
61838cdb3d fix: 兼容GLM等API的usage数据解析
部分第三方API(如GLM)的SSE响应格式与标准Claude API不同:
- 标准Claude: input_tokens在message_start中
- GLM等API: 所有tokens都在message_delta中

现在从message_delta中也解析input_tokens和cache相关字段,
如果message_start中没有值则使用message_delta中的数据。
2025-12-23 19:42:52 +08:00
dexcoder6
50dba656fd feat: 添加用户余额充值/退款功能 (#17)
## 功能特性

### 前端
- 在用户列表操作列添加充值和退款按钮
- 实现充值/退款对话框,支持输入金额和备注
- 从编辑用户表单中移除余额字段,防止直接修改
- 添加余额不足验证,实时显示操作后余额
- 优化备注提示词,提供多种场景示例

### 后端
- 为 redeem_codes 表添加 notes 字段(迁移文件)
- 在 UpdateUserBalance 接口添加 notes 参数支持
- 添加余额验证:金额必须大于0,操作后余额不能为负
- UpdateUser 接口移除 balance 字段处理,防止误操作
- 完整的审计日志和缓存管理

## 安全保护

- 前端:余额不足时禁用提交按钮,实时提示
- 后端:双重验证(输入金额 > 0 + 结果余额 >= 0)
- 权限:仅管理员可访问(AdminAuth 中间件)
- 审计:所有操作记录到 redeem_codes 表

## 修改文件

后端:
- backend/migrations/004_add_redeem_code_notes.sql
- backend/internal/model/redeem_code.go
- backend/internal/service/admin_service.go
- backend/internal/handler/admin/user_handler.go

前端:
- frontend/src/views/admin/UsersView.vue
- frontend/src/api/admin/users.ts
- frontend/src/i18n/locales/zh.ts
- frontend/src/i18n/locales/en.ts

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-23 16:29:57 +08:00
shaw
0e2821456c chore: 忽略TypeScript增量编译缓存文件 2025-12-23 16:27:56 +08:00
shaw
f25ac3aff5 feat: OpenAI OAuth账号显示Codex使用量
从响应头提取x-codex-*使用量信息并保存到账号Extra字段,
前端账号列表展示5h/7d窗口的使用进度条。
2025-12-23 16:26:07 +08:00
shaw
f6341b7f2b chore: 将"代理管理"菜单更名为"IP管理" 2025-12-23 15:46:10 +08:00
shaw
4e257512b9 style: 统一平台和分组列的样式
- 账号页面平台列改为与分组页面一致的标签样式
- 订阅页面分组列改用 GroupBadge 组件展示
- 修正 OpenAI OAuth 类型描述文案
2025-12-23 15:40:22 +08:00
shaw
e53b34f321 Merge PR #15: feat: 增强用户管理功能,添加用户名、微信号和备注字段 2025-12-23 14:03:07 +08:00
shaw
12ddae0184 fix: 优化OpenAI模型定价查找的回退逻辑
当模型ID在model_pricing.json中找不到时,增加智能回退策略:
- gpt-5.2-codex → 回退到 gpt-5.2
- gpt-5.2-20251222 → 去掉日期后缀回退到 gpt-5.2
- 最终回退到 DefaultTestModel (gpt-5.1-codex)
2025-12-23 13:58:56 +08:00
shaw
7b9c3f165e feat: 账号管理新增使用统计功能
- 新增账号统计弹窗,展示30天使用数据
- 显示总费用、请求数、日均费用、日均请求等汇总指标
- 显示今日概览、最高费用日、最高请求日
- 包含费用与请求趋势图(双Y轴)
- 复用模型分布图组件展示模型使用分布
- 显示实际扣费和标准计费(标准计费以较淡颜色显示)
2025-12-23 13:42:33 +08:00
dexcoder6
0b8e84f942 feat: 增强用户管理功能,添加用户名、微信号和备注字段
- 新增User模型字段:username(用户名)、wechat(微信号)、notes(备注)
- 扩展用户搜索功能,支持通过用户名和微信号搜索
- 添加用户个人资料更新功能,用户可自行编辑用户名和微信号
- 管理员用户列表新增用户名、微信号、备注显示列
- 备注字段仅对管理员可见,增强数据安全性
- 完善中英文国际化翻译
- 修复国际化文件中重复属性的TypeScript错误

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-23 11:26:22 +08:00
shaw
d9e27df9af feat: 账号列表显示所属分组
- Account模型新增Groups虚拟字段
- 账号列表API预加载Group信息
- 账号管理页面新增分组列,使用GroupBadge展示
2025-12-23 11:20:02 +08:00
shaw
f0fabf89a1 feat: 用户列表显示订阅分组及剩余天数
- User模型新增Subscriptions关联
- 用户列表批量加载订阅信息避免N+1查询
- GroupBadge组件支持显示剩余天数(过期红色、<=3天红色、<=7天橙色)
- 用户管理页面新增订阅分组列
2025-12-23 11:03:10 +08:00
shaw
5bbfbcdae9 fix: 修复订阅窗口过期后进度条显示不正确的问题
问题:滑动窗口过期后(如昨天用满额度),前端仍显示历史数据(红色进度条100%、"即将重置")

解决:
- 后端返回数据前检查窗口是否过期,过期则清零展示数据
- 前端处理 window_start 为 null 的情况,显示"窗口未激活"
- 不影响实际的窗口激活逻辑,窗口仍从当天零点开始
2025-12-23 10:38:15 +08:00
shaw
eb55947ec4 fix: 修复golangci-lint检查问题
- 移除OpenAIGatewayHandler中未使用的userService字段
- 将账号类型判断的if-else链改为switch语句
2025-12-23 10:25:32 +08:00
shaw
5f7e5184eb feat: admin/subscriptions新增重置时间显示 2025-12-23 10:14:41 +08:00
shaw
008a111268 chore: 更新前端构建信息 2025-12-23 10:03:34 +08:00
shaw
fda753278c feat: 平台图标与计费修复
- fix(billing): 修复 OpenAI 兼容 API 缓存 token 重复计费问题
- fix(auth): 隐藏数据库错误详情,返回通用服务不可用错误
- feat(ui): 新增 PlatformIcon 组件,GroupBadge 支持平台颜色区分
- feat(ui): 账号管理新增重置状态按钮,重授权后自动清除错误
- feat(ui): 分组管理新增计费类型列,显示订阅限额信息
- ui: 首页 GPT 状态改为已支持
2025-12-23 10:01:58 +08:00
shaw
6c469b42ed feat: 新增支持codex转发 2025-12-22 22:58:31 +08:00
shaw
dacf3a2a6e fix: 去掉accept-encoding透传 2025-12-21 21:30:19 +08:00
shaw
e6add93ae3 fix(build): add -tags embed to ensure frontend is embedded
- Add -tags=embed flag to GoReleaser builds
- Add -tags embed flag to Dockerfile builds
- Fix Dockerfile COPY order to prevent frontend dist being overwritten
- Update README build instructions with embed tag explanation
2025-12-20 19:13:26 +08:00
NepetaLemon
b2273ec695 ci(backend): 修复 backend-ci 2025-12-20 16:52:38 +08:00
Forest
aa89777dda ci(backend): 调整 embed server 2025-12-20 16:44:25 +08:00
Forest
1e1f3c0c74 ci(backend): 添加 gofmt 配置 2025-12-20 16:19:40 +08:00
Forest
1fab9204eb ci(backend): 添加 unused 配置 2025-12-20 16:12:44 +08:00
Forest
dbd3e71637 ci(backend): 添加 staticcheck 配置 2025-12-20 16:01:24 +08:00
Forest
974f67211b ci(backend): 添加 ineffassign 配置 2025-12-20 15:58:08 +08:00
Forest
0338c83b90 ci(backend): 添加 errcheck 配置 2025-12-20 15:52:13 +08:00
NepetaLemon
c6b3de1199 ci(backend): 添加 github actions (#10)
## 变更内容

### CI/CD
- 添加 GitHub Actions 工作流(test + golangci-lint)
- 添加 golangci-lint 配置,启用 errcheck/govet/staticcheck/unused/depguard
- 通过 depguard 强制 service 层不能直接导入 repository

### 错误处理修复
- 修复 CSV 写入、SSE 流式输出、随机数生成等未处理的错误
- GenerateRedeemCode() 现在返回 error

### 资源泄露修复
- 统一使用 defer func() { _ = xxx.Close() }() 模式

### 代码清理
- 移除未使用的常量
- 简化 nil map 检查
- 统一代码格式
2025-12-20 02:29:52 -05:00
shaw
f1325e9ae6 chore: 调整Turnstile设置跳转地址 2025-12-20 15:14:36 +08:00
shaw
587012396b feat: 支持创建管理员APIKEY 2025-12-20 15:11:43 +08:00
shaw
adebd941e1 fix: 修复Oauth账号自动刷新token失败的bug 2025-12-20 13:01:58 +08:00
Wesley Liddick
bb500b7b2a Merge pull request #9 from NepetaLemon/refactor/add-http-service-ports
refactor(backend): service http ports
2025-12-19 23:35:13 -05:00
Forest
cceada7dae refactor(backend): service http ports 2025-12-20 11:57:02 +08:00
shaw
5c2e7ae265 fix: 调整订阅计费时间窗口为每日0点
- 窗口激活/重置时使用当天零点而非精确时间
- 使用服务器本地时区计算零点(支持 UTC+8 等时区)
- 窗口重置时失效 Redis 缓存,避免数据不一致
2025-12-20 11:33:06 +08:00
shaw
420bedd615 Merge PR #8: refactor(backend): 添加 service 缓存端口 2025-12-20 11:05:01 +08:00
shaw
a79f6c5e1e feat: 给所有表格页面增加刷新按钮 2025-12-20 10:48:42 +08:00
shaw
0484c59ead feat: /admin/usage页面增加模型分布情况显示 2025-12-20 10:06:55 +08:00
Forest
7bbf621490 refactor(backend): 添加 service 缓存端口 2025-12-19 23:44:18 +08:00
shaw
ef81aeb463 fix: 修复dashboard页面用户名的显示bug 2025-12-19 22:41:26 +08:00
shaw
22414326cc fix: 修复前端切换页面时logo跟标题闪烁的问题 2025-12-19 22:33:36 +08:00
Wesley Liddick
14b155c66b Merge pull request #7 from NepetaLemon/refactor/ports-pattern
refactor(backend): 引入端口接口模式
2025-12-19 08:29:04 -05:00
Forest
e99b344b2b refactor(backend): 引入端口接口模式 2025-12-19 21:26:19 +08:00
shaw
7fd94ab78b fix: 修复usage页面未显示缓存写入的问题 2025-12-19 16:57:31 +08:00
shaw
078529e51e chore: 更新docker的postgres版本为18 2025-12-19 16:42:03 +08:00
shaw
23a4cf11c8 fix: 设置默认logo作为favicon 2025-12-19 16:41:00 +08:00
shaw
d1f0902ec0 feat(account): 支持账号级别拦截预热请求
- 新增 intercept_warmup_requests 配置项,存储在 credentials 字段
- 启用后,标题生成、Warmup 等预热请求返回 mock 响应,不消耗上游 token
- 前端支持所有账号类型(OAuth、Setup Token、API Key)的开关配置
- 修复 OAuth 凭证刷新时丢失非 token 配置的问题
2025-12-19 16:39:25 +08:00
shaw
ee86dbca9d feat(account): 账号测试支持选择模型
- 新增 GET /api/v1/admin/accounts/:id/models 接口获取账号可用模型
- 账号测试弹窗新增模型选择下拉框
- 测试时支持传入 model_id 参数,不传则默认使用 Sonnet
- API Key 账号支持根据 model_mapping 映射测试模型
- 将模型常量提取到 claude 包统一管理
2025-12-19 16:00:09 +08:00
Wesley Liddick
733d4c2b85 Merge pull request #6 from dexcoder6/main
fix(frontend): 修复移动端菜单栏和使用记录页面 UI 问题
2025-12-19 02:59:05 -05:00
dexcoder6
406d3f3cab fix(frontend): 修复移动端菜单栏和使用记录页面 UI 问题
- 修复移动端无法打开菜单栏的问题
  - 在 app.ts 中添加 mobileOpen 状态管理
  - 修复 AppHeader.vue 中移动端菜单按钮调用错误的方法
  - 修复 AppSidebar.vue 使用本地 ref 而非全局状态的问题

- 添加移动端菜单自动关闭功能
  - 点击菜单项后自动关闭侧边栏
  - 添加 150ms 延迟以显示关闭动画

- 修复使用记录页面总消费卡片溢出问题
  - 调整总消费卡片布局,将删除线价格移至说明行
  - 添加 min-w-0 flex-1 防止内容溢出
  - 保持与其他卡片高度一致

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-19 15:55:42 +08:00
shaw
1ed93a5fd0 refactor: 提取 Claude 客户端常量到独立包
- 新增 internal/pkg/claude 包统一管理 Claude Code 相关常量
  - 统一账号测试逻辑,所有账号类型使用相同的 Claude Code 风格请求
  - 网关服务使用常量包替换硬编码的 beta header 字符串
2025-12-19 15:22:52 +08:00
shaw
463ddea36f fix(frontend): 修复代理快捷添加弹窗的 i18n 解析错误
batchInputHint 中的 @ 符号需要使用 {'@'} 转义
2025-12-19 11:24:22 +08:00
shaw
e769f67699 fix(setup): 支持从配置文件读取 Setup Wizard 监听地址
Setup Wizard 之前硬编码使用 8080 端口,现在支持从 config.yaml 或
环境变量 (SERVER_HOST, SERVER_PORT) 读取监听地址,方便用户在端口
被占用时使用其他地址启动初始化向导。
2025-12-19 11:21:58 +08:00
shaw
52d2ae9708 feat(gateway): 添加 /v1/messages/count_tokens 端点
实现 Claude API 的 token 计数功能,支持 OAuth、SetupToken 和 ApiKey 三种账号类型。

特点:
- 校验订阅/余额(不扣费)
- 不计算用户和账号并发
- 不记录使用量
- 支持模型映射(ApiKey 账号)
- 支持 OAuth 账号的指纹管理和 401 重试
2025-12-19 11:12:41 +08:00
shaw
2e59998c51 fix: 代理表单字段保存时自动去除前后空格
前后端同时处理,防止因意外空格导致代理连接失败
2025-12-19 10:39:30 +08:00
shaw
32e58115cc fix(frontend): 修复代理快捷添加弹窗的 i18n 解析错误
转义 batchInputPlaceholder 中的 @ 符号,防止 Vue I18n 将其误解析为链接消息语法
2025-12-19 10:32:22 +08:00
shaw
ba27026399 docs: 调整源码编译步骤的顺序 2025-12-19 09:47:17 +08:00
350 changed files with 48082 additions and 11883 deletions

41
.github/workflows/backend-ci.yml vendored Normal file
View File

@@ -0,0 +1,41 @@
name: CI
on:
push:
pull_request:
permissions:
contents: read
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: backend/go.mod
check-latest: true
cache: true
- name: Unit tests
working-directory: backend
run: make test-unit
- name: Integration tests
working-directory: backend
run: make test-integration
golangci-lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: backend/go.mod
check-latest: true
cache: true
- name: golangci-lint
uses: golangci/golangci-lint-action@v9
with:
version: v2.7
args: --timeout=5m
working-directory: backend

View File

@@ -85,6 +85,19 @@ jobs:
go-version: '1.24'
cache-dependency-path: backend/go.sum
# Docker setup for GoReleaser
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Fetch tags with annotations
run: |
# 确保获取完整的 annotated tag 信息
@@ -117,87 +130,16 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
TAG_MESSAGE: ${{ steps.tag_message.outputs.message }}
GITHUB_REPO_OWNER: ${{ github.repository_owner }}
GITHUB_REPO_NAME: ${{ github.event.repository.name }}
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
# ===========================================================================
# Docker Build and Push
# ===========================================================================
docker:
needs: [update-version, build-frontend]
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Download VERSION artifact
uses: actions/download-artifact@v4
with:
name: version-file
path: backend/cmd/server/
- name: Download frontend artifact
uses: actions/download-artifact@v4
with:
name: frontend-dist
path: backend/internal/web/dist/
# Extract version from tag
- name: Extract version
id: version
run: |
VERSION=${GITHUB_REF#refs/tags/v}
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "Version: $VERSION"
# Set up Docker Buildx for multi-platform builds
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
# Login to DockerHub
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# Extract metadata for Docker
- name: Extract Docker metadata
id: meta
uses: docker/metadata-action@v5
with:
images: |
weishaw/sub2api
tags: |
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=semver,pattern={{major}}
type=raw,value=latest,enable={{is_default_branch}}
# Build and push Docker image
- name: Build and push Docker image
uses: docker/build-push-action@v5
with:
context: .
file: ./Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
build-args: |
VERSION=${{ steps.version.outputs.version }}
COMMIT=${{ github.sha }}
DATE=${{ github.event.head_commit.timestamp }}
cache-from: type=gha
cache-to: type=gha,mode=max
# Update DockerHub description (optional)
# Update DockerHub description
- name: Update DockerHub description
uses: peter-evans/dockerhub-description@v4
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
repository: weishaw/sub2api
repository: ${{ secrets.DOCKERHUB_USERNAME }}/sub2api
short-description: "Sub2API - AI API Gateway Platform"
readme-filepath: ./deploy/DOCKER.md

18
.gitignore vendored
View File

@@ -21,6 +21,9 @@ coverage.html
# 依赖(使用 go mod
vendor/
# Go 编译缓存
backend/.gocache/
# ===================
# Node.js / Vue 前端
# ===================
@@ -28,6 +31,7 @@ node_modules/
frontend/node_modules/
frontend/dist/
*.local
*.tsbuildinfo
# 日志
npm-debug.log*
@@ -81,15 +85,27 @@ build/
release/
# 后端嵌入的前端构建产物
# Keep a placeholder file so `//go:embed all:dist` always has a match in CI/lint,
# while still ignoring generated frontend build outputs.
backend/internal/web/dist/
!backend/internal/web/dist/
backend/internal/web/dist/*
!backend/internal/web/dist/.keep
# 后端运行时缓存数据
backend/data/
# ===================
# 本地配置文件(包含敏感信息)
# ===================
backend/config.yaml
deploy/config.yaml
backend/.installed
# ===================
# 其他
# ===================
tests
CLAUDE.md
.claude
scripts
scripts

View File

@@ -11,6 +11,8 @@ builds:
dir: backend
main: ./cmd/server
binary: sub2api
flags:
- -tags=embed
env:
- CGO_ENABLED=0
goos:
@@ -50,10 +52,58 @@ changelog:
# 禁用自动 changelog完全使用 tag 消息
disable: true
# Docker images
dockers:
- id: amd64
goos: linux
goarch: amd64
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
dockerfile: Dockerfile.goreleaser
use: buildx
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.version={{ .Version }}"
- "--label=org.opencontainers.image.revision={{ .Commit }}"
- id: arm64
goos: linux
goarch: arm64
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
dockerfile: Dockerfile.goreleaser
use: buildx
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.version={{ .Version }}"
- "--label=org.opencontainers.image.revision={{ .Commit }}"
# Docker manifests for multi-arch support
docker_manifests:
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}"
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:latest"
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Major }}.{{ .Minor }}"
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Major }}"
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
release:
github:
owner: Wei-Shaw
name: sub2api
owner: "{{ .Env.GITHUB_REPO_OWNER }}"
name: "{{ .Env.GITHUB_REPO_NAME }}"
draft: false
prerelease: auto
name_template: "Sub2API {{.Version}}"
@@ -71,7 +121,7 @@ release:
**One-line install (Linux):**
```bash
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash
curl -sSL https://raw.githubusercontent.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }}/main/deploy/install.sh | sudo bash
```
**Manual download:**
@@ -79,5 +129,5 @@ release:
## 📚 Documentation
- [GitHub Repository](https://github.com/Wei-Shaw/sub2api)
- [Installation Guide](https://github.com/Wei-Shaw/sub2api/blob/main/deploy/README.md)
- [GitHub Repository](https://github.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }})
- [Installation Guide](https://github.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }}/blob/main/deploy/README.md)

View File

@@ -40,14 +40,15 @@ WORKDIR /app/backend
COPY backend/go.mod backend/go.sum ./
RUN go mod download
# Copy frontend dist from previous stage
COPY --from=frontend-builder /app/frontend/../backend/internal/web/dist ./internal/web/dist
# Copy backend source
# Copy backend source first
COPY backend/ ./
# Build the binary (BuildType=release for CI builds)
# Copy frontend dist from previous stage (must be after backend copy to avoid being overwritten)
COPY --from=frontend-builder /app/backend/internal/web/dist ./internal/web/dist
# Build the binary (BuildType=release for CI builds, embed frontend)
RUN CGO_ENABLED=0 GOOS=linux go build \
-tags embed \
-ldflags="-s -w -X main.Commit=${COMMIT} -X main.Date=${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)} -X main.BuildType=release" \
-o /app/sub2api \
./cmd/server

40
Dockerfile.goreleaser Normal file
View File

@@ -0,0 +1,40 @@
# =============================================================================
# Sub2API Dockerfile for GoReleaser
# =============================================================================
# This Dockerfile is used by GoReleaser to build Docker images.
# It only packages the pre-built binary, no compilation needed.
# =============================================================================
FROM alpine:3.19
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
LABEL description="Sub2API - AI API Gateway Platform"
LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
# Install runtime dependencies
RUN apk add --no-cache \
ca-certificates \
tzdata \
curl \
&& rm -rf /var/cache/apk/*
# Create non-root user
RUN addgroup -g 1000 sub2api && \
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
WORKDIR /app
# Copy pre-built binary from GoReleaser
COPY sub2api /app/sub2api
# Create data directory
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
USER sub2api
EXPOSE 8080
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
ENTRYPOINT ["/app/sub2api"]

View File

@@ -216,26 +216,25 @@ Build and run from source code for development or customization.
git clone https://github.com/Wei-Shaw/sub2api.git
cd sub2api
# 2. Build backend
cd backend
go build -o sub2api ./cmd/server
# 3. Build frontend
cd ../frontend
# 2. Build frontend
cd frontend
npm install
npm run build
# Output will be in ../backend/internal/web/dist/
# 4. Copy frontend build to backend (for embedding)
cp -r dist ../backend/internal/web/
# 5. Create configuration file
# 3. Build backend with embedded frontend
cd ../backend
go build -tags embed -o sub2api ./cmd/server
# 4. Create configuration file
cp ../deploy/config.example.yaml ./config.yaml
# 6. Edit configuration
# 5. Edit configuration
nano config.yaml
```
> **Note:** The `-tags embed` flag embeds the frontend into the binary. Without this flag, the binary will not serve the frontend UI.
**Key configuration in `config.yaml`:**
```yaml
@@ -266,7 +265,7 @@ default:
```
```bash
# 7. Run the application
# 6. Run the application
./sub2api
```

View File

@@ -216,26 +216,25 @@ docker-compose logs -f
git clone https://github.com/Wei-Shaw/sub2api.git
cd sub2api
# 2. 编译
cd backend
go build -o sub2api ./cmd/server
# 3. 编译前端
cd ../frontend
# 2. 编译
cd frontend
npm install
npm run build
# 构建产物输出到 ../backend/internal/web/dist/
# 4. 复制前端构建产物到后端(用于嵌入
cp -r dist ../backend/internal/web/
# 5. 创建配置文件
# 3. 编译后端(嵌入前端
cd ../backend
go build -tags embed -o sub2api ./cmd/server
# 4. 创建配置文件
cp ../deploy/config.example.yaml ./config.yaml
# 6. 编辑配置
# 5. 编辑配置
nano config.yaml
```
> **注意:** `-tags embed` 参数会将前端嵌入到二进制文件中。不使用此参数编译的程序将不包含前端界面。
**`config.yaml` 关键配置:**
```yaml
@@ -266,7 +265,7 @@ default:
```
```bash
# 7. 运行应用
# 6. 运行应用
./sub2api
```

602
backend/.golangci.yml Normal file
View File

@@ -0,0 +1,602 @@
version: "2"
linters:
default: none
enable:
- depguard
- errcheck
- govet
- ineffassign
- staticcheck
- unused
settings:
depguard:
rules:
# Enforce: service must not depend on repository.
service-no-repository:
list-mode: original
files:
- "**/internal/service/**"
deny:
- pkg: github.com/Wei-Shaw/sub2api/internal/repository
desc: "service must not import repository"
- pkg: gorm.io/gorm
desc: "service must not import gorm"
- pkg: github.com/redis/go-redis/v9
desc: "service must not import redis"
handler-no-repository:
list-mode: original
files:
- "**/internal/handler/**"
deny:
- pkg: github.com/Wei-Shaw/sub2api/internal/repository
desc: "handler must not import repository"
- pkg: gorm.io/gorm
desc: "handler must not import gorm"
- pkg: github.com/redis/go-redis/v9
desc: "handler must not import redis"
errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default.
# Default: false
check-type-assertions: true
# report about assignment of errors to blank identifier: `num, _ := strconv.Atoi(numStr)`.
# Such cases aren't reported by default.
# Default: false
check-blank: false
# To disable the errcheck built-in exclude list.
# See `-excludeonly` option in https://github.com/kisielk/errcheck#excluding-functions for details.
# Default: false
disable-default-exclusions: true
# List of functions to exclude from checking, where each entry is a single function to exclude.
# See https://github.com/kisielk/errcheck#excluding-functions for details.
exclude-functions:
- io/ioutil.ReadFile
- io.Copy(*bytes.Buffer)
- io.Copy(os.Stdout)
- fmt.Println
- fmt.Print
- fmt.Printf
- fmt.Fprint
- fmt.Fprintf
- fmt.Fprintln
# Display function signature instead of selector.
# Default: false
verbose: true
ineffassign:
# Check escaping variables of type error, may cause false positives.
# Default: false
check-escaping-errors: true
staticcheck:
# https://staticcheck.dev/docs/configuration/options/#dot_import_whitelist
# Default: ["github.com/mmcloughlin/avo/build", "github.com/mmcloughlin/avo/operand", "github.com/mmcloughlin/avo/reg"]
dot-import-whitelist:
- fmt
# https://staticcheck.dev/docs/configuration/options/#initialisms
# Default: ["ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS"]
initialisms: [ "ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS" ]
# https://staticcheck.dev/docs/configuration/options/#http_status_code_whitelist
# Default: ["200", "400", "404", "500"]
http-status-code-whitelist: [ "200", "400", "404", "500" ]
# SAxxxx checks in https://staticcheck.dev/docs/configuration/options/#checks
# Example (to disable some checks): [ "all", "-SA1000", "-SA1001"]
# Run `GL_DEBUG=staticcheck golangci-lint run --enable=staticcheck` to see all available checks and enabled by config checks.
# Default: ["all", "-ST1000", "-ST1003", "-ST1016", "-ST1020", "-ST1021", "-ST1022"]
checks:
# Invalid regular expression.
# https://staticcheck.dev/docs/checks/#SA1000
- SA1000
# Invalid template.
# https://staticcheck.dev/docs/checks/#SA1001
- SA1001
# Invalid format in 'time.Parse'.
# https://staticcheck.dev/docs/checks/#SA1002
- SA1002
# Unsupported argument to functions in 'encoding/binary'.
# https://staticcheck.dev/docs/checks/#SA1003
- SA1003
# Suspiciously small untyped constant in 'time.Sleep'.
# https://staticcheck.dev/docs/checks/#SA1004
- SA1004
# Invalid first argument to 'exec.Command'.
# https://staticcheck.dev/docs/checks/#SA1005
- SA1005
# 'Printf' with dynamic first argument and no further arguments.
# https://staticcheck.dev/docs/checks/#SA1006
- SA1006
# Invalid URL in 'net/url.Parse'.
# https://staticcheck.dev/docs/checks/#SA1007
- SA1007
# Non-canonical key in 'http.Header' map.
# https://staticcheck.dev/docs/checks/#SA1008
- SA1008
# '(*regexp.Regexp).FindAll' called with 'n == 0', which will always return zero results.
# https://staticcheck.dev/docs/checks/#SA1010
- SA1010
# Various methods in the "strings" package expect valid UTF-8, but invalid input is provided.
# https://staticcheck.dev/docs/checks/#SA1011
- SA1011
# A nil 'context.Context' is being passed to a function, consider using 'context.TODO' instead.
# https://staticcheck.dev/docs/checks/#SA1012
- SA1012
# 'io.Seeker.Seek' is being called with the whence constant as the first argument, but it should be the second.
# https://staticcheck.dev/docs/checks/#SA1013
- SA1013
# Non-pointer value passed to 'Unmarshal' or 'Decode'.
# https://staticcheck.dev/docs/checks/#SA1014
- SA1014
# Using 'time.Tick' in a way that will leak. Consider using 'time.NewTicker', and only use 'time.Tick' in tests, commands and endless functions.
# https://staticcheck.dev/docs/checks/#SA1015
- SA1015
# Trapping a signal that cannot be trapped.
# https://staticcheck.dev/docs/checks/#SA1016
- SA1016
# Channels used with 'os/signal.Notify' should be buffered.
# https://staticcheck.dev/docs/checks/#SA1017
- SA1017
# 'strings.Replace' called with 'n == 0', which does nothing.
# https://staticcheck.dev/docs/checks/#SA1018
- SA1018
# Using a deprecated function, variable, constant or field.
# https://staticcheck.dev/docs/checks/#SA1019
- SA1019
# Using an invalid host:port pair with a 'net.Listen'-related function.
# https://staticcheck.dev/docs/checks/#SA1020
- SA1020
# Using 'bytes.Equal' to compare two 'net.IP'.
# https://staticcheck.dev/docs/checks/#SA1021
- SA1021
# Modifying the buffer in an 'io.Writer' implementation.
# https://staticcheck.dev/docs/checks/#SA1023
- SA1023
# A string cutset contains duplicate characters.
# https://staticcheck.dev/docs/checks/#SA1024
- SA1024
# It is not possible to use '(*time.Timer).Reset''s return value correctly.
# https://staticcheck.dev/docs/checks/#SA1025
- SA1025
# Cannot marshal channels or functions.
# https://staticcheck.dev/docs/checks/#SA1026
- SA1026
# Atomic access to 64-bit variable must be 64-bit aligned.
# https://staticcheck.dev/docs/checks/#SA1027
- SA1027
# 'sort.Slice' can only be used on slices.
# https://staticcheck.dev/docs/checks/#SA1028
- SA1028
# Inappropriate key in call to 'context.WithValue'.
# https://staticcheck.dev/docs/checks/#SA1029
- SA1029
# Invalid argument in call to a 'strconv' function.
# https://staticcheck.dev/docs/checks/#SA1030
- SA1030
# Overlapping byte slices passed to an encoder.
# https://staticcheck.dev/docs/checks/#SA1031
- SA1031
# Wrong order of arguments to 'errors.Is'.
# https://staticcheck.dev/docs/checks/#SA1032
- SA1032
# 'sync.WaitGroup.Add' called inside the goroutine, leading to a race condition.
# https://staticcheck.dev/docs/checks/#SA2000
- SA2000
# Empty critical section, did you mean to defer the unlock?.
# https://staticcheck.dev/docs/checks/#SA2001
- SA2001
# Called 'testing.T.FailNow' or 'SkipNow' in a goroutine, which isn't allowed.
# https://staticcheck.dev/docs/checks/#SA2002
- SA2002
# Deferred 'Lock' right after locking, likely meant to defer 'Unlock' instead.
# https://staticcheck.dev/docs/checks/#SA2003
- SA2003
# 'TestMain' doesn't call 'os.Exit', hiding test failures.
# https://staticcheck.dev/docs/checks/#SA3000
- SA3000
# Assigning to 'b.N' in benchmarks distorts the results.
# https://staticcheck.dev/docs/checks/#SA3001
- SA3001
# Binary operator has identical expressions on both sides.
# https://staticcheck.dev/docs/checks/#SA4000
- SA4000
# '&*x' gets simplified to 'x', it does not copy 'x'.
# https://staticcheck.dev/docs/checks/#SA4001
- SA4001
# Comparing unsigned values against negative values is pointless.
# https://staticcheck.dev/docs/checks/#SA4003
- SA4003
# The loop exits unconditionally after one iteration.
# https://staticcheck.dev/docs/checks/#SA4004
- SA4004
# Field assignment that will never be observed. Did you mean to use a pointer receiver?.
# https://staticcheck.dev/docs/checks/#SA4005
- SA4005
# A value assigned to a variable is never read before being overwritten. Forgotten error check or dead code?.
# https://staticcheck.dev/docs/checks/#SA4006
- SA4006
# The variable in the loop condition never changes, are you incrementing the wrong variable?.
# https://staticcheck.dev/docs/checks/#SA4008
- SA4008
# A function argument is overwritten before its first use.
# https://staticcheck.dev/docs/checks/#SA4009
- SA4009
# The result of 'append' will never be observed anywhere.
# https://staticcheck.dev/docs/checks/#SA4010
- SA4010
# Break statement with no effect. Did you mean to break out of an outer loop?.
# https://staticcheck.dev/docs/checks/#SA4011
- SA4011
# Comparing a value against NaN even though no value is equal to NaN.
# https://staticcheck.dev/docs/checks/#SA4012
- SA4012
# Negating a boolean twice ('!!b') is the same as writing 'b'. This is either redundant, or a typo.
# https://staticcheck.dev/docs/checks/#SA4013
- SA4013
# An if/else if chain has repeated conditions and no side-effects; if the condition didn't match the first time, it won't match the second time, either.
# https://staticcheck.dev/docs/checks/#SA4014
- SA4014
# Calling functions like 'math.Ceil' on floats converted from integers doesn't do anything useful.
# https://staticcheck.dev/docs/checks/#SA4015
- SA4015
# Certain bitwise operations, such as 'x ^ 0', do not do anything useful.
# https://staticcheck.dev/docs/checks/#SA4016
- SA4016
# Discarding the return values of a function without side effects, making the call pointless.
# https://staticcheck.dev/docs/checks/#SA4017
- SA4017
# Self-assignment of variables.
# https://staticcheck.dev/docs/checks/#SA4018
- SA4018
# Multiple, identical build constraints in the same file.
# https://staticcheck.dev/docs/checks/#SA4019
- SA4019
# Unreachable case clause in a type switch.
# https://staticcheck.dev/docs/checks/#SA4020
- SA4020
# "x = append(y)" is equivalent to "x = y".
# https://staticcheck.dev/docs/checks/#SA4021
- SA4021
# Comparing the address of a variable against nil.
# https://staticcheck.dev/docs/checks/#SA4022
- SA4022
# Impossible comparison of interface value with untyped nil.
# https://staticcheck.dev/docs/checks/#SA4023
- SA4023
# Checking for impossible return value from a builtin function.
# https://staticcheck.dev/docs/checks/#SA4024
- SA4024
# Integer division of literals that results in zero.
# https://staticcheck.dev/docs/checks/#SA4025
- SA4025
# Go constants cannot express negative zero.
# https://staticcheck.dev/docs/checks/#SA4026
- SA4026
# '(*net/url.URL).Query' returns a copy, modifying it doesn't change the URL.
# https://staticcheck.dev/docs/checks/#SA4027
- SA4027
# 'x % 1' is always zero.
# https://staticcheck.dev/docs/checks/#SA4028
- SA4028
# Ineffective attempt at sorting slice.
# https://staticcheck.dev/docs/checks/#SA4029
- SA4029
# Ineffective attempt at generating random number.
# https://staticcheck.dev/docs/checks/#SA4030
- SA4030
# Checking never-nil value against nil.
# https://staticcheck.dev/docs/checks/#SA4031
- SA4031
# Comparing 'runtime.GOOS' or 'runtime.GOARCH' against impossible value.
# https://staticcheck.dev/docs/checks/#SA4032
- SA4032
# Assignment to nil map.
# https://staticcheck.dev/docs/checks/#SA5000
- SA5000
# Deferring 'Close' before checking for a possible error.
# https://staticcheck.dev/docs/checks/#SA5001
- SA5001
# The empty for loop ("for {}") spins and can block the scheduler.
# https://staticcheck.dev/docs/checks/#SA5002
- SA5002
# Defers in infinite loops will never execute.
# https://staticcheck.dev/docs/checks/#SA5003
- SA5003
# "for { select { ..." with an empty default branch spins.
# https://staticcheck.dev/docs/checks/#SA5004
- SA5004
# The finalizer references the finalized object, preventing garbage collection.
# https://staticcheck.dev/docs/checks/#SA5005
- SA5005
# Infinite recursive call.
# https://staticcheck.dev/docs/checks/#SA5007
- SA5007
# Invalid struct tag.
# https://staticcheck.dev/docs/checks/#SA5008
- SA5008
# Invalid Printf call.
# https://staticcheck.dev/docs/checks/#SA5009
- SA5009
# Impossible type assertion.
# https://staticcheck.dev/docs/checks/#SA5010
- SA5010
# Possible nil pointer dereference.
# https://staticcheck.dev/docs/checks/#SA5011
- SA5011
# Passing odd-sized slice to function expecting even size.
# https://staticcheck.dev/docs/checks/#SA5012
- SA5012
# Using 'regexp.Match' or related in a loop, should use 'regexp.Compile'.
# https://staticcheck.dev/docs/checks/#SA6000
- SA6000
# Missing an optimization opportunity when indexing maps by byte slices.
# https://staticcheck.dev/docs/checks/#SA6001
- SA6001
# Storing non-pointer values in 'sync.Pool' allocates memory.
# https://staticcheck.dev/docs/checks/#SA6002
- SA6002
# Converting a string to a slice of runes before ranging over it.
# https://staticcheck.dev/docs/checks/#SA6003
- SA6003
# Inefficient string comparison with 'strings.ToLower' or 'strings.ToUpper'.
# https://staticcheck.dev/docs/checks/#SA6005
- SA6005
# Using io.WriteString to write '[]byte'.
# https://staticcheck.dev/docs/checks/#SA6006
- SA6006
# Defers in range loops may not run when you expect them to.
# https://staticcheck.dev/docs/checks/#SA9001
- SA9001
# Using a non-octal 'os.FileMode' that looks like it was meant to be in octal.
# https://staticcheck.dev/docs/checks/#SA9002
- SA9002
# Empty body in an if or else branch.
# https://staticcheck.dev/docs/checks/#SA9003
- SA9003
# Only the first constant has an explicit type.
# https://staticcheck.dev/docs/checks/#SA9004
- SA9004
# Trying to marshal a struct with no public fields nor custom marshaling.
# https://staticcheck.dev/docs/checks/#SA9005
- SA9005
# Dubious bit shifting of a fixed size integer value.
# https://staticcheck.dev/docs/checks/#SA9006
- SA9006
# Deleting a directory that shouldn't be deleted.
# https://staticcheck.dev/docs/checks/#SA9007
- SA9007
# 'else' branch of a type assertion is probably not reading the right value.
# https://staticcheck.dev/docs/checks/#SA9008
- SA9008
# Ineffectual Go compiler directive.
# https://staticcheck.dev/docs/checks/#SA9009
- SA9009
# Incorrect or missing package comment.
# https://staticcheck.dev/docs/checks/#ST1000
- ST1000
# Dot imports are discouraged.
# https://staticcheck.dev/docs/checks/#ST1001
- ST1001
# Poorly chosen identifier.
# https://staticcheck.dev/docs/checks/#ST1003
- ST1003
# Incorrectly formatted error string.
# https://staticcheck.dev/docs/checks/#ST1005
- ST1005
# Poorly chosen receiver name.
# https://staticcheck.dev/docs/checks/#ST1006
- ST1006
# A function's error value should be its last return value.
# https://staticcheck.dev/docs/checks/#ST1008
- ST1008
# Poorly chosen name for variable of type 'time.Duration'.
# https://staticcheck.dev/docs/checks/#ST1011
- ST1011
# Poorly chosen name for error variable.
# https://staticcheck.dev/docs/checks/#ST1012
- ST1012
# Should use constants for HTTP error codes, not magic numbers.
# https://staticcheck.dev/docs/checks/#ST1013
- ST1013
# A switch's default case should be the first or last case.
# https://staticcheck.dev/docs/checks/#ST1015
- ST1015
# Use consistent method receiver names.
# https://staticcheck.dev/docs/checks/#ST1016
- ST1016
# Don't use Yoda conditions.
# https://staticcheck.dev/docs/checks/#ST1017
- ST1017
# Avoid zero-width and control characters in string literals.
# https://staticcheck.dev/docs/checks/#ST1018
- ST1018
# Importing the same package multiple times.
# https://staticcheck.dev/docs/checks/#ST1019
- ST1019
# The documentation of an exported function should start with the function's name.
# https://staticcheck.dev/docs/checks/#ST1020
- ST1020
# The documentation of an exported type should start with type's name.
# https://staticcheck.dev/docs/checks/#ST1021
- ST1021
# The documentation of an exported variable or constant should start with variable's name.
# https://staticcheck.dev/docs/checks/#ST1022
- ST1022
# Redundant type in variable declaration.
# https://staticcheck.dev/docs/checks/#ST1023
- ST1023
# Use plain channel send or receive instead of single-case select.
# https://staticcheck.dev/docs/checks/#S1000
- S1000
# Replace for loop with call to copy.
# https://staticcheck.dev/docs/checks/#S1001
- S1001
# Omit comparison with boolean constant.
# https://staticcheck.dev/docs/checks/#S1002
- S1002
# Replace call to 'strings.Index' with 'strings.Contains'.
# https://staticcheck.dev/docs/checks/#S1003
- S1003
# Replace call to 'bytes.Compare' with 'bytes.Equal'.
# https://staticcheck.dev/docs/checks/#S1004
- S1004
# Drop unnecessary use of the blank identifier.
# https://staticcheck.dev/docs/checks/#S1005
- S1005
# Use "for { ... }" for infinite loops.
# https://staticcheck.dev/docs/checks/#S1006
- S1006
# Simplify regular expression by using raw string literal.
# https://staticcheck.dev/docs/checks/#S1007
- S1007
# Simplify returning boolean expression.
# https://staticcheck.dev/docs/checks/#S1008
- S1008
# Omit redundant nil check on slices, maps, and channels.
# https://staticcheck.dev/docs/checks/#S1009
- S1009
# Omit default slice index.
# https://staticcheck.dev/docs/checks/#S1010
- S1010
# Use a single 'append' to concatenate two slices.
# https://staticcheck.dev/docs/checks/#S1011
- S1011
# Replace 'time.Now().Sub(x)' with 'time.Since(x)'.
# https://staticcheck.dev/docs/checks/#S1012
- S1012
# Use a type conversion instead of manually copying struct fields.
# https://staticcheck.dev/docs/checks/#S1016
- S1016
# Replace manual trimming with 'strings.TrimPrefix'.
# https://staticcheck.dev/docs/checks/#S1017
- S1017
# Use "copy" for sliding elements.
# https://staticcheck.dev/docs/checks/#S1018
- S1018
# Simplify "make" call by omitting redundant arguments.
# https://staticcheck.dev/docs/checks/#S1019
- S1019
# Omit redundant nil check in type assertion.
# https://staticcheck.dev/docs/checks/#S1020
- S1020
# Merge variable declaration and assignment.
# https://staticcheck.dev/docs/checks/#S1021
- S1021
# Omit redundant control flow.
# https://staticcheck.dev/docs/checks/#S1023
- S1023
# Replace 'x.Sub(time.Now())' with 'time.Until(x)'.
# https://staticcheck.dev/docs/checks/#S1024
- S1024
# Don't use 'fmt.Sprintf("%s", x)' unnecessarily.
# https://staticcheck.dev/docs/checks/#S1025
- S1025
# Simplify error construction with 'fmt.Errorf'.
# https://staticcheck.dev/docs/checks/#S1028
- S1028
# Range over the string directly.
# https://staticcheck.dev/docs/checks/#S1029
- S1029
# Use 'bytes.Buffer.String' or 'bytes.Buffer.Bytes'.
# https://staticcheck.dev/docs/checks/#S1030
- S1030
# Omit redundant nil check around loop.
# https://staticcheck.dev/docs/checks/#S1031
- S1031
# Use 'sort.Ints(x)', 'sort.Float64s(x)', and 'sort.Strings(x)'.
# https://staticcheck.dev/docs/checks/#S1032
- S1032
# Unnecessary guard around call to "delete".
# https://staticcheck.dev/docs/checks/#S1033
- S1033
# Use result of type assertion to simplify cases.
# https://staticcheck.dev/docs/checks/#S1034
- S1034
# Redundant call to 'net/http.CanonicalHeaderKey' in method call on 'net/http.Header'.
# https://staticcheck.dev/docs/checks/#S1035
- S1035
# Unnecessary guard around map access.
# https://staticcheck.dev/docs/checks/#S1036
- S1036
# Elaborate way of sleeping.
# https://staticcheck.dev/docs/checks/#S1037
- S1037
# Unnecessarily complex way of printing formatted string.
# https://staticcheck.dev/docs/checks/#S1038
- S1038
# Unnecessary use of 'fmt.Sprint'.
# https://staticcheck.dev/docs/checks/#S1039
- S1039
# Type assertion to current type.
# https://staticcheck.dev/docs/checks/#S1040
- S1040
# Apply De Morgan's law.
# https://staticcheck.dev/docs/checks/#QF1001
- QF1001
# Convert untagged switch to tagged switch.
# https://staticcheck.dev/docs/checks/#QF1002
- QF1002
# Convert if/else-if chain to tagged switch.
# https://staticcheck.dev/docs/checks/#QF1003
- QF1003
# Use 'strings.ReplaceAll' instead of 'strings.Replace' with 'n == -1'.
# https://staticcheck.dev/docs/checks/#QF1004
- QF1004
# Expand call to 'math.Pow'.
# https://staticcheck.dev/docs/checks/#QF1005
- QF1005
# Lift 'if'+'break' into loop condition.
# https://staticcheck.dev/docs/checks/#QF1006
- QF1006
# Merge conditional assignment into variable declaration.
# https://staticcheck.dev/docs/checks/#QF1007
- QF1007
# Omit embedded fields from selector expression.
# https://staticcheck.dev/docs/checks/#QF1008
- QF1008
# Use 'time.Time.Equal' instead of '==' operator.
# https://staticcheck.dev/docs/checks/#QF1009
- QF1009
# Convert slice of bytes to string when printing it.
# https://staticcheck.dev/docs/checks/#QF1010
- QF1010
# Omit redundant type from variable declaration.
# https://staticcheck.dev/docs/checks/#QF1011
- QF1011
# Use 'fmt.Fprintf(x, ...)' instead of 'x.Write(fmt.Sprintf(...))'.
# https://staticcheck.dev/docs/checks/#QF1012
- QF1012
unused:
# Mark all struct fields that have been written to as used.
# Default: true
field-writes-are-uses: false
# Treat IncDec statement (e.g. `i++` or `i--`) as both read and write operation instead of just write.
# Default: false
post-statements-are-reads: true
# Mark all exported fields as used.
# default: true
exported-fields-are-used: false
# Mark all function parameters as used.
# default: true
parameters-are-used: true
# Mark all local variables as used.
# default: true
local-variables-are-used: false
# Mark all identifiers inside generated files as used.
# Default: true
generated-is-used: false
formatters:
enable:
- gofmt
settings:
gofmt:
# Simplify code: gofmt with `-s` option.
# Default: true
simplify: false
# Apply the rewrite rules to the source before reformatting.
# https://pkg.go.dev/cmd/gofmt
# Default: []
rewrite-rules:
- pattern: 'interface{}'
replacement: 'any'
- pattern: 'a[b:len(a)]'
replacement: 'a[b:]'

View File

@@ -1,6 +1,33 @@
.PHONY: wire
.PHONY: wire build build-embed test-unit test-integration test-cover-integration clean-coverage
wire:
@echo "生成 Wire 代码..."
@cd cmd/server && go generate
@echo "Wire 代码生成完成"
@echo "Wire 代码生成完成"
build:
@echo "构建后端(不嵌入前端)..."
@go build -o bin/server ./cmd/server
@echo "构建完成: bin/server"
build-embed:
@echo "构建后端(嵌入前端)..."
@go build -tags embed -o bin/server ./cmd/server
@echo "构建完成: bin/server (with embedded frontend)"
test-unit:
@go test -tags unit ./... -count=1
test-integration:
@go test -tags integration ./... -count=1 -race -parallel=8
test-cover-integration:
@echo "运行集成测试并生成覆盖率报告..."
@go test -tags=integration -cover -coverprofile=coverage.out -count=1 -race -parallel=8 ./...
@go tool cover -func=coverage.out | tail -1
@go tool cover -html=coverage.out -o coverage.html
@echo "覆盖率报告已生成: coverage.html"
clean-coverage:
@rm -f coverage.out coverage.html
@echo "覆盖率文件已清理"

View File

@@ -15,10 +15,11 @@ import (
"syscall"
"time"
"sub2api/internal/handler"
"sub2api/internal/middleware"
"sub2api/internal/setup"
"sub2api/internal/web"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/setup"
"github.com/Wei-Shaw/sub2api/internal/web"
"github.com/gin-gonic/gin"
)
@@ -83,7 +84,7 @@ func main() {
func runSetupServer() {
r := gin.New()
r.Use(gin.Recovery())
r.Use(middleware.Recovery())
r.Use(middleware.CORS())
// Register setup routes
@@ -94,8 +95,10 @@ func runSetupServer() {
r.Use(web.ServeEmbeddedFrontend())
}
addr := ":8080"
log.Printf("Setup wizard available at http://localhost%s", addr)
// Get server address from config.yaml or environment variables (SERVER_HOST, SERVER_PORT)
// This allows users to run setup on a different address if needed
addr := config.GetServerAddress()
log.Printf("Setup wizard available at http://%s", addr)
log.Println("Complete the setup wizard to configure Sub2API")
if err := r.Run(addr); err != nil {

View File

@@ -4,18 +4,19 @@
package main
import (
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/infrastructure"
"sub2api/internal/repository"
"sub2api/internal/server"
"sub2api/internal/service"
"context"
"log"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/server"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
@@ -35,11 +36,15 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
// 业务层 ProviderSets
repository.ProviderSet,
service.ProviderSet,
middleware.ProviderSet,
handler.ProviderSet,
// 服务器层 ProviderSet
server.ProviderSet,
// BuildInfo provider
provideServiceBuildInfo,
// 清理函数提供者
provideCleanup,
@@ -49,10 +54,22 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
return nil, nil
}
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
return service.BuildInfo{
Version: buildInfo.Version,
BuildType: buildInfo.BuildType,
}
}
func provideCleanup(
db *gorm.DB,
rdb *redis.Client,
services *service.Services,
tokenRefresh *service.TokenRefreshService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
oauth *service.OAuthService,
openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -63,12 +80,28 @@ func provideCleanup(
name string
fn func() error
}{
{"TokenRefreshService", func() error {
tokenRefresh.Stop()
return nil
}},
{"PricingService", func() error {
services.Pricing.Stop()
pricing.Stop()
return nil
}},
{"EmailQueueService", func() error {
services.EmailQueue.Stop()
emailQueue.Stop()
return nil
}},
{"OAuthService", func() error {
oauth.Stop()
return nil
}},
{"OpenAIOAuthService", func() error {
openaiOAuth.Stop()
return nil
}},
{"GeminiOAuthService", func() error {
geminiOAuth.Stop()
return nil
}},
{"Redis", func() error {

View File

@@ -8,17 +8,18 @@ package main
import (
"context"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/server"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
"log"
"net/http"
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/handler/admin"
"sub2api/internal/infrastructure"
"sub2api/internal/repository"
"sub2api/internal/server"
"sub2api/internal/service"
"time"
)
@@ -41,100 +42,95 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
settingRepository := repository.NewSettingRepository(db)
settingService := service.NewSettingService(settingRepository, configConfig)
client := infrastructure.ProvideRedis(configConfig)
emailService := service.NewEmailService(settingRepository, client)
turnstileService := service.NewTurnstileService(settingService)
emailCache := repository.NewEmailCache(client)
emailService := service.NewEmailService(settingRepository, emailCache)
turnstileVerifier := repository.NewTurnstileVerifier()
turnstileService := service.NewTurnstileService(settingService, turnstileVerifier)
emailQueueService := service.ProvideEmailQueueService(emailService)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
authHandler := handler.NewAuthHandler(authService)
userService := service.NewUserService(userRepository, configConfig)
userService := service.NewUserService(userRepository)
authHandler := handler.NewAuthHandler(authService, userService)
userHandler := handler.NewUserHandler(userService)
apiKeyRepository := repository.NewApiKeyRepository(db)
groupRepository := repository.NewGroupRepository(db)
userSubscriptionRepository := repository.NewUserSubscriptionRepository(db)
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, client, configConfig)
apiKeyCache := repository.NewApiKeyCache(client)
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(db)
usageService := service.NewUsageService(usageLogRepository, userRepository)
usageHandler := handler.NewUsageHandler(usageService, usageLogRepository, apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(db)
accountRepository := repository.NewAccountRepository(db)
proxyRepository := repository.NewProxyRepository(db)
repositories := &repository.Repositories{
User: userRepository,
ApiKey: apiKeyRepository,
Group: groupRepository,
Account: accountRepository,
Proxy: proxyRepository,
RedeemCode: redeemCodeRepository,
UsageLog: usageLogRepository,
Setting: settingRepository,
UserSubscription: userSubscriptionRepository,
}
billingCacheService := service.NewBillingCacheService(client, userRepository, userSubscriptionRepository)
subscriptionService := service.NewSubscriptionService(repositories, billingCacheService)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, client, billingCacheService)
billingCache := repository.NewBillingCache(client)
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
redeemCache := repository.NewRedeemCache(client)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService)
redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
adminService := service.NewAdminService(repositories, billingCacheService)
dashboardHandler := admin.NewDashboardHandler(adminService, usageLogRepository)
dashboardService := service.NewDashboardService(usageLogRepository)
dashboardHandler := admin.NewDashboardHandler(dashboardService)
accountRepository := repository.NewAccountRepository(db)
proxyRepository := repository.NewProxyRepository(db)
proxyExitInfoProber := repository.NewProxyExitInfoProber()
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber)
adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService)
oAuthService := service.NewOAuthService(proxyRepository)
rateLimitService := service.NewRateLimitService(repositories, configConfig)
accountUsageService := service.NewAccountUsageService(repositories, oAuthService)
accountTestService := service.NewAccountTestService(repositories, oAuthService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, rateLimitService, accountUsageService, accountTestService)
oAuthHandler := admin.NewOAuthHandler(oAuthService, adminService)
claudeOAuthClient := repository.NewClaudeOAuthClient()
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
openAIOAuthClient := repository.NewOpenAIOAuthClient()
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig)
rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher)
geminiTokenCache := repository.NewGeminiTokenCache(client)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
httpUpstream := repository.NewHTTPUpstream(configConfig)
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, httpUpstream)
concurrencyCache := repository.NewConcurrencyCache(client)
concurrencyService := service.NewConcurrencyService(concurrencyCache)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
proxyHandler := admin.NewProxyHandler(adminService)
adminRedeemHandler := admin.NewRedeemHandler(adminService)
settingHandler := admin.NewSettingHandler(settingService, emailService)
systemHandler := handler.ProvideSystemHandler(client, buildInfo)
updateCache := repository.NewUpdateCache(client)
gitHubReleaseClient := repository.NewGitHubReleaseClient()
serviceBuildInfo := provideServiceBuildInfo(buildInfo)
updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo)
systemHandler := handler.ProvideSystemHandler(updateService)
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
adminUsageHandler := admin.NewUsageHandler(usageLogRepository, apiKeyRepository, usageService, adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
pricingService, err := service.ProvidePricingService(configConfig)
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
gatewayCache := repository.NewGatewayCache(client)
pricingRemoteClient := repository.NewPricingRemoteClient()
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil {
return nil, err
}
billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(client)
gatewayService := service.NewGatewayService(repositories, client, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService)
concurrencyService := service.NewConcurrencyService(client)
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
identityCache := repository.NewIdentityCache(client)
identityService := service.NewIdentityService(identityCache)
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, userService, concurrencyService, billingCacheService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler)
groupService := service.NewGroupService(groupRepository)
accountService := service.NewAccountService(accountRepository, groupRepository)
proxyService := service.NewProxyService(proxyRepository)
services := &service.Services{
Auth: authService,
User: userService,
ApiKey: apiKeyService,
Group: groupService,
Account: accountService,
Proxy: proxyService,
Redeem: redeemService,
Usage: usageService,
Pricing: pricingService,
Billing: billingService,
BillingCache: billingCacheService,
Admin: adminService,
Gateway: gatewayService,
OAuth: oAuthService,
RateLimit: rateLimitService,
AccountUsage: accountUsageService,
AccountTest: accountTestService,
Setting: settingService,
Email: emailService,
EmailQueue: emailQueueService,
Turnstile: turnstileService,
Subscription: subscriptionService,
Concurrency: concurrencyService,
Identity: identityService,
}
engine := server.ProvideRouter(configConfig, handlers, services, repositories)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService)
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService)
httpServer := server.ProvideHTTPServer(configConfig, engine)
v := provideCleanup(db, client, services)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
v := provideCleanup(db, client, tokenRefreshService, pricingService, emailQueueService, oAuthService, openAIOAuthService, geminiOAuthService)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -149,10 +145,22 @@ type Application struct {
Cleanup func()
}
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
return service.BuildInfo{
Version: buildInfo.Version,
BuildType: buildInfo.BuildType,
}
}
func provideCleanup(
db *gorm.DB,
rdb *redis.Client,
services *service.Services,
tokenRefresh *service.TokenRefreshService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
oauth *service.OAuthService,
openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -162,12 +170,28 @@ func provideCleanup(
name string
fn func() error
}{
{"TokenRefreshService", func() error {
tokenRefresh.Stop()
return nil
}},
{"PricingService", func() error {
services.Pricing.Stop()
pricing.Stop()
return nil
}},
{"EmailQueueService", func() error {
services.EmailQueue.Stop()
emailQueue.Stop()
return nil
}},
{"OAuthService", func() error {
oauth.Stop()
return nil
}},
{"OpenAIOAuthService", func() error {
openaiOAuth.Stop()
return nil
}},
{"GeminiOAuthService", func() error {
geminiOAuth.Stop()
return nil
}},
{"Redis", func() error {

View File

@@ -1,38 +0,0 @@
server:
host: "0.0.0.0"
port: 8080
mode: "debug" # debug/release
database:
host: "127.0.0.1"
port: 5432
user: "postgres"
password: "XZeRr7nkjHWhm8fw"
dbname: "sub2api"
sslmode: "disable"
redis:
host: "127.0.0.1"
port: 6379
password: ""
db: 0
jwt:
secret: "your-secret-key-change-in-production"
expire_hour: 24
default:
admin_email: "admin@sub2api.com"
admin_password: "admin123"
user_concurrency: 5
user_balance: 0
api_key_prefix: "sk-"
rate_multiplier: 1.0
# Timezone configuration (similar to PHP's date_default_timezone_set)
# This affects ALL time operations:
# - Database timestamps
# - Usage statistics "today" boundary
# - Subscription expiry times
# Common values: Asia/Shanghai, America/New_York, Europe/London, UTC
timezone: "Asia/Shanghai"

View File

@@ -1,4 +1,4 @@
module sub2api
module github.com/Wei-Shaw/sub2api
go 1.24.0
@@ -8,64 +8,121 @@ require (
github.com/gin-gonic/gin v1.9.1
github.com/golang-jwt/jwt/v5 v5.2.0
github.com/google/uuid v1.6.0
github.com/google/wire v0.7.0
github.com/imroc/req/v3 v3.56.0
github.com/lib/pq v1.10.9
github.com/redis/go-redis/v9 v9.3.0
github.com/redis/go-redis/v9 v9.7.3
github.com/spf13/viper v1.18.2
github.com/stretchr/testify v1.11.1
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0
github.com/testcontainers/testcontainers-go/modules/redis v0.40.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
golang.org/x/crypto v0.44.0
golang.org/x/net v0.47.0
golang.org/x/term v0.37.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/datatypes v1.2.0
gorm.io/driver/postgres v1.5.4
gorm.io/gorm v1.25.5
)
require (
dario.cat/mergo v1.0.2 // indirect
filippo.io/edwards25519 v1.1.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/containerd/errdefs v1.0.0 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v28.5.1+incompatible // indirect
github.com/docker/go-connections v0.6.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/ebitengine/purego v0.8.4 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/subcommands v1.2.0 // indirect
github.com/google/wire v0.7.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/icholy/digest v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.4.3 // indirect
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect
github.com/jackc/pgx/v5 v5.5.5 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.1 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/magiconair/properties v1.8.7 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/magiconair/properties v1.8.10 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mdelapenya/tlscert v0.2.0 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/go-archive v0.1.0 // indirect
github.com/moby/patternmatcher v0.6.0 // indirect
github.com/moby/sys/sequential v0.6.0 // indirect
github.com/moby/sys/user v0.4.0 // indirect
github.com/moby/sys/userns v0.1.0 // indirect
github.com/moby/term v0.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/quic-go/qpack v0.5.1 // indirect
github.com/quic-go/quic-go v0.56.0 // indirect
github.com/refraction-networking/utls v1.8.1 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/testcontainers/testcontainers-go v0.40.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
go.opentelemetry.io/otel v1.37.0 // indirect
go.opentelemetry.io/otel/metric v1.37.0 // indirect
go.opentelemetry.io/otel/sdk v1.37.0 // indirect
go.opentelemetry.io/otel/trace v1.37.0 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/arch v0.3.0 // indirect
@@ -75,6 +132,8 @@ require (
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
golang.org/x/tools v0.38.0 // indirect
google.golang.org/protobuf v1.31.0 // indirect
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gorm.io/driver/mysql v1.5.2 // indirect
)

View File

@@ -1,3 +1,13 @@
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
@@ -7,17 +17,43 @@ github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
@@ -28,6 +64,13 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
@@ -36,13 +79,19 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
@@ -54,6 +103,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3J18=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
@@ -62,10 +113,12 @@ github.com/imroc/req/v3 v3.56.0 h1:t6YdqqerYBXhZ9+VjqsQs5wlKxdUNEvsgBhxWc1AEEo=
github.com/imroc/req/v3 v3.56.0/go.mod h1:cUZSooE8hhzFNOrAbdxuemXDQxFXLQTnu3066jr7ZGk=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY=
github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
@@ -85,36 +138,74 @@ github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI=
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o=
github.com/microsoft/go-mssqldb v0.17.0 h1:Fto83dMZPnYv1Zwx5vHHxpNraeEaUlQ/hhHLgZiaenE=
github.com/microsoft/go-mssqldb v0.17.0/go.mod h1:OkoNGhGEs8EZqchVTtochlXruEhEOaO4S0d2sB5aeGQ=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/moby/go-archive v0.1.0 h1:Kk/5rdW/g+H8NHdJW2gsXyZ7UnzvJNOy6VKJqueWdcQ=
github.com/moby/go-archive v0.1.0/go.mod h1:G9B+YoujNohJmrIYFBpSd54GTUB4lt9S+xVQvsJyFuo=
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc=
github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw=
github.com/moby/sys/atomicwriter v0.1.0/go.mod h1:Ul8oqv2ZMNHOceF643P6FKPXeCmYtlQMvpizfsSoaWs=
github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU=
github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko=
github.com/moby/sys/user v0.4.0 h1:jhcMKit7SA80hivmFJcbB1vqmw//wU61Zdui2eQXuMs=
github.com/moby/sys/user v0.4.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs=
github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g=
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4=
github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
github.com/quic-go/quic-go v0.56.0 h1:q/TW+OLismmXAehgFLczhCDTYB3bFmua4D9lsNBWxvY=
github.com/quic-go/quic-go v0.56.0/go.mod h1:9gx5KsFQtw2oZ6GZTyh+7YEvOxWCL9WZAepnHxgAo6c=
github.com/redis/go-redis/v9 v9.3.0 h1:RiVDjmig62jIWp7Kk4XVLs0hzV6pI3PyTnnL0cnn0u0=
github.com/redis/go-redis/v9 v9.3.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
@@ -128,6 +219,8 @@ github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMV
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
@@ -135,16 +228,55 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU=
github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY=
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 h1:s2bIayFXlbDFexo96y+htn7FzuhpXLYJNnIuglNKqOk=
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0/go.mod h1:h+u/2KoREGTnTl9UwrQ/g+XhasAT8E6dClclAADeXoQ=
github.com/testcontainers/testcontainers-go/modules/redis v0.40.0 h1:OG4qwcxp2O0re7V7M9lY9w0v6wWgWf7j7rtkpAnGMd0=
github.com/testcontainers/testcontainers-go/modules/redis v0.40.0/go.mod h1:Bc+EDhKMo5zI5V5zdBkHiMVzeAXbtI4n5isS/nzf6zw=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw=
go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ=
go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU=
go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE=
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
@@ -164,8 +296,14 @@ golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
@@ -177,9 +315,15 @@ golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 h1:i8QOKZfYg6AbGVZzUAY3LrNWCKF8O6zFisU9Wl9RER4=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4/go.mod h1:HSkG/KdJWusxU1F6CNrwNDjBMgisKxGnc5dAZfT0mjQ=
google.golang.org/grpc v1.75.1 h1:/ODCNEuf9VghjgO3rqLcfg8fiOP0nSluljWFlDxELLI=
google.golang.org/grpc v1.75.1/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ=
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
@@ -188,10 +332,19 @@ gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/datatypes v1.2.0 h1:5YT+eokWdIxhJgWHdrb2zYUimyk0+TaFth+7a0ybzco=
gorm.io/datatypes v1.2.0/go.mod h1:o1dh0ZvjIjhH/bngTpypG6lVRJ5chTBxE09FH/71k04=
gorm.io/driver/mysql v1.5.2 h1:QC2HRskSE75wBuOxe0+iCkyJZ+RqpudsQtqkp+IMuXs=
gorm.io/driver/mysql v1.5.2/go.mod h1:pQLhh1Ut/WUAySdTHwBpBv6+JKcj+ua4ZFx1QQTBzb8=
gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo=
gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0=
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
gorm.io/driver/sqlserver v1.4.1 h1:t4r4r6Jam5E6ejqP7N82qAJIJAht27EGT41HyPfXRw0=
gorm.io/driver/sqlserver v1.4.1/go.mod h1:DJ4P+MeZbc5rvY58PnmN1Lnyvb5gw5NPzGshHDnJLig=
gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

View File

@@ -8,15 +8,41 @@ import (
)
type Config struct {
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
JWT JWTConfig `mapstructure:"jwt"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
JWT JWTConfig `mapstructure:"jwt"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Gemini GeminiConfig `mapstructure:"gemini"`
}
type GeminiConfig struct {
OAuth GeminiOAuthConfig `mapstructure:"oauth"`
}
type GeminiOAuthConfig struct {
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
Scopes string `mapstructure:"scopes"`
}
// TokenRefreshConfig OAuth token自动刷新配置
type TokenRefreshConfig struct {
// 是否启用自动刷新
Enabled bool `mapstructure:"enabled"`
// 检查间隔(分钟)
CheckIntervalMinutes int `mapstructure:"check_interval_minutes"`
// 提前刷新时间小时在token过期前多久开始刷新
RefreshBeforeExpiryHours float64 `mapstructure:"refresh_before_expiry_hours"`
// 最大重试次数
MaxRetries int `mapstructure:"max_retries"`
// 重试退避基础时间(秒)
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
}
type PricingConfig struct {
@@ -37,7 +63,7 @@ type PricingConfig struct {
type ServerConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Mode string `mapstructure:"mode"` // debug/release
Mode string `mapstructure:"mode"` // debug/release
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
}
@@ -148,7 +174,7 @@ func setDefaults() {
viper.SetDefault("server.port", 8080)
viper.SetDefault("server.mode", "debug")
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
// Database
viper.SetDefault("database.host", "localhost")
@@ -192,6 +218,20 @@ func setDefaults() {
// Gateway
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头LLM高负载时可能排队较久
// TokenRefresh
viper.SetDefault("token_refresh.enabled", true)
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新适配Google 1小时token
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
// Gemini OAuth - configure via environment variables or config file
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
// Default: uses Gemini CLI public credentials (set via environment)
viper.SetDefault("gemini.oauth.client_id", "")
viper.SetDefault("gemini.oauth.client_secret", "")
viper.SetDefault("gemini.oauth.scopes", "")
}
func (c *Config) Validate() error {
@@ -203,3 +243,29 @@ func (c *Config) Validate() error {
}
return nil
}
// GetServerAddress returns the server address (host:port) from config file or environment variable.
// This is a lightweight function that can be used before full config validation,
// such as during setup wizard startup.
// Priority: config.yaml > environment variables > defaults
func GetServerAddress() string {
v := viper.New()
v.SetConfigName("config")
v.SetConfigType("yaml")
v.AddConfigPath(".")
v.AddConfigPath("./config")
v.AddConfigPath("/etc/sub2api")
// Support SERVER_HOST and SERVER_PORT environment variables
v.AutomaticEnv()
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.SetDefault("server.host", "0.0.0.0")
v.SetDefault("server.port", 8080)
// Try to read config file (ignore errors if not found)
_ = v.ReadInConfig()
host := v.GetString("server.host")
port := v.GetInt("server.port")
return fmt.Sprintf("%s:%d", host, port)
}

View File

@@ -2,9 +2,15 @@ package admin
import (
"strconv"
"strings"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -12,14 +18,12 @@ import (
// OAuthHandler handles OAuth-related operations for accounts
type OAuthHandler struct {
oauthService *service.OAuthService
adminService service.AdminService
}
// NewOAuthHandler creates a new OAuth handler
func NewOAuthHandler(oauthService *service.OAuthService, adminService service.AdminService) *OAuthHandler {
func NewOAuthHandler(oauthService *service.OAuthService) *OAuthHandler {
return &OAuthHandler{
oauthService: oauthService,
adminService: adminService,
}
}
@@ -27,47 +31,84 @@ func NewOAuthHandler(oauthService *service.OAuthService, adminService service.Ad
type AccountHandler struct {
adminService service.AdminService
oauthService *service.OAuthService
openaiOAuthService *service.OpenAIOAuthService
geminiOAuthService *service.GeminiOAuthService
rateLimitService *service.RateLimitService
accountUsageService *service.AccountUsageService
accountTestService *service.AccountTestService
concurrencyService *service.ConcurrencyService
crsSyncService *service.CRSSyncService
}
// NewAccountHandler creates a new admin account handler
func NewAccountHandler(adminService service.AdminService, oauthService *service.OAuthService, rateLimitService *service.RateLimitService, accountUsageService *service.AccountUsageService, accountTestService *service.AccountTestService) *AccountHandler {
func NewAccountHandler(
adminService service.AdminService,
oauthService *service.OAuthService,
openaiOAuthService *service.OpenAIOAuthService,
geminiOAuthService *service.GeminiOAuthService,
rateLimitService *service.RateLimitService,
accountUsageService *service.AccountUsageService,
accountTestService *service.AccountTestService,
concurrencyService *service.ConcurrencyService,
crsSyncService *service.CRSSyncService,
) *AccountHandler {
return &AccountHandler{
adminService: adminService,
oauthService: oauthService,
openaiOAuthService: openaiOAuthService,
geminiOAuthService: geminiOAuthService,
rateLimitService: rateLimitService,
accountUsageService: accountUsageService,
accountTestService: accountTestService,
concurrencyService: concurrencyService,
crsSyncService: crsSyncService,
}
}
// CreateAccountRequest represents create account request
type CreateAccountRequest struct {
Name string `json:"name" binding:"required"`
Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
Credentials map[string]interface{} `json:"credentials" binding:"required"`
Extra map[string]interface{} `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
Name string `json:"name" binding:"required"`
Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
Credentials map[string]any `json:"credentials" binding:"required"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
}
// UpdateAccountRequest represents update account request
// 使用指针类型来区分"未提供"和"设置为0"
type UpdateAccountRequest struct {
Name string `json:"name"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
Credentials map[string]interface{} `json:"credentials"`
Extra map[string]interface{} `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
GroupIDs *[]int64 `json:"group_ids"`
Name string `json:"name"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
GroupIDs *[]int64 `json:"group_ids"`
}
// BulkUpdateAccountsRequest represents the payload for bulk editing accounts
type BulkUpdateAccountsRequest struct {
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
Name string `json:"name"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
GroupIDs *[]int64 `json:"group_ids"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
}
// AccountWithConcurrency extends Account with real-time concurrency info
type AccountWithConcurrency struct {
*dto.Account
CurrentConcurrency int `json:"current_concurrency"`
}
// List handles listing all accounts with pagination
@@ -81,11 +122,32 @@ func (h *AccountHandler) List(c *gin.Context) {
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
if err != nil {
response.InternalError(c, "Failed to list accounts: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Paginated(c, accounts, total, page, pageSize)
// Get current concurrency counts for all accounts
accountIDs := make([]int64, len(accounts))
for i, acc := range accounts {
accountIDs[i] = acc.ID
}
concurrencyCounts, err := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs)
if err != nil {
// Log error but don't fail the request, just use 0 for all
concurrencyCounts = make(map[int64]int)
}
// Build response with concurrency info
result := make([]AccountWithConcurrency, len(accounts))
for i := range accounts {
result[i] = AccountWithConcurrency{
Account: dto.AccountFromService(&accounts[i]),
CurrentConcurrency: concurrencyCounts[accounts[i].ID],
}
}
response.Paginated(c, result, total, page, pageSize)
}
// GetByID handles getting an account by ID
@@ -99,11 +161,11 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.NotFound(c, "Account not found")
response.ErrorFrom(c, err)
return
}
response.Success(c, account)
response.Success(c, dto.AccountFromService(account))
}
// Create handles creating a new account
@@ -127,11 +189,11 @@ func (h *AccountHandler) Create(c *gin.Context) {
GroupIDs: req.GroupIDs,
})
if err != nil {
response.BadRequest(c, "Failed to create account: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, account)
response.Success(c, dto.AccountFromService(account))
}
// Update handles updating an account
@@ -161,11 +223,11 @@ func (h *AccountHandler) Update(c *gin.Context) {
GroupIDs: req.GroupIDs,
})
if err != nil {
response.InternalError(c, "Failed to update account: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, account)
response.Success(c, dto.AccountFromService(account))
}
// Delete handles deleting an account
@@ -179,13 +241,25 @@ func (h *AccountHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteAccount(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to delete account: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Account deleted successfully"})
}
// TestAccountRequest represents the request body for testing an account
type TestAccountRequest struct {
ModelID string `json:"model_id"`
}
type SyncFromCRSRequest struct {
BaseURL string `json:"base_url" binding:"required"`
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
SyncProxies *bool `json:"sync_proxies"`
}
// Test handles testing account connectivity with SSE streaming
// POST /api/v1/admin/accounts/:id/test
func (h *AccountHandler) Test(c *gin.Context) {
@@ -195,13 +269,46 @@ func (h *AccountHandler) Test(c *gin.Context) {
return
}
var req TestAccountRequest
// Allow empty body, model_id is optional
_ = c.ShouldBindJSON(&req)
// Use AccountTestService to test the account with SSE streaming
if err := h.accountTestService.TestAccountConnection(c, accountID); err != nil {
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID); err != nil {
// Error already sent via SSE, just log
return
}
}
// SyncFromCRS handles syncing accounts from claude-relay-service (CRS)
// POST /api/v1/admin/accounts/sync/crs
func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
var req SyncFromCRSRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Default to syncing proxies (can be disabled by explicitly setting false)
syncProxies := true
if req.SyncProxies != nil {
syncProxies = *req.SyncProxies
}
result, err := h.crsSyncService.SyncFromCRS(c.Request.Context(), service.SyncFromCRSInput{
BaseURL: req.BaseURL,
Username: req.Username,
Password: req.Password,
SyncProxies: syncProxies,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// Refresh handles refreshing account credentials
// POST /api/v1/admin/accounts/:id/refresh
func (h *AccountHandler) Refresh(c *gin.Context) {
@@ -224,32 +331,74 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
return
}
// Use OAuth service to refresh token
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
return
}
var newCredentials map[string]any
// Update account credentials
newCredentials := map[string]interface{}{
"access_token": tokenInfo.AccessToken,
"token_type": tokenInfo.TokenType,
"expires_in": tokenInfo.ExpiresIn,
"expires_at": tokenInfo.ExpiresAt,
"refresh_token": tokenInfo.RefreshToken,
"scope": tokenInfo.Scope,
if account.IsOpenAI() {
// Use OpenAI OAuth service to refresh token
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Build new credentials from token info
newCredentials = h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
// Preserve non-token settings from existing credentials
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
} else if account.Platform == service.PlatformGemini {
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
return
}
newCredentials = h.geminiOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
} else {
// Use Anthropic/Claude OAuth service to refresh token
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
newCredentials = make(map[string]any)
for k, v := range account.Credentials {
newCredentials[k] = v
}
// Update token-related fields
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
if strings.TrimSpace(tokenInfo.RefreshToken) != "" {
newCredentials["refresh_token"] = tokenInfo.RefreshToken
}
if strings.TrimSpace(tokenInfo.Scope) != "" {
newCredentials["scope"] = tokenInfo.Scope
}
}
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: newCredentials,
})
if err != nil {
response.InternalError(c, "Failed to update account credentials: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, updatedAccount)
response.Success(c, dto.AccountFromService(updatedAccount))
}
// GetStats handles getting account statistics
@@ -261,15 +410,26 @@ func (h *AccountHandler) GetStats(c *gin.Context) {
return
}
// Return mock data for now
_ = accountID
response.Success(c, gin.H{
"total_requests": 0,
"successful_requests": 0,
"failed_requests": 0,
"total_tokens": 0,
"average_response_time": 0,
})
// Parse days parameter (default 30)
days := 30
if daysStr := c.Query("days"); daysStr != "" {
if d, err := strconv.Atoi(daysStr); err == nil && d > 0 && d <= 90 {
days = d
}
}
// Calculate time range
now := timezone.Now()
endTime := timezone.StartOfDay(now.AddDate(0, 0, 1))
startTime := timezone.StartOfDay(now.AddDate(0, 0, -days+1))
stats, err := h.accountUsageService.GetAccountUsageStats(c.Request.Context(), accountID, startTime, endTime)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, stats)
}
// ClearError handles clearing account error
@@ -283,11 +443,11 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
account, err := h.adminService.ClearAccountError(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to clear error: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, account)
response.Success(c, dto.AccountFromService(account))
}
// BatchCreate handles batch creating accounts
@@ -309,6 +469,136 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
})
}
// BatchUpdateCredentialsRequest represents batch credentials update request
type BatchUpdateCredentialsRequest struct {
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
Field string `json:"field" binding:"required,oneof=account_uuid org_uuid intercept_warmup_requests"`
Value any `json:"value"`
}
// BatchUpdateCredentials handles batch updating credentials fields
// POST /api/v1/admin/accounts/batch-update-credentials
func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
var req BatchUpdateCredentialsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Validate value type based on field
if req.Field == "intercept_warmup_requests" {
// Must be boolean
if _, ok := req.Value.(bool); !ok {
response.BadRequest(c, "intercept_warmup_requests must be boolean")
return
}
} else {
// account_uuid and org_uuid can be string or null
if req.Value != nil {
if _, ok := req.Value.(string); !ok {
response.BadRequest(c, req.Field+" must be string or null")
return
}
}
}
ctx := c.Request.Context()
success := 0
failed := 0
results := []gin.H{}
for _, accountID := range req.AccountIDs {
// Get account
account, err := h.adminService.GetAccount(ctx, accountID)
if err != nil {
failed++
results = append(results, gin.H{
"account_id": accountID,
"success": false,
"error": "Account not found",
})
continue
}
// Update credentials field
if account.Credentials == nil {
account.Credentials = make(map[string]any)
}
account.Credentials[req.Field] = req.Value
// Update account
updateInput := &service.UpdateAccountInput{
Credentials: account.Credentials,
}
_, err = h.adminService.UpdateAccount(ctx, accountID, updateInput)
if err != nil {
failed++
results = append(results, gin.H{
"account_id": accountID,
"success": false,
"error": err.Error(),
})
continue
}
success++
results = append(results, gin.H{
"account_id": accountID,
"success": true,
})
}
response.Success(c, gin.H{
"success": success,
"failed": failed,
"results": results,
})
}
// BulkUpdate handles bulk updating accounts with selected fields/credentials.
// POST /api/v1/admin/accounts/bulk-update
func (h *AccountHandler) BulkUpdate(c *gin.Context) {
var req BulkUpdateAccountsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
hasUpdates := req.Name != "" ||
req.ProxyID != nil ||
req.Concurrency != nil ||
req.Priority != nil ||
req.Status != "" ||
req.GroupIDs != nil ||
len(req.Credentials) > 0 ||
len(req.Extra) > 0
if !hasUpdates {
response.BadRequest(c, "No updates provided")
return
}
result, err := h.adminService.BulkUpdateAccounts(c.Request.Context(), &service.BulkUpdateAccountsInput{
AccountIDs: req.AccountIDs,
Name: req.Name,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
Status: req.Status,
GroupIDs: req.GroupIDs,
Credentials: req.Credentials,
Extra: req.Extra,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// ========== OAuth Handlers ==========
// GenerateAuthURLRequest represents the request for generating auth URL
@@ -327,7 +617,7 @@ func (h *OAuthHandler) GenerateAuthURL(c *gin.Context) {
result, err := h.oauthService.GenerateAuthURL(c.Request.Context(), req.ProxyID)
if err != nil {
response.InternalError(c, "Failed to generate auth URL: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -345,7 +635,7 @@ func (h *OAuthHandler) GenerateSetupTokenURL(c *gin.Context) {
result, err := h.oauthService.GenerateSetupTokenURL(c.Request.Context(), req.ProxyID)
if err != nil {
response.InternalError(c, "Failed to generate setup token URL: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -374,7 +664,7 @@ func (h *OAuthHandler) ExchangeCode(c *gin.Context) {
ProxyID: req.ProxyID,
})
if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -396,7 +686,7 @@ func (h *OAuthHandler) ExchangeSetupTokenCode(c *gin.Context) {
ProxyID: req.ProxyID,
})
if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -424,7 +714,7 @@ func (h *OAuthHandler) CookieAuth(c *gin.Context) {
Scope: "full",
})
if err != nil {
response.BadRequest(c, "Cookie auth failed: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -446,7 +736,7 @@ func (h *OAuthHandler) SetupTokenCookieAuth(c *gin.Context) {
Scope: "inference",
})
if err != nil {
response.BadRequest(c, "Cookie auth failed: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -464,7 +754,7 @@ func (h *AccountHandler) GetUsage(c *gin.Context) {
usage, err := h.accountUsageService.GetUsage(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to get usage: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -482,7 +772,7 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
err = h.rateLimitService.ClearRateLimit(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to clear rate limit: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -500,7 +790,7 @@ func (h *AccountHandler) GetTodayStats(c *gin.Context) {
stats, err := h.accountUsageService.GetTodayStats(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to get today stats: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -529,9 +819,142 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
account, err := h.adminService.SetAccountSchedulable(c.Request.Context(), accountID, req.Schedulable)
if err != nil {
response.InternalError(c, "Failed to update schedulable status: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, account)
response.Success(c, dto.AccountFromService(account))
}
// GetAvailableModels handles getting available models for an account
// GET /api/v1/admin/accounts/:id/models
func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.NotFound(c, "Account not found")
return
}
// Handle OpenAI accounts
if account.IsOpenAI() {
// For OAuth accounts: return default OpenAI models
if account.IsOAuth() {
response.Success(c, openai.DefaultModels)
return
}
// For API Key accounts: check model_mapping
mapping := account.GetModelMapping()
if len(mapping) == 0 {
response.Success(c, openai.DefaultModels)
return
}
// Return mapped models
var models []openai.Model
for requestedModel := range mapping {
var found bool
for _, dm := range openai.DefaultModels {
if dm.ID == requestedModel {
models = append(models, dm)
found = true
break
}
}
if !found {
models = append(models, openai.Model{
ID: requestedModel,
Object: "model",
Type: "model",
DisplayName: requestedModel,
})
}
}
response.Success(c, models)
return
}
// Handle Gemini accounts
if account.IsGemini() {
// For OAuth accounts: return default Gemini models
if account.IsOAuth() {
response.Success(c, geminicli.DefaultModels)
return
}
// For API Key accounts: return models based on model_mapping
mapping := account.GetModelMapping()
if len(mapping) == 0 {
response.Success(c, geminicli.DefaultModels)
return
}
var models []geminicli.Model
for requestedModel := range mapping {
var found bool
for _, dm := range geminicli.DefaultModels {
if dm.ID == requestedModel {
models = append(models, dm)
found = true
break
}
}
if !found {
models = append(models, geminicli.Model{
ID: requestedModel,
Type: "model",
DisplayName: requestedModel,
CreatedAt: "",
})
}
}
response.Success(c, models)
return
}
// Handle Claude/Anthropic accounts
// For OAuth and Setup-Token accounts: return default models
if account.IsOAuth() {
response.Success(c, claude.DefaultModels)
return
}
// For API Key accounts: return models based on model_mapping
mapping := account.GetModelMapping()
if len(mapping) == 0 {
// No mapping configured, return default models
response.Success(c, claude.DefaultModels)
return
}
// Return mapped models (keys of the mapping are the available model IDs)
var models []claude.Model
for requestedModel := range mapping {
// Try to find display info from default models
var found bool
for _, dm := range claude.DefaultModels {
if dm.ID == requestedModel {
models = append(models, dm)
found = true
break
}
}
// If not found in defaults, create a basic entry
if !found {
models = append(models, claude.Model{
ID: requestedModel,
Type: "model",
DisplayName: requestedModel,
CreatedAt: "",
})
}
}
response.Success(c, models)
}

View File

@@ -2,28 +2,26 @@ package admin
import (
"strconv"
"sub2api/internal/pkg/response"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"sub2api/internal/service"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// DashboardHandler handles admin dashboard statistics
type DashboardHandler struct {
adminService service.AdminService
usageRepo *repository.UsageLogRepository
startTime time.Time // Server start time for uptime calculation
dashboardService *service.DashboardService
startTime time.Time // Server start time for uptime calculation
}
// NewDashboardHandler creates a new admin dashboard handler
func NewDashboardHandler(adminService service.AdminService, usageRepo *repository.UsageLogRepository) *DashboardHandler {
func NewDashboardHandler(dashboardService *service.DashboardService) *DashboardHandler {
return &DashboardHandler{
adminService: adminService,
usageRepo: usageRepo,
startTime: time.Now(),
dashboardService: dashboardService,
startTime: time.Now(),
}
}
@@ -61,7 +59,7 @@ func parseTimeRange(c *gin.Context) (time.Time, time.Time) {
// GetStats handles getting dashboard statistics
// GET /api/v1/admin/dashboard/stats
func (h *DashboardHandler) GetStats(c *gin.Context) {
stats, err := h.usageRepo.GetDashboardStats(c.Request.Context())
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
if err != nil {
response.Error(c, 500, "Failed to get dashboard statistics")
return
@@ -110,6 +108,10 @@ func (h *DashboardHandler) GetStats(c *gin.Context) {
// 系统运行统计
"average_duration_ms": stats.AverageDurationMs,
"uptime": uptime,
// 性能指标
"rpm": stats.Rpm,
"tpm": stats.Tpm,
})
}
@@ -127,12 +129,25 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
// GetUsageTrend handles getting usage trend data
// GET /api/v1/admin/dashboard/trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour)
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
trend, err := h.usageRepo.GetUsageTrend(c.Request.Context(), startTime, endTime, granularity)
// Parse optional filter params
var userID, apiKeyID int64
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = id
}
}
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
apiKeyID = id
}
}
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
@@ -148,11 +163,24 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// GetModelStats handles getting model usage statistics
// GET /api/v1/admin/dashboard/models
// Query params: start_date, end_date (YYYY-MM-DD)
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
stats, err := h.usageRepo.GetModelStats(c.Request.Context(), startTime, endTime)
// Parse optional filter params
var userID, apiKeyID int64
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = id
}
}
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
apiKeyID = id
}
}
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
@@ -177,7 +205,7 @@ func (h *DashboardHandler) GetApiKeyUsageTrend(c *gin.Context) {
limit = 5
}
trend, err := h.usageRepo.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
trend, err := h.dashboardService.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
if err != nil {
response.Error(c, 500, "Failed to get API key usage trend")
return
@@ -203,7 +231,7 @@ func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
limit = 12
}
trend, err := h.usageRepo.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
trend, err := h.dashboardService.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
if err != nil {
response.Error(c, 500, "Failed to get user usage trend")
return
@@ -232,11 +260,11 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
}
if len(req.UserIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]interface{}{}})
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
stats, err := h.usageRepo.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
if err != nil {
response.Error(c, 500, "Failed to get user usage stats")
return
@@ -260,11 +288,11 @@ func (h *DashboardHandler) GetBatchApiKeysUsage(c *gin.Context) {
}
if len(req.ApiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]interface{}{}})
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
stats, err := h.usageRepo.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs)
stats, err := h.dashboardService.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs)
if err != nil {
response.Error(c, 500, "Failed to get API key usage stats")
return

View File

@@ -0,0 +1,135 @@
package admin
import (
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type GeminiOAuthHandler struct {
geminiOAuthService *service.GeminiOAuthService
}
func NewGeminiOAuthHandler(geminiOAuthService *service.GeminiOAuthService) *GeminiOAuthHandler {
return &GeminiOAuthHandler{geminiOAuthService: geminiOAuthService}
}
// GET /api/v1/admin/gemini/oauth/capabilities
func (h *GeminiOAuthHandler) GetCapabilities(c *gin.Context) {
cfg := h.geminiOAuthService.GetOAuthConfig()
response.Success(c, cfg)
}
type GeminiGenerateAuthURLRequest struct {
ProxyID *int64 `json:"proxy_id"`
ProjectID string `json:"project_id"`
// OAuth 类型: "code_assist" (需要 project_id) 或 "ai_studio" (不需要 project_id)
// 默认为 "code_assist" 以保持向后兼容
OAuthType string `json:"oauth_type"`
}
// GenerateAuthURL generates Google OAuth authorization URL for Gemini.
// POST /api/v1/admin/gemini/oauth/auth-url
func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
var req GeminiGenerateAuthURLRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// 默认使用 code_assist 以保持向后兼容
oauthType := strings.TrimSpace(req.OAuthType)
if oauthType == "" {
oauthType = "code_assist"
}
if oauthType != "code_assist" && oauthType != "ai_studio" {
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist' or 'ai_studio'")
return
}
// Always pass the "hosted" callback URI; the OAuth service may override it depending on
// oauth_type and whether the built-in Gemini CLI OAuth client is used.
redirectURI := deriveGeminiRedirectURI(c)
result, err := h.geminiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, redirectURI, req.ProjectID, oauthType)
if err != nil {
msg := err.Error()
// Treat missing/invalid OAuth client configuration as a user/config error.
if strings.Contains(msg, "OAuth client not configured") || strings.Contains(msg, "requires your own OAuth Client") {
response.BadRequest(c, "Failed to generate auth URL: "+msg)
return
}
response.InternalError(c, "Failed to generate auth URL: "+msg)
return
}
response.Success(c, result)
}
type GeminiExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"`
State string `json:"state" binding:"required"`
Code string `json:"code" binding:"required"`
ProxyID *int64 `json:"proxy_id"`
// OAuth 类型: "code_assist" 或 "ai_studio",需要与 GenerateAuthURL 时的类型一致
OAuthType string `json:"oauth_type"`
}
// ExchangeCode exchanges authorization code for tokens.
// POST /api/v1/admin/gemini/oauth/exchange-code
func (h *GeminiOAuthHandler) ExchangeCode(c *gin.Context) {
var req GeminiExchangeCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// 默认使用 code_assist 以保持向后兼容
oauthType := strings.TrimSpace(req.OAuthType)
if oauthType == "" {
oauthType = "code_assist"
}
if oauthType != "code_assist" && oauthType != "ai_studio" {
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist' or 'ai_studio'")
return
}
tokenInfo, err := h.geminiOAuthService.ExchangeCode(c.Request.Context(), &service.GeminiExchangeCodeInput{
SessionID: req.SessionID,
State: req.State,
Code: req.Code,
ProxyID: req.ProxyID,
OAuthType: oauthType,
})
if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error())
return
}
response.Success(c, tokenInfo)
}
func deriveGeminiRedirectURI(c *gin.Context) string {
origin := strings.TrimSpace(c.GetHeader("Origin"))
if origin != "" {
return strings.TrimRight(origin, "/") + "/auth/callback"
}
scheme := "http"
if c.Request.TLS != nil {
scheme = "https"
}
if xfProto := strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")); xfProto != "" {
scheme = strings.TrimSpace(strings.Split(xfProto, ",")[0])
}
host := strings.TrimSpace(c.Request.Host)
if xfHost := strings.TrimSpace(c.GetHeader("X-Forwarded-Host")); xfHost != "" {
host = strings.TrimSpace(strings.Split(xfHost, ",")[0])
}
return fmt.Sprintf("%s://%s/auth/callback", scheme, host)
}

View File

@@ -3,9 +3,9 @@ package admin
import (
"strconv"
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -65,11 +65,15 @@ func (h *GroupHandler) List(c *gin.Context) {
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, isExclusive)
if err != nil {
response.InternalError(c, "Failed to list groups: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Paginated(c, groups, total, page, pageSize)
outGroups := make([]dto.Group, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
}
response.Paginated(c, outGroups, total, page, pageSize)
}
// GetAll handles getting all active groups without pagination
@@ -77,7 +81,7 @@ func (h *GroupHandler) List(c *gin.Context) {
func (h *GroupHandler) GetAll(c *gin.Context) {
platform := c.Query("platform")
var groups []model.Group
var groups []service.Group
var err error
if platform != "" {
@@ -87,11 +91,15 @@ func (h *GroupHandler) GetAll(c *gin.Context) {
}
if err != nil {
response.InternalError(c, "Failed to get groups: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, groups)
outGroups := make([]dto.Group, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
}
response.Success(c, outGroups)
}
// GetByID handles getting a group by ID
@@ -105,11 +113,11 @@ func (h *GroupHandler) GetByID(c *gin.Context) {
group, err := h.adminService.GetGroup(c.Request.Context(), groupID)
if err != nil {
response.NotFound(c, "Group not found")
response.ErrorFrom(c, err)
return
}
response.Success(c, group)
response.Success(c, dto.GroupFromService(group))
}
// Create handles creating a new group
@@ -133,11 +141,11 @@ func (h *GroupHandler) Create(c *gin.Context) {
MonthlyLimitUSD: req.MonthlyLimitUSD,
})
if err != nil {
response.BadRequest(c, "Failed to create group: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, group)
response.Success(c, dto.GroupFromService(group))
}
// Update handles updating a group
@@ -168,11 +176,11 @@ func (h *GroupHandler) Update(c *gin.Context) {
MonthlyLimitUSD: req.MonthlyLimitUSD,
})
if err != nil {
response.InternalError(c, "Failed to update group: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, group)
response.Success(c, dto.GroupFromService(group))
}
// Delete handles deleting a group
@@ -186,7 +194,7 @@ func (h *GroupHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteGroup(c.Request.Context(), groupID)
if err != nil {
response.InternalError(c, "Failed to delete group: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -225,9 +233,13 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
keys, total, err := h.adminService.GetGroupAPIKeys(c.Request.Context(), groupID, page, pageSize)
if err != nil {
response.InternalError(c, "Failed to get group API keys: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Paginated(c, keys, total, page, pageSize)
outKeys := make([]dto.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *dto.ApiKeyFromService(&keys[i]))
}
response.Paginated(c, outKeys, total, page, pageSize)
}

View File

@@ -0,0 +1,229 @@
package admin
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// OpenAIOAuthHandler handles OpenAI OAuth-related operations
type OpenAIOAuthHandler struct {
openaiOAuthService *service.OpenAIOAuthService
adminService service.AdminService
}
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
return &OpenAIOAuthHandler{
openaiOAuthService: openaiOAuthService,
adminService: adminService,
}
}
// OpenAIGenerateAuthURLRequest represents the request for generating OpenAI auth URL
type OpenAIGenerateAuthURLRequest struct {
ProxyID *int64 `json:"proxy_id"`
RedirectURI string `json:"redirect_uri"`
}
// GenerateAuthURL generates OpenAI OAuth authorization URL
// POST /api/v1/admin/openai/generate-auth-url
func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
var req OpenAIGenerateAuthURLRequest
if err := c.ShouldBindJSON(&req); err != nil {
// Allow empty body
req = OpenAIGenerateAuthURLRequest{}
}
result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// OpenAIExchangeCodeRequest represents the request for exchanging OpenAI auth code
type OpenAIExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"`
RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"`
}
// ExchangeCode exchanges OpenAI authorization code for tokens
// POST /api/v1/admin/openai/exchange-code
func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
var req OpenAIExchangeCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID,
Code: req.Code,
RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, tokenInfo)
}
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
type OpenAIRefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
ProxyID *int64 `json:"proxy_id"`
}
// RefreshToken refreshes an OpenAI OAuth token
// POST /api/v1/admin/openai/refresh-token
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
var req OpenAIRefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
var proxyURL string
if req.ProxyID != nil {
proxy, err := h.adminService.GetProxy(c.Request.Context(), *req.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, tokenInfo)
}
// RefreshAccountToken refreshes token for a specific OpenAI account
// POST /api/v1/admin/openai/accounts/:id/refresh
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
// Get account
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Ensure account is OpenAI platform
if !account.IsOpenAI() {
response.BadRequest(c, "Account is not an OpenAI account")
return
}
// Only refresh OAuth-based accounts
if !account.IsOAuth() {
response.BadRequest(c, "Cannot refresh non-OAuth account credentials")
return
}
// Use OpenAI OAuth service to refresh token
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Build new credentials from token info
newCredentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
// Preserve non-token settings from existing credentials
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: newCredentials,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.AccountFromService(updatedAccount))
}
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
// POST /api/v1/admin/openai/create-from-oauth
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
var req struct {
SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"`
RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"`
Name string `json:"name"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Exchange code for tokens
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID,
Code: req.Code,
RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
// Build credentials from token info
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
// Use email as default name if not provided
name := req.Name
if name == "" && tokenInfo.Email != "" {
name = tokenInfo.Email
}
if name == "" {
name = "OpenAI OAuth Account"
}
// Create account
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: name,
Platform: "openai",
Type: "oauth",
Credentials: credentials,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
GroupIDs: req.GroupIDs,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.AccountFromService(account))
}

View File

@@ -2,9 +2,11 @@ package admin
import (
"strconv"
"strings"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -52,11 +54,15 @@ func (h *ProxyHandler) List(c *gin.Context) {
proxies, total, err := h.adminService.ListProxies(c.Request.Context(), page, pageSize, protocol, status, search)
if err != nil {
response.InternalError(c, "Failed to list proxies: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Paginated(c, proxies, total, page, pageSize)
out := make([]dto.Proxy, 0, len(proxies))
for i := range proxies {
out = append(out, *dto.ProxyFromService(&proxies[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
// GetAll handles getting all active proxies without pagination
@@ -68,20 +74,28 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
if withCount {
proxies, err := h.adminService.GetAllProxiesWithAccountCount(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get proxies: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, proxies)
out := make([]dto.ProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i]))
}
response.Success(c, out)
return
}
proxies, err := h.adminService.GetAllProxies(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get proxies: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, proxies)
out := make([]dto.Proxy, 0, len(proxies))
for i := range proxies {
out = append(out, *dto.ProxyFromService(&proxies[i]))
}
response.Success(c, out)
}
// GetByID handles getting a proxy by ID
@@ -95,11 +109,11 @@ func (h *ProxyHandler) GetByID(c *gin.Context) {
proxy, err := h.adminService.GetProxy(c.Request.Context(), proxyID)
if err != nil {
response.NotFound(c, "Proxy not found")
response.ErrorFrom(c, err)
return
}
response.Success(c, proxy)
response.Success(c, dto.ProxyFromService(proxy))
}
// Create handles creating a new proxy
@@ -112,19 +126,19 @@ func (h *ProxyHandler) Create(c *gin.Context) {
}
proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
Name: req.Name,
Protocol: req.Protocol,
Host: req.Host,
Name: strings.TrimSpace(req.Name),
Protocol: strings.TrimSpace(req.Protocol),
Host: strings.TrimSpace(req.Host),
Port: req.Port,
Username: req.Username,
Password: req.Password,
Username: strings.TrimSpace(req.Username),
Password: strings.TrimSpace(req.Password),
})
if err != nil {
response.BadRequest(c, "Failed to create proxy: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, proxy)
response.Success(c, dto.ProxyFromService(proxy))
}
// Update handles updating a proxy
@@ -143,20 +157,20 @@ func (h *ProxyHandler) Update(c *gin.Context) {
}
proxy, err := h.adminService.UpdateProxy(c.Request.Context(), proxyID, &service.UpdateProxyInput{
Name: req.Name,
Protocol: req.Protocol,
Host: req.Host,
Name: strings.TrimSpace(req.Name),
Protocol: strings.TrimSpace(req.Protocol),
Host: strings.TrimSpace(req.Host),
Port: req.Port,
Username: req.Username,
Password: req.Password,
Status: req.Status,
Username: strings.TrimSpace(req.Username),
Password: strings.TrimSpace(req.Password),
Status: strings.TrimSpace(req.Status),
})
if err != nil {
response.InternalError(c, "Failed to update proxy: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, proxy)
response.Success(c, dto.ProxyFromService(proxy))
}
// Delete handles deleting a proxy
@@ -170,7 +184,7 @@ func (h *ProxyHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteProxy(c.Request.Context(), proxyID)
if err != nil {
response.InternalError(c, "Failed to delete proxy: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -188,7 +202,7 @@ func (h *ProxyHandler) Test(c *gin.Context) {
result, err := h.adminService.TestProxy(c.Request.Context(), proxyID)
if err != nil {
response.InternalError(c, "Failed to test proxy: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -228,14 +242,17 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
accounts, total, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID, page, pageSize)
if err != nil {
response.InternalError(c, "Failed to get proxy accounts: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Paginated(c, accounts, total, page, pageSize)
out := make([]dto.Account, 0, len(accounts))
for i := range accounts {
out = append(out, *dto.AccountFromService(&accounts[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
// BatchCreateProxyItem represents a single proxy in batch create request
type BatchCreateProxyItem struct {
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
@@ -263,10 +280,16 @@ func (h *ProxyHandler) BatchCreate(c *gin.Context) {
skipped := 0
for _, item := range req.Proxies {
// Trim all string fields
host := strings.TrimSpace(item.Host)
protocol := strings.TrimSpace(item.Protocol)
username := strings.TrimSpace(item.Username)
password := strings.TrimSpace(item.Password)
// Check for duplicates (same host, port, username, password)
exists, err := h.adminService.CheckProxyExists(c.Request.Context(), item.Host, item.Port, item.Username, item.Password)
exists, err := h.adminService.CheckProxyExists(c.Request.Context(), host, item.Port, username, password)
if err != nil {
response.InternalError(c, "Failed to check proxy existence: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -278,11 +301,11 @@ func (h *ProxyHandler) BatchCreate(c *gin.Context) {
// Create proxy with default name
_, err = h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
Name: "default",
Protocol: item.Protocol,
Host: item.Host,
Protocol: protocol,
Host: host,
Port: item.Port,
Username: item.Username,
Password: item.Password,
Username: username,
Password: password,
})
if err != nil {
// If creation fails due to duplicate, count as skipped

View File

@@ -6,8 +6,9 @@ import (
"fmt"
"strconv"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -43,11 +44,15 @@ func (h *RedeemHandler) List(c *gin.Context) {
codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search)
if err != nil {
response.InternalError(c, "Failed to list redeem codes: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Paginated(c, codes, total, page, pageSize)
out := make([]dto.RedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
// GetByID handles getting a redeem code by ID
@@ -61,11 +66,11 @@ func (h *RedeemHandler) GetByID(c *gin.Context) {
code, err := h.adminService.GetRedeemCode(c.Request.Context(), codeID)
if err != nil {
response.NotFound(c, "Redeem code not found")
response.ErrorFrom(c, err)
return
}
response.Success(c, code)
response.Success(c, dto.RedeemCodeFromService(code))
}
// Generate handles generating new redeem codes
@@ -85,11 +90,15 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
ValidityDays: req.ValidityDays,
})
if err != nil {
response.InternalError(c, "Failed to generate redeem codes: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, codes)
out := make([]dto.RedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
}
response.Success(c, out)
}
// Delete handles deleting a redeem code
@@ -103,7 +112,7 @@ func (h *RedeemHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteRedeemCode(c.Request.Context(), codeID)
if err != nil {
response.InternalError(c, "Failed to delete redeem code: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -123,7 +132,7 @@ func (h *RedeemHandler) BatchDelete(c *gin.Context) {
deleted, err := h.adminService.BatchDeleteRedeemCodes(c.Request.Context(), req.IDs)
if err != nil {
response.InternalError(c, "Failed to batch delete redeem codes: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -144,11 +153,11 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
code, err := h.adminService.ExpireRedeemCode(c.Request.Context(), codeID)
if err != nil {
response.InternalError(c, "Failed to expire redeem code: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, code)
response.Success(c, dto.RedeemCodeFromService(code))
}
// GetStats handles getting redeem code statistics
@@ -156,10 +165,10 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
func (h *RedeemHandler) GetStats(c *gin.Context) {
// Return mock data for now
response.Success(c, gin.H{
"total_codes": 0,
"active_codes": 0,
"used_codes": 0,
"expired_codes": 0,
"total_codes": 0,
"active_codes": 0,
"used_codes": 0,
"expired_codes": 0,
"total_value_distributed": 0.0,
"by_type": gin.H{
"balance": 0,
@@ -178,7 +187,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
// Get all codes without pagination (use large page size)
codes, _, err := h.adminService.ListRedeemCodes(c.Request.Context(), 1, 10000, codeType, status, "")
if err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -187,7 +196,10 @@ func (h *RedeemHandler) Export(c *gin.Context) {
writer := csv.NewWriter(&buf)
// Write header
writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"})
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"}); err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
return
}
// Write data rows
for _, code := range codes {
@@ -199,7 +211,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
if code.UsedAt != nil {
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
}
writer.Write([]string{
if err := writer.Write([]string{
fmt.Sprintf("%d", code.ID),
code.Code,
code.Type,
@@ -208,10 +220,17 @@ func (h *RedeemHandler) Export(c *gin.Context) {
usedBy,
usedAt,
code.CreatedAt.Format("2006-01-02 15:04:05"),
})
}); err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
return
}
}
writer.Flush()
if err := writer.Error(); err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
return
}
c.Header("Content-Type", "text/csv")
c.Header("Content-Disposition", "attachment; filename=redeem_codes.csv")

View File

@@ -1,9 +1,9 @@
package admin
import (
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -27,11 +27,32 @@ func NewSettingHandler(settingService *service.SettingService, emailService *ser
func (h *SettingHandler) GetSettings(c *gin.Context) {
settings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get settings: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, settings)
response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
SmtpHost: settings.SmtpHost,
SmtpPort: settings.SmtpPort,
SmtpUsername: settings.SmtpUsername,
SmtpPassword: settings.SmtpPassword,
SmtpFrom: settings.SmtpFrom,
SmtpFromName: settings.SmtpFromName,
SmtpUseTLS: settings.SmtpUseTLS,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
TurnstileSecretKey: settings.TurnstileSecretKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
ApiBaseUrl: settings.ApiBaseUrl,
ContactInfo: settings.ContactInfo,
DocUrl: settings.DocUrl,
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
})
}
// UpdateSettingsRequest 更新设置请求
@@ -60,6 +81,7 @@ type UpdateSettingsRequest struct {
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
@@ -86,7 +108,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.SmtpPort = 587
}
settings := &model.SystemSettings{
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
SmtpHost: req.SmtpHost,
@@ -104,23 +126,45 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
SiteSubtitle: req.SiteSubtitle,
ApiBaseUrl: req.ApiBaseUrl,
ContactInfo: req.ContactInfo,
DocUrl: req.DocUrl,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
}
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
response.InternalError(c, "Failed to update settings: "+err.Error())
response.ErrorFrom(c, err)
return
}
// 重新获取设置返回
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get updated settings: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, updatedSettings)
response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
SmtpHost: updatedSettings.SmtpHost,
SmtpPort: updatedSettings.SmtpPort,
SmtpUsername: updatedSettings.SmtpUsername,
SmtpPassword: updatedSettings.SmtpPassword,
SmtpFrom: updatedSettings.SmtpFrom,
SmtpFromName: updatedSettings.SmtpFromName,
SmtpUseTLS: updatedSettings.SmtpUseTLS,
TurnstileEnabled: updatedSettings.TurnstileEnabled,
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
TurnstileSecretKey: updatedSettings.TurnstileSecretKey,
SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle,
ApiBaseUrl: updatedSettings.ApiBaseUrl,
ContactInfo: updatedSettings.ContactInfo,
DocUrl: updatedSettings.DocUrl,
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
})
}
// TestSmtpRequest 测试SMTP连接请求
@@ -164,7 +208,7 @@ func (h *SettingHandler) TestSmtpConnection(c *gin.Context) {
err := h.emailService.TestSmtpConnectionWithConfig(config)
if err != nil {
response.BadRequest(c, "SMTP connection test failed: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -250,9 +294,49 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
`
if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil {
response.BadRequest(c, "Failed to send test email: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Test email sent successfully"})
}
// GetAdminApiKey 获取管理员 API Key 状态
// GET /api/v1/admin/settings/admin-api-key
func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
maskedKey, exists, err := h.settingService.GetAdminApiKeyStatus(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{
"exists": exists,
"masked_key": maskedKey,
})
}
// RegenerateAdminApiKey 生成/重新生成管理员 API Key
// POST /api/v1/admin/settings/admin-api-key/regenerate
func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
key, err := h.settingService.GenerateAdminApiKey(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{
"key": key, // 完整 key 只在生成时返回一次
})
}
// DeleteAdminApiKey 删除管理员 API Key
// DELETE /api/v1/admin/settings/admin-api-key
func (h *SettingHandler) DeleteAdminApiKey(c *gin.Context) {
if err := h.settingService.DeleteAdminApiKey(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Admin API key deleted"})
}

View File

@@ -3,16 +3,17 @@ package admin
import (
"strconv"
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// toResponsePagination converts repository.PaginationResult to response.PaginationResult
func toResponsePagination(p *repository.PaginationResult) *response.PaginationResult {
// toResponsePagination converts pagination.PaginationResult to response.PaginationResult
func toResponsePagination(p *pagination.PaginationResult) *response.PaginationResult {
if p == nil {
return nil
}
@@ -78,11 +79,15 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
if err != nil {
response.InternalError(c, "Failed to list subscriptions: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.PaginatedWithResult(c, subscriptions, toResponsePagination(pagination))
out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
}
// GetByID handles getting a subscription by ID
@@ -96,11 +101,11 @@ func (h *SubscriptionHandler) GetByID(c *gin.Context) {
subscription, err := h.subscriptionService.GetByID(c.Request.Context(), subscriptionID)
if err != nil {
response.NotFound(c, "Subscription not found")
response.ErrorFrom(c, err)
return
}
response.Success(c, subscription)
response.Success(c, dto.UserSubscriptionFromService(subscription))
}
// GetProgress handles getting subscription usage progress
@@ -141,11 +146,11 @@ func (h *SubscriptionHandler) Assign(c *gin.Context) {
Notes: req.Notes,
})
if err != nil {
response.BadRequest(c, "Failed to assign subscription: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, subscription)
response.Success(c, dto.UserSubscriptionFromService(subscription))
}
// BulkAssign handles bulk assigning subscriptions to multiple users
@@ -168,11 +173,11 @@ func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
Notes: req.Notes,
})
if err != nil {
response.InternalError(c, "Failed to bulk assign subscriptions: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
response.Success(c, dto.BulkAssignResultFromService(result))
}
// Extend handles extending a subscription
@@ -192,11 +197,11 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days)
if err != nil {
response.InternalError(c, "Failed to extend subscription: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, subscription)
response.Success(c, dto.UserSubscriptionFromService(subscription))
}
// Revoke handles revoking a subscription
@@ -210,7 +215,7 @@ func (h *SubscriptionHandler) Revoke(c *gin.Context) {
err = h.subscriptionService.RevokeSubscription(c.Request.Context(), subscriptionID)
if err != nil {
response.InternalError(c, "Failed to revoke subscription: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -230,11 +235,15 @@ func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
subscriptions, pagination, err := h.subscriptionService.ListGroupSubscriptions(c.Request.Context(), groupID, page, pageSize)
if err != nil {
response.InternalError(c, "Failed to list group subscriptions: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.PaginatedWithResult(c, subscriptions, toResponsePagination(pagination))
out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
}
// ListByUser handles listing subscriptions for a specific user
@@ -248,19 +257,22 @@ func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), userID)
if err != nil {
response.InternalError(c, "Failed to list user subscriptions: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, subscriptions)
out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.Success(c, out)
}
// Helper function to get admin ID from context
func getAdminIDFromContext(c *gin.Context) int64 {
if user, exists := c.Get("user"); exists {
if u, ok := user.(*model.User); ok && u != nil {
return u.ID
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
return 0
}
return 0
return subject.UserID
}

View File

@@ -4,12 +4,11 @@ import (
"net/http"
"time"
"sub2api/internal/pkg/response"
"sub2api/internal/pkg/sysutil"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/sysutil"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// SystemHandler handles system-related operations
@@ -18,9 +17,9 @@ type SystemHandler struct {
}
// NewSystemHandler creates a new SystemHandler
func NewSystemHandler(rdb *redis.Client, version, buildType string) *SystemHandler {
func NewSystemHandler(updateSvc *service.UpdateService) *SystemHandler {
return &SystemHandler{
updateSvc: service.NewUpdateService(rdb, version, buildType),
updateSvc: updateSvc,
}
}

View File

@@ -4,34 +4,33 @@ import (
"strconv"
"time"
"sub2api/internal/pkg/response"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// UsageHandler handles admin usage-related requests
type UsageHandler struct {
usageRepo *repository.UsageLogRepository
apiKeyRepo *repository.ApiKeyRepository
usageService *service.UsageService
apiKeyService *service.ApiKeyService
adminService service.AdminService
}
// NewUsageHandler creates a new admin usage handler
func NewUsageHandler(
usageRepo *repository.UsageLogRepository,
apiKeyRepo *repository.ApiKeyRepository,
usageService *service.UsageService,
apiKeyService *service.ApiKeyService,
adminService service.AdminService,
) *UsageHandler {
return &UsageHandler{
usageRepo: usageRepo,
apiKeyRepo: apiKeyRepo,
usageService: usageService,
adminService: adminService,
usageService: usageService,
apiKeyService: apiKeyService,
adminService: adminService,
}
}
@@ -82,21 +81,25 @@ func (h *UsageHandler) List(c *gin.Context) {
endTime = &t
}
params := repository.PaginationParams{Page: page, PageSize: pageSize}
filters := repository.UsageLogFilters{
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
filters := usagestats.UsageLogFilters{
UserID: userID,
ApiKeyID: apiKeyID,
StartTime: startTime,
EndTime: endTime,
}
records, result, err := h.usageRepo.ListWithFilters(c.Request.Context(), params, filters)
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
if err != nil {
response.InternalError(c, "Failed to list usage records: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Paginated(c, records, result.Total, page, pageSize)
out := make([]dto.UsageLog, 0, len(records))
for i := range records {
out = append(out, *dto.UsageLogFromService(&records[i]))
}
response.Paginated(c, out, result.Total, page, pageSize)
}
// Stats handles getting usage statistics with filters
@@ -160,7 +163,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
if apiKeyID > 0 {
stats, err := h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, stats)
@@ -170,7 +173,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
if userID > 0 {
stats, err := h.usageService.GetStatsByUser(c.Request.Context(), userID, startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, stats)
@@ -178,9 +181,9 @@ func (h *UsageHandler) Stats(c *gin.Context) {
}
// Get global stats
stats, err := h.usageRepo.GetGlobalStats(c.Request.Context(), startTime, endTime)
stats, err := h.usageService.GetGlobalStats(c.Request.Context(), startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -192,14 +195,14 @@ func (h *UsageHandler) Stats(c *gin.Context) {
func (h *UsageHandler) SearchUsers(c *gin.Context) {
keyword := c.Query("q")
if keyword == "" {
response.Success(c, []interface{}{})
response.Success(c, []any{})
return
}
// Limit to 30 results
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, "", "", keyword)
if err != nil {
response.InternalError(c, "Failed to search users: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -236,9 +239,9 @@ func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
userID = id
}
keys, err := h.apiKeyRepo.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
keys, err := h.apiKeyService.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
if err != nil {
response.InternalError(c, "Failed to search API keys: "+err.Error())
response.ErrorFrom(c, err)
return
}

View File

@@ -3,8 +3,9 @@ package admin
import (
"strconv"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -25,6 +26,9 @@ func NewUserHandler(adminService service.AdminService) *UserHandler {
type CreateUserRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
Username string `json:"username"`
Wechat string `json:"wechat"`
Notes string `json:"notes"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
AllowedGroups []int64 `json:"allowed_groups"`
@@ -35,6 +39,9 @@ type CreateUserRequest struct {
type UpdateUserRequest struct {
Email string `json:"email" binding:"omitempty,email"`
Password string `json:"password" binding:"omitempty,min=6"`
Username *string `json:"username"`
Wechat *string `json:"wechat"`
Notes *string `json:"notes"`
Balance *float64 `json:"balance"`
Concurrency *int `json:"concurrency"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
@@ -43,8 +50,9 @@ type UpdateUserRequest struct {
// UpdateBalanceRequest represents balance update request
type UpdateBalanceRequest struct {
Balance float64 `json:"balance" binding:"required"`
Balance float64 `json:"balance" binding:"required,gt=0"`
Operation string `json:"operation" binding:"required,oneof=set add subtract"`
Notes string `json:"notes"`
}
// List handles listing all users with pagination
@@ -57,11 +65,15 @@ func (h *UserHandler) List(c *gin.Context) {
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, status, role, search)
if err != nil {
response.InternalError(c, "Failed to list users: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Paginated(c, users, total, page, pageSize)
out := make([]dto.User, 0, len(users))
for i := range users {
out = append(out, *dto.UserFromService(&users[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
// GetByID handles getting a user by ID
@@ -75,11 +87,11 @@ func (h *UserHandler) GetByID(c *gin.Context) {
user, err := h.adminService.GetUser(c.Request.Context(), userID)
if err != nil {
response.NotFound(c, "User not found")
response.ErrorFrom(c, err)
return
}
response.Success(c, user)
response.Success(c, dto.UserFromService(user))
}
// Create handles creating a new user
@@ -94,16 +106,19 @@ func (h *UserHandler) Create(c *gin.Context) {
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
Email: req.Email,
Password: req.Password,
Username: req.Username,
Wechat: req.Wechat,
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
AllowedGroups: req.AllowedGroups,
})
if err != nil {
response.BadRequest(c, "Failed to create user: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, user)
response.Success(c, dto.UserFromService(user))
}
// Update handles updating a user
@@ -125,17 +140,20 @@ func (h *UserHandler) Update(c *gin.Context) {
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
Email: req.Email,
Password: req.Password,
Username: req.Username,
Wechat: req.Wechat,
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
Status: req.Status,
AllowedGroups: req.AllowedGroups,
})
if err != nil {
response.InternalError(c, "Failed to update user: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, user)
response.Success(c, dto.UserFromService(user))
}
// Delete handles deleting a user
@@ -149,7 +167,7 @@ func (h *UserHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteUser(c.Request.Context(), userID)
if err != nil {
response.InternalError(c, "Failed to delete user: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -171,13 +189,13 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
return
}
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation)
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes)
if err != nil {
response.InternalError(c, "Failed to update balance: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, user)
response.Success(c, dto.UserFromService(user))
}
// GetUserAPIKeys handles getting user's API keys
@@ -193,11 +211,15 @@ func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize)
if err != nil {
response.InternalError(c, "Failed to get user API keys: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Paginated(c, keys, total, page, pageSize)
out := make([]dto.ApiKey, 0, len(keys))
for i := range keys {
out = append(out, *dto.ApiKeyFromService(&keys[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
// GetUserUsage handles getting user's usage statistics
@@ -213,7 +235,7 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) {
stats, err := h.adminService.GetUserUsageStats(c.Request.Context(), userID, period)
if err != nil {
response.InternalError(c, "Failed to get user usage: "+err.Error())
response.ErrorFrom(c, err)
return
}

View File

@@ -3,10 +3,11 @@ package handler
import (
"strconv"
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -40,42 +41,34 @@ type UpdateAPIKeyRequest struct {
// List handles listing user's API keys with pagination
// GET /api/v1/api-keys
func (h *APIKeyHandler) List(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
page, pageSize := response.ParsePagination(c)
params := repository.PaginationParams{Page: page, PageSize: pageSize}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params)
keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params)
if err != nil {
response.InternalError(c, "Failed to list API keys: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Paginated(c, keys, result.Total, page, pageSize)
out := make([]dto.ApiKey, 0, len(keys))
for i := range keys {
out = append(out, *dto.ApiKeyFromService(&keys[i]))
}
response.Paginated(c, out, result.Total, page, pageSize)
}
// GetByID handles getting a single API key
// GET /api/v1/api-keys/:id
func (h *APIKeyHandler) GetByID(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -87,31 +80,25 @@ func (h *APIKeyHandler) GetByID(c *gin.Context) {
key, err := h.apiKeyService.GetByID(c.Request.Context(), keyID)
if err != nil {
response.NotFound(c, "API key not found")
response.ErrorFrom(c, err)
return
}
// 验证所有权
if key.UserID != user.ID {
if key.UserID != subject.UserID {
response.Forbidden(c, "Not authorized to access this key")
return
}
response.Success(c, key)
response.Success(c, dto.ApiKeyFromService(key))
}
// Create handles creating a new API key
// POST /api/v1/api-keys
func (h *APIKeyHandler) Create(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -126,27 +113,21 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
GroupID: req.GroupID,
CustomKey: req.CustomKey,
}
key, err := h.apiKeyService.Create(c.Request.Context(), user.ID, svcReq)
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
if err != nil {
response.InternalError(c, "Failed to create API key: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, key)
response.Success(c, dto.ApiKeyFromService(key))
}
// Update handles updating an API key
// PUT /api/v1/api-keys/:id
func (h *APIKeyHandler) Update(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -171,27 +152,21 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
svcReq.Status = &req.Status
}
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, user.ID, svcReq)
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, subject.UserID, svcReq)
if err != nil {
response.InternalError(c, "Failed to update API key: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, key)
response.Success(c, dto.ApiKeyFromService(key))
}
// Delete handles deleting an API key
// DELETE /api/v1/api-keys/:id
func (h *APIKeyHandler) Delete(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -201,9 +176,9 @@ func (h *APIKeyHandler) Delete(c *gin.Context) {
return
}
err = h.apiKeyService.Delete(c.Request.Context(), keyID, user.ID)
err = h.apiKeyService.Delete(c.Request.Context(), keyID, subject.UserID)
if err != nil {
response.InternalError(c, "Failed to delete API key: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -213,23 +188,21 @@ func (h *APIKeyHandler) Delete(c *gin.Context) {
// GetAvailableGroups 获取用户可以绑定的分组列表
// GET /api/v1/groups/available
func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), user.ID)
groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), subject.UserID)
if err != nil {
response.InternalError(c, "Failed to get available groups: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, groups)
out := make([]dto.Group, 0, len(groups))
for i := range groups {
out = append(out, *dto.GroupFromService(&groups[i]))
}
response.Success(c, out)
}

View File

@@ -1,9 +1,10 @@
package handler
import (
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -11,12 +12,14 @@ import (
// AuthHandler handles authentication-related requests
type AuthHandler struct {
authService *service.AuthService
userService *service.UserService
}
// NewAuthHandler creates a new AuthHandler
func NewAuthHandler(authService *service.AuthService) *AuthHandler {
func NewAuthHandler(authService *service.AuthService, userService *service.UserService) *AuthHandler {
return &AuthHandler{
authService: authService,
userService: userService,
}
}
@@ -49,9 +52,9 @@ type LoginRequest struct {
// AuthResponse 认证响应格式(匹配前端期望)
type AuthResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
User *model.User `json:"user"`
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
User *dto.User `json:"user"`
}
// Register handles user registration
@@ -66,21 +69,21 @@ func (h *AuthHandler) Register(c *gin.Context) {
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
if req.VerifyCode == "" {
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
response.ErrorFrom(c, err)
return
}
}
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode)
if err != nil {
response.BadRequest(c, "Registration failed: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
User: user,
User: dto.UserFromService(user),
})
}
@@ -95,13 +98,13 @@ func (h *AuthHandler) SendVerifyCode(c *gin.Context) {
// Turnstile 验证
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
response.ErrorFrom(c, err)
return
}
result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email)
if err != nil {
response.BadRequest(c, "Failed to send verification code: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -122,37 +125,37 @@ func (h *AuthHandler) Login(c *gin.Context) {
// Turnstile 验证
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
response.ErrorFrom(c, err)
return
}
token, user, err := h.authService.Login(c.Request.Context(), req.Email, req.Password)
if err != nil {
response.Unauthorized(c, "Login failed: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
User: user,
User: dto.UserFromService(user),
})
}
// GetCurrentUser handles getting current authenticated user
// GET /api/v1/auth/me
func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
user, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, user)
response.Success(c, dto.UserFromService(user))
}

View File

@@ -0,0 +1,310 @@
package dto
import "github.com/Wei-Shaw/sub2api/internal/service"
func UserFromServiceShallow(u *service.User) *User {
if u == nil {
return nil
}
return &User{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Wechat: u.Wechat,
Notes: u.Notes,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
AllowedGroups: u.AllowedGroups,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}
func UserFromService(u *service.User) *User {
if u == nil {
return nil
}
out := UserFromServiceShallow(u)
if len(u.ApiKeys) > 0 {
out.ApiKeys = make([]ApiKey, 0, len(u.ApiKeys))
for i := range u.ApiKeys {
k := u.ApiKeys[i]
out.ApiKeys = append(out.ApiKeys, *ApiKeyFromService(&k))
}
}
if len(u.Subscriptions) > 0 {
out.Subscriptions = make([]UserSubscription, 0, len(u.Subscriptions))
for i := range u.Subscriptions {
s := u.Subscriptions[i]
out.Subscriptions = append(out.Subscriptions, *UserSubscriptionFromService(&s))
}
}
return out
}
func ApiKeyFromService(k *service.ApiKey) *ApiKey {
if k == nil {
return nil
}
return &ApiKey{
ID: k.ID,
UserID: k.UserID,
Key: k.Key,
Name: k.Name,
GroupID: k.GroupID,
Status: k.Status,
CreatedAt: k.CreatedAt,
UpdatedAt: k.UpdatedAt,
User: UserFromServiceShallow(k.User),
Group: GroupFromServiceShallow(k.Group),
}
}
func GroupFromServiceShallow(g *service.Group) *Group {
if g == nil {
return nil
}
return &Group{
ID: g.ID,
Name: g.Name,
Description: g.Description,
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUSD,
WeeklyLimitUSD: g.WeeklyLimitUSD,
MonthlyLimitUSD: g.MonthlyLimitUSD,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
AccountCount: g.AccountCount,
}
}
func GroupFromService(g *service.Group) *Group {
if g == nil {
return nil
}
out := GroupFromServiceShallow(g)
if len(g.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
for i := range g.AccountGroups {
ag := g.AccountGroups[i]
out.AccountGroups = append(out.AccountGroups, *AccountGroupFromService(&ag))
}
}
return out
}
func AccountFromServiceShallow(a *service.Account) *Account {
if a == nil {
return nil
}
return &Account{
ID: a.ID,
Name: a.Name,
Platform: a.Platform,
Type: a.Type,
Credentials: a.Credentials,
Extra: a.Extra,
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
Priority: a.Priority,
Status: a.Status,
ErrorMessage: a.ErrorMessage,
LastUsedAt: a.LastUsedAt,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
Schedulable: a.Schedulable,
RateLimitedAt: a.RateLimitedAt,
RateLimitResetAt: a.RateLimitResetAt,
OverloadUntil: a.OverloadUntil,
SessionWindowStart: a.SessionWindowStart,
SessionWindowEnd: a.SessionWindowEnd,
SessionWindowStatus: a.SessionWindowStatus,
GroupIDs: a.GroupIDs,
}
}
func AccountFromService(a *service.Account) *Account {
if a == nil {
return nil
}
out := AccountFromServiceShallow(a)
out.Proxy = ProxyFromService(a.Proxy)
if len(a.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(a.AccountGroups))
for i := range a.AccountGroups {
ag := a.AccountGroups[i]
out.AccountGroups = append(out.AccountGroups, *AccountGroupFromService(&ag))
}
}
if len(a.Groups) > 0 {
out.Groups = make([]*Group, 0, len(a.Groups))
for _, g := range a.Groups {
out.Groups = append(out.Groups, GroupFromServiceShallow(g))
}
}
return out
}
func AccountGroupFromService(ag *service.AccountGroup) *AccountGroup {
if ag == nil {
return nil
}
return &AccountGroup{
AccountID: ag.AccountID,
GroupID: ag.GroupID,
Priority: ag.Priority,
CreatedAt: ag.CreatedAt,
Account: AccountFromServiceShallow(ag.Account),
Group: GroupFromServiceShallow(ag.Group),
}
}
func ProxyFromService(p *service.Proxy) *Proxy {
if p == nil {
return nil
}
return &Proxy{
ID: p.ID,
Name: p.Name,
Protocol: p.Protocol,
Host: p.Host,
Port: p.Port,
Username: p.Username,
Password: p.Password,
Status: p.Status,
CreatedAt: p.CreatedAt,
UpdatedAt: p.UpdatedAt,
}
}
func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWithAccountCount {
if p == nil {
return nil
}
return &ProxyWithAccountCount{
Proxy: *ProxyFromService(&p.Proxy),
AccountCount: p.AccountCount,
}
}
func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
if rc == nil {
return nil
}
return &RedeemCode{
ID: rc.ID,
Code: rc.Code,
Type: rc.Type,
Value: rc.Value,
Status: rc.Status,
UsedBy: rc.UsedBy,
UsedAt: rc.UsedAt,
Notes: rc.Notes,
CreatedAt: rc.CreatedAt,
GroupID: rc.GroupID,
ValidityDays: rc.ValidityDays,
User: UserFromServiceShallow(rc.User),
Group: GroupFromServiceShallow(rc.Group),
}
}
func UsageLogFromService(l *service.UsageLog) *UsageLog {
if l == nil {
return nil
}
return &UsageLog{
ID: l.ID,
UserID: l.UserID,
ApiKeyID: l.ApiKeyID,
AccountID: l.AccountID,
RequestID: l.RequestID,
Model: l.Model,
GroupID: l.GroupID,
SubscriptionID: l.SubscriptionID,
InputTokens: l.InputTokens,
OutputTokens: l.OutputTokens,
CacheCreationTokens: l.CacheCreationTokens,
CacheReadTokens: l.CacheReadTokens,
CacheCreation5mTokens: l.CacheCreation5mTokens,
CacheCreation1hTokens: l.CacheCreation1hTokens,
InputCost: l.InputCost,
OutputCost: l.OutputCost,
CacheCreationCost: l.CacheCreationCost,
CacheReadCost: l.CacheReadCost,
TotalCost: l.TotalCost,
ActualCost: l.ActualCost,
RateMultiplier: l.RateMultiplier,
BillingType: l.BillingType,
Stream: l.Stream,
DurationMs: l.DurationMs,
FirstTokenMs: l.FirstTokenMs,
CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User),
ApiKey: ApiKeyFromService(l.ApiKey),
Account: AccountFromService(l.Account),
Group: GroupFromServiceShallow(l.Group),
Subscription: UserSubscriptionFromService(l.Subscription),
}
}
func SettingFromService(s *service.Setting) *Setting {
if s == nil {
return nil
}
return &Setting{
ID: s.ID,
Key: s.Key,
Value: s.Value,
UpdatedAt: s.UpdatedAt,
}
}
func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscription {
if sub == nil {
return nil
}
return &UserSubscription{
ID: sub.ID,
UserID: sub.UserID,
GroupID: sub.GroupID,
StartsAt: sub.StartsAt,
ExpiresAt: sub.ExpiresAt,
Status: sub.Status,
DailyWindowStart: sub.DailyWindowStart,
WeeklyWindowStart: sub.WeeklyWindowStart,
MonthlyWindowStart: sub.MonthlyWindowStart,
DailyUsageUSD: sub.DailyUsageUSD,
WeeklyUsageUSD: sub.WeeklyUsageUSD,
MonthlyUsageUSD: sub.MonthlyUsageUSD,
AssignedBy: sub.AssignedBy,
AssignedAt: sub.AssignedAt,
Notes: sub.Notes,
CreatedAt: sub.CreatedAt,
UpdatedAt: sub.UpdatedAt,
User: UserFromServiceShallow(sub.User),
Group: GroupFromServiceShallow(sub.Group),
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
}
}
func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult {
if r == nil {
return nil
}
subs := make([]UserSubscription, 0, len(r.Subscriptions))
for i := range r.Subscriptions {
subs = append(subs, *UserSubscriptionFromService(&r.Subscriptions[i]))
}
return &BulkAssignResult{
SuccessCount: r.SuccessCount,
FailedCount: r.FailedCount,
Subscriptions: subs,
Errors: r.Errors,
}
}

View File

@@ -0,0 +1,43 @@
package dto
// SystemSettings represents the admin settings API response payload.
type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
SmtpHost string `json:"smtp_host"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password,omitempty"`
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
}
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
Version string `json:"version"`
}

View File

@@ -0,0 +1,219 @@
package dto
import "time"
type User struct {
ID int64 `json:"id"`
Email string `json:"email"`
Username string `json:"username"`
Wechat string `json:"wechat"`
Notes string `json:"notes"`
Role string `json:"role"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
Status string `json:"status"`
AllowedGroups []int64 `json:"allowed_groups"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ApiKeys []ApiKey `json:"api_keys,omitempty"`
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
}
type ApiKey struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
Key string `json:"key"`
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
}
type Group struct {
ID int64 `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
Status string `json:"status"`
SubscriptionType string `json:"subscription_type"`
DailyLimitUSD *float64 `json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
AccountCount int64 `json:"account_count,omitempty"`
}
type Account struct {
ID int64 `json:"id"`
Name string `json:"name"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
Status string `json:"status"`
ErrorMessage string `json:"error_message"`
LastUsedAt *time.Time `json:"last_used_at"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Schedulable bool `json:"schedulable"`
RateLimitedAt *time.Time `json:"rate_limited_at"`
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
OverloadUntil *time.Time `json:"overload_until"`
SessionWindowStart *time.Time `json:"session_window_start"`
SessionWindowEnd *time.Time `json:"session_window_end"`
SessionWindowStatus string `json:"session_window_status"`
Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
GroupIDs []int64 `json:"group_ids,omitempty"`
Groups []*Group `json:"groups,omitempty"`
}
type AccountGroup struct {
AccountID int64 `json:"account_id"`
GroupID int64 `json:"group_id"`
Priority int `json:"priority"`
CreatedAt time.Time `json:"created_at"`
Account *Account `json:"account,omitempty"`
Group *Group `json:"group,omitempty"`
}
type Proxy struct {
ID int64 `json:"id"`
Name string `json:"name"`
Protocol string `json:"protocol"`
Host string `json:"host"`
Port int `json:"port"`
Username string `json:"username"`
Password string `json:"-"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type ProxyWithAccountCount struct {
Proxy
AccountCount int64 `json:"account_count"`
}
type RedeemCode struct {
ID int64 `json:"id"`
Code string `json:"code"`
Type string `json:"type"`
Value float64 `json:"value"`
Status string `json:"status"`
UsedBy *int64 `json:"used_by"`
UsedAt *time.Time `json:"used_at"`
Notes string `json:"notes"`
CreatedAt time.Time `json:"created_at"`
GroupID *int64 `json:"group_id"`
ValidityDays int `json:"validity_days"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
}
type UsageLog struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
ApiKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"`
Model string `json:"model"`
GroupID *int64 `json:"group_id"`
SubscriptionID *int64 `json:"subscription_id"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationTokens int `json:"cache_creation_tokens"`
CacheReadTokens int `json:"cache_read_tokens"`
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
InputCost float64 `json:"input_cost"`
OutputCost float64 `json:"output_cost"`
CacheCreationCost float64 `json:"cache_creation_cost"`
CacheReadCost float64 `json:"cache_read_cost"`
TotalCost float64 `json:"total_cost"`
ActualCost float64 `json:"actual_cost"`
RateMultiplier float64 `json:"rate_multiplier"`
BillingType int8 `json:"billing_type"`
Stream bool `json:"stream"`
DurationMs *int `json:"duration_ms"`
FirstTokenMs *int `json:"first_token_ms"`
CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"`
ApiKey *ApiKey `json:"api_key,omitempty"`
Account *Account `json:"account,omitempty"`
Group *Group `json:"group,omitempty"`
Subscription *UserSubscription `json:"subscription,omitempty"`
}
type Setting struct {
ID int64 `json:"id"`
Key string `json:"key"`
Value string `json:"value"`
UpdatedAt time.Time `json:"updated_at"`
}
type UserSubscription struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
GroupID int64 `json:"group_id"`
StartsAt time.Time `json:"starts_at"`
ExpiresAt time.Time `json:"expires_at"`
Status string `json:"status"`
DailyWindowStart *time.Time `json:"daily_window_start"`
WeeklyWindowStart *time.Time `json:"weekly_window_start"`
MonthlyWindowStart *time.Time `json:"monthly_window_start"`
DailyUsageUSD float64 `json:"daily_usage_usd"`
WeeklyUsageUSD float64 `json:"weekly_usage_usd"`
MonthlyUsageUSD float64 `json:"monthly_usage_usd"`
AssignedBy *int64 `json:"assigned_by"`
AssignedAt time.Time `json:"assigned_at"`
Notes string `json:"notes"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
AssignedByUser *User `json:"assigned_by_user,omitempty"`
}
type BulkAssignResult struct {
SuccessCount int `json:"success_count"`
FailedCount int `json:"failed_count"`
Subscriptions []UserSubscription `json:"subscriptions"`
Errors []string `json:"errors"`
}

View File

@@ -7,37 +7,40 @@ import (
"io"
"log"
"net/http"
"strings"
"time"
"sub2api/internal/middleware"
"sub2api/internal/model"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const (
// Maximum wait time for concurrency slot
maxConcurrencyWait = 60 * time.Second
// Ping interval during wait
pingInterval = 5 * time.Second
)
// GatewayHandler handles API gateway requests
type GatewayHandler struct {
gatewayService *service.GatewayService
geminiCompatService *service.GeminiMessagesCompatService
userService *service.UserService
concurrencyService *service.ConcurrencyService
billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
}
// NewGatewayHandler creates a new GatewayHandler
func NewGatewayHandler(gatewayService *service.GatewayService, userService *service.UserService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService) *GatewayHandler {
func NewGatewayHandler(
gatewayService *service.GatewayService,
geminiCompatService *service.GeminiMessagesCompatService,
userService *service.UserService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
) *GatewayHandler {
return &GatewayHandler{
gatewayService: gatewayService,
geminiCompatService: geminiCompatService,
userService: userService,
concurrencyService: concurrencyService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude),
}
}
@@ -45,13 +48,13 @@ func NewGatewayHandler(gatewayService *service.GatewayService, userService *serv
// POST /v1/messages
func (h *GatewayHandler) Messages(c *gin.Context) {
// 从context获取apiKey和userApiKeyAuth中间件已设置
apiKey, ok := middleware.GetApiKeyFromContext(c)
apiKey, ok := middleware2.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
user, ok := middleware.GetUserFromContext(c)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
@@ -83,11 +86,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
streamStarted := false
// 获取订阅信息可能为nil- 提前获取用于后续检查
subscription, _ := middleware.GetSubscriptionFromContext(c)
subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 0. 检查wait队列是否已满
maxWait := service.CalculateMaxWait(user.Concurrency)
canWait, err := h.concurrencyService.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
if err != nil {
log.Printf("Increment wait count failed: %v", err)
// On error, allow request to proceed
@@ -96,10 +99,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 确保在函数退出时减少wait计数
defer h.concurrencyService.DecrementWaitCount(c.Request.Context(), user.ID)
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
// 1. 首先获取用户并发槽位
userReleaseFunc, err := h.acquireUserSlotWithWait(c, user, req.Stream, &streamStarted)
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, req.Stream, &streamStarted)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted)
@@ -110,7 +113,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
return
@@ -119,15 +122,35 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 计算粘性会话hash
sessionHash := h.gatewayService.GenerateSessionHash(body)
platform := ""
if apiKey.Group != nil {
platform = apiKey.Group.Platform
}
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
var account *service.Account
if platform == service.PlatformGemini {
account, err = h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
} else {
account, err = h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
}
if err != nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if req.Stream {
sendMockWarmupStream(c, req.Model)
} else {
sendMockWarmupResponse(c, req.Model)
}
return
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.acquireAccountSlotWithWait(c, account, req.Stream, &streamStarted)
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
@@ -138,7 +161,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 转发请求
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
var result *service.ForwardResult
if platform == service.PlatformGemini {
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
} else {
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body)
}
if err != nil {
// 错误响应已在Forward中处理这里只记录日志
log.Printf("Forward request failed: %v", err)
@@ -152,7 +180,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
User: user,
User: apiKey.User,
Account: account,
Subscription: subscription,
}); err != nil {
@@ -161,167 +189,38 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}()
}
// acquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary
// For streaming requests, sends ping events during the wait
// streamStarted is updated if streaming response has begun
func (h *GatewayHandler) acquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency)
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
// Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted)
}
// acquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary
// For streaming requests, sends ping events during the wait
// streamStarted is updated if streaming response has begun
func (h *GatewayHandler) acquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
// Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted)
}
// concurrencyError represents a concurrency limit error with context
type concurrencyError struct {
SlotType string
IsTimeout bool
}
func (e *concurrencyError) Error() string {
if e.IsTimeout {
return fmt.Sprintf("timeout waiting for %s concurrency slot", e.SlotType)
}
return fmt.Sprintf("%s concurrency limit reached", e.SlotType)
}
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests
// Note: For streaming requests, we send ping to keep the connection alive.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller)
func (h *GatewayHandler) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
defer cancel()
// For streaming requests, set up SSE headers for ping
var flusher http.Flusher
if isStream {
var ok bool
flusher, ok = c.Writer.(http.Flusher)
if !ok {
return nil, fmt.Errorf("streaming not supported")
}
}
pingTicker := time.NewTicker(pingInterval)
defer pingTicker.Stop()
pollTicker := time.NewTicker(100 * time.Millisecond)
defer pollTicker.Stop()
for {
select {
case <-ctx.Done():
return nil, &concurrencyError{
SlotType: slotType,
IsTimeout: true,
}
case <-pingTicker.C:
// Send ping for streaming requests to keep connection alive
if isStream && flusher != nil {
// Set headers on first ping (lazy initialization)
if !*streamStarted {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
*streamStarted = true
}
fmt.Fprintf(c.Writer, "data: {\"type\": \"ping\"}\n\n")
flusher.Flush()
}
case <-pollTicker.C:
// Try to acquire slot
var result *service.AcquireResult
var err error
if slotType == "user" {
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
} else {
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
}
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
}
}
}
// Models handles listing available models
// GET /v1/models
// Returns different model lists based on the API key's group platform
func (h *GatewayHandler) Models(c *gin.Context) {
models := []gin.H{
{
"id": "claude-opus-4-5-20251101",
"type": "model",
"display_name": "Claude Opus 4.5",
"created_at": "2025-11-01T00:00:00Z",
},
{
"id": "claude-sonnet-4-5-20250929",
"type": "model",
"display_name": "Claude Sonnet 4.5",
"created_at": "2025-09-29T00:00:00Z",
},
{
"id": "claude-haiku-4-5-20251001",
"type": "model",
"display_name": "Claude Haiku 4.5",
"created_at": "2025-10-01T00:00:00Z",
},
apiKey, _ := middleware2.GetApiKeyFromContext(c)
// Return OpenAI models for OpenAI platform groups
if apiKey != nil && apiKey.Group != nil && apiKey.Group.Platform == "openai" {
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": openai.DefaultModels,
})
return
}
// Default: Claude models
c.JSON(http.StatusOK, gin.H{
"data": models,
"object": "list",
"data": claude.DefaultModels,
})
}
// Usage handles getting account balance for CC Switch integration
// GET /v1/usage
func (h *GatewayHandler) Usage(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
apiKey, ok := middleware2.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
user, ok := middleware.GetUserFromContext(c)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
@@ -329,7 +228,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
// 订阅模式:返回订阅限额信息
if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() {
subscription, ok := middleware.GetSubscriptionFromContext(c)
subscription, ok := middleware2.GetSubscriptionFromContext(c)
if !ok {
h.errorResponse(c, http.StatusForbidden, "subscription_error", "No active subscription")
return
@@ -346,7 +245,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
}
// 余额模式:返回钱包余额
latestUser, err := h.userService.GetByID(c.Request.Context(), user.ID)
latestUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info")
return
@@ -364,7 +263,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
// 逻辑:
// 1. 如果日/周/月任一限额达到100%返回0
// 2. 否则返回所有已配置周期中剩余额度的最小值
func (h *GatewayHandler) calculateSubscriptionRemaining(group *model.Group, sub *model.UserSubscription) float64 {
func (h *GatewayHandler) calculateSubscriptionRemaining(group *service.Group, sub *service.UserSubscription) float64 {
var remainingValues []float64
// 检查日限额
@@ -423,7 +322,9 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
if ok {
// Send error event in SSE format
errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
fmt.Fprint(c.Writer, errorEvent)
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}
flusher.Flush()
}
return
@@ -443,3 +344,155 @@ func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, mess
},
})
}
// CountTokens handles token counting endpoint
// POST /v1/messages/count_tokens
// 特点:校验订阅/余额,但不计算并发、不记录使用量
func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 从context获取apiKey和userApiKeyAuth中间件已设置
apiKey, ok := middleware2.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
_, ok = middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
// 解析请求获取模型名
var req struct {
Model string `json:"model"`
}
if err := json.Unmarshal(body, &req); err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// 获取订阅信息可能为nil
subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 校验 billing eligibility订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
h.errorResponse(c, http.StatusForbidden, "billing_error", err.Error())
return
}
// 计算粘性会话 hash
sessionHash := h.gatewayService.GenerateSessionHash(body)
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
if err != nil {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
// 转发请求(不记录使用量)
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, body); err != nil {
log.Printf("Forward count_tokens request failed: %v", err)
// 错误响应已在 ForwardCountTokens 中处理
return
}
}
// isWarmupRequest 检测是否为预热请求标题生成、Warmup等
func isWarmupRequest(body []byte) bool {
// 快速检查如果body不包含关键字直接返回false
bodyStr := string(body)
if !strings.Contains(bodyStr, "title") && !strings.Contains(bodyStr, "Warmup") {
return false
}
// 解析完整请求
var req struct {
Messages []struct {
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"content"`
} `json:"messages"`
System []struct {
Text string `json:"text"`
} `json:"system"`
}
if err := json.Unmarshal(body, &req); err != nil {
return false
}
// 检查 messages 中的标题提示模式
for _, msg := range req.Messages {
for _, content := range msg.Content {
if content.Type == "text" {
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
content.Text == "Warmup" {
return true
}
}
}
}
// 检查 system 中的标题提取模式
for _, system := range req.System {
if strings.Contains(system.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
return true
}
}
return false
}
// sendMockWarmupStream 发送流式 mock 响应(用于预热请求拦截)
func sendMockWarmupStream(c *gin.Context, model string) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
events := []string{
`event: message_start` + "\n" + `data: {"message":{"content":[],"id":"msg_mock_warmup","model":"` + model + `","role":"assistant","stop_reason":null,"stop_sequence":null,"type":"message","usage":{"input_tokens":10,"output_tokens":0}},"type":"message_start"}`,
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
`event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
`event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
`event: content_block_stop` + "\n" + `data: {"index":0,"type":"content_block_stop"}`,
`event: message_delta` + "\n" + `data: {"delta":{"stop_reason":"end_turn","stop_sequence":null},"type":"message_delta","usage":{"input_tokens":10,"output_tokens":2}}`,
`event: message_stop` + "\n" + `data: {"type":"message_stop"}`,
}
for _, event := range events {
_, _ = c.Writer.WriteString(event + "\n\n")
c.Writer.Flush()
time.Sleep(20 * time.Millisecond)
}
}
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
func sendMockWarmupResponse(c *gin.Context, model string) {
c.JSON(http.StatusOK, gin.H{
"id": "msg_mock_warmup",
"type": "message",
"role": "assistant",
"model": model,
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
"stop_reason": "end_turn",
"usage": gin.H{
"input_tokens": 10,
"output_tokens": 2,
},
})
}

View File

@@ -0,0 +1,179 @@
package handler
import (
"context"
"fmt"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const (
// maxConcurrencyWait is the maximum time to wait for a concurrency slot
maxConcurrencyWait = 30 * time.Second
// pingInterval is the interval for sending ping events during slot wait
pingInterval = 15 * time.Second
)
// SSEPingFormat defines the format of SSE ping events for different platforms
type SSEPingFormat string
const (
// SSEPingFormatClaude is the Claude/Anthropic SSE ping format
SSEPingFormatClaude SSEPingFormat = "data: {\"type\": \"ping\"}\n\n"
// SSEPingFormatNone indicates no ping should be sent (e.g., OpenAI has no ping spec)
SSEPingFormatNone SSEPingFormat = ""
)
// ConcurrencyError represents a concurrency limit error with context
type ConcurrencyError struct {
SlotType string
IsTimeout bool
}
func (e *ConcurrencyError) Error() string {
if e.IsTimeout {
return fmt.Sprintf("timeout waiting for %s concurrency slot", e.SlotType)
}
return fmt.Sprintf("%s concurrency limit reached", e.SlotType)
}
// ConcurrencyHelper provides common concurrency slot management for gateway handlers
type ConcurrencyHelper struct {
concurrencyService *service.ConcurrencyService
pingFormat SSEPingFormat
}
// NewConcurrencyHelper creates a new ConcurrencyHelper
func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat) *ConcurrencyHelper {
return &ConcurrencyHelper{
concurrencyService: concurrencyService,
pingFormat: pingFormat,
}
}
// IncrementWaitCount increments the wait count for a user
func (h *ConcurrencyHelper) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
return h.concurrencyService.IncrementWaitCount(ctx, userID, maxWait)
}
// DecrementWaitCount decrements the wait count for a user
func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64) {
h.concurrencyService.DecrementWaitCount(ctx, userID)
}
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun.
func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency)
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
// Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "user", userID, maxConcurrency, isStream, streamStarted)
}
// AcquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun.
func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
// Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "account", accountID, maxConcurrency, isStream, streamStarted)
}
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
defer cancel()
// Determine if ping is needed (streaming + ping format defined)
needPing := isStream && h.pingFormat != ""
var flusher http.Flusher
if needPing {
var ok bool
flusher, ok = c.Writer.(http.Flusher)
if !ok {
return nil, fmt.Errorf("streaming not supported")
}
}
// Only create ping ticker if ping is needed
var pingCh <-chan time.Time
if needPing {
pingTicker := time.NewTicker(pingInterval)
defer pingTicker.Stop()
pingCh = pingTicker.C
}
pollTicker := time.NewTicker(100 * time.Millisecond)
defer pollTicker.Stop()
for {
select {
case <-ctx.Done():
return nil, &ConcurrencyError{
SlotType: slotType,
IsTimeout: true,
}
case <-pingCh:
// Send ping to keep connection alive
if !*streamStarted {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
*streamStarted = true
}
if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil {
return nil, err
}
flusher.Flush()
case <-pollTicker.C:
// Try to acquire slot
var result *service.AcquireResult
var err error
if slotType == "user" {
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
} else {
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
}
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
}
}
}

View File

@@ -0,0 +1,272 @@
package handler
import (
"context"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// GeminiV1BetaListModels proxies:
// GET /v1beta/models
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok || apiKey == nil {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
}
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return
}
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
if err != nil {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
res, err := h.geminiCompatService.ForwardAIStudioGET(c.Request.Context(), account, "/v1beta/models")
if err != nil {
googleError(c, http.StatusBadGateway, err.Error())
return
}
if shouldFallbackGeminiModels(res) {
c.JSON(http.StatusOK, gemini.FallbackModelsList())
return
}
writeUpstreamResponse(c, res)
}
// GeminiV1BetaGetModel proxies:
// GET /v1beta/models/{model}
func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok || apiKey == nil {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
}
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return
}
modelName := strings.TrimSpace(c.Param("model"))
if modelName == "" {
googleError(c, http.StatusBadRequest, "Missing model in URL")
return
}
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
if err != nil {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
res, err := h.geminiCompatService.ForwardAIStudioGET(c.Request.Context(), account, "/v1beta/models/"+modelName)
if err != nil {
googleError(c, http.StatusBadGateway, err.Error())
return
}
if shouldFallbackGeminiModels(res) {
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
return
}
writeUpstreamResponse(c, res)
}
// GeminiV1BetaModels proxies Gemini native REST endpoints like:
// POST /v1beta/models/{model}:generateContent
// POST /v1beta/models/{model}:streamGenerateContent?alt=sse
func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok || apiKey == nil {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
}
authSubject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok {
googleError(c, http.StatusInternalServerError, "User context not found")
return
}
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return
}
modelName, action, err := parseGeminiModelAction(strings.TrimPrefix(c.Param("modelAction"), "/"))
if err != nil {
googleError(c, http.StatusNotFound, err.Error())
return
}
stream := action == "streamGenerateContent"
body, err := io.ReadAll(c.Request.Body)
if err != nil {
googleError(c, http.StatusBadRequest, "Failed to read request body")
return
}
if len(body) == 0 {
googleError(c, http.StatusBadRequest, "Request body is empty")
return
}
// Get subscription (may be nil)
subscription, _ := middleware.GetSubscriptionFromContext(c)
// For Gemini native API, do not send Claude-style ping frames.
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone)
// 0) wait queue check
maxWait := service.CalculateMaxWait(authSubject.Concurrency)
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait)
if err != nil {
log.Printf("Increment wait count failed: %v", err)
} else if !canWait {
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
}
defer geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID)
// 1) user concurrency slot
streamStarted := false
userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted)
if err != nil {
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
googleError(c, http.StatusForbidden, err.Error())
return
}
// 3) select account (sticky session based on request body)
sessionHash := h.gatewayService.GenerateSessionHash(body)
account, err := h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, modelName)
if err != nil {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
// 4) account concurrency slot
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted)
if err != nil {
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
if accountReleaseFunc != nil {
defer accountReleaseFunc()
}
// 5) forward (writes response to client)
result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
if err != nil {
// ForwardNative already wrote the response
log.Printf("Gemini native forward failed: %v", err)
return
}
// 6) record usage async
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}()
}
func parseGeminiModelAction(rest string) (model string, action string, err error) {
rest = strings.TrimSpace(rest)
if rest == "" {
return "", "", &pathParseError{"missing path"}
}
// Standard: {model}:{action}
if i := strings.Index(rest, ":"); i > 0 && i < len(rest)-1 {
return rest[:i], rest[i+1:], nil
}
// Fallback: {model}/{action}
if i := strings.Index(rest, "/"); i > 0 && i < len(rest)-1 {
return rest[:i], rest[i+1:], nil
}
return "", "", &pathParseError{"invalid model action path"}
}
type pathParseError struct{ msg string }
func (e *pathParseError) Error() string { return e.msg }
func googleError(c *gin.Context, status int, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"code": status,
"message": message,
"status": googleapi.HTTPStatusToGoogleStatus(status),
},
})
}
func writeUpstreamResponse(c *gin.Context, res *service.UpstreamHTTPResult) {
if res == nil {
googleError(c, http.StatusBadGateway, "Empty upstream response")
return
}
for k, vv := range res.Headers {
// Avoid overriding content-length and hop-by-hop headers.
if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") {
continue
}
for _, v := range vv {
c.Writer.Header().Add(k, v)
}
}
contentType := res.Headers.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(res.StatusCode, contentType, res.Body)
}
func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
if res == nil {
return true
}
if res.StatusCode != http.StatusUnauthorized && res.StatusCode != http.StatusForbidden {
return false
}
if strings.Contains(strings.ToLower(res.Headers.Get("Www-Authenticate")), "insufficient_scope") {
return true
}
if strings.Contains(strings.ToLower(string(res.Body)), "insufficient authentication scopes") {
return true
}
if strings.Contains(strings.ToLower(string(res.Body)), "access_token_scope_insufficient") {
return true
}
return false
}

View File

@@ -1,7 +1,7 @@
package handler
import (
"sub2api/internal/handler/admin"
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
)
// AdminHandlers contains all admin-related HTTP handlers
@@ -11,6 +11,8 @@ type AdminHandlers struct {
Group *admin.GroupHandler
Account *admin.AccountHandler
OAuth *admin.OAuthHandler
OpenAIOAuth *admin.OpenAIOAuthHandler
GeminiOAuth *admin.GeminiOAuthHandler
Proxy *admin.ProxyHandler
Redeem *admin.RedeemHandler
Setting *admin.SettingHandler
@@ -21,15 +23,16 @@ type AdminHandlers struct {
// Handlers contains all HTTP handlers
type Handlers struct {
Auth *AuthHandler
User *UserHandler
APIKey *APIKeyHandler
Usage *UsageHandler
Redeem *RedeemHandler
Subscription *SubscriptionHandler
Admin *AdminHandlers
Gateway *GatewayHandler
Setting *SettingHandler
Auth *AuthHandler
User *UserHandler
APIKey *APIKeyHandler
Usage *UsageHandler
Redeem *RedeemHandler
Subscription *SubscriptionHandler
Admin *AdminHandlers
Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler
Setting *SettingHandler
}
// BuildInfo contains build-time information

View File

@@ -0,0 +1,209 @@
package handler
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// OpenAIGatewayHandler handles OpenAI API gateway requests
type OpenAIGatewayHandler struct {
gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
) *OpenAIGatewayHandler {
return &OpenAIGatewayHandler{
gatewayService: gatewayService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatNone),
}
}
// Responses handles OpenAI Responses API endpoint
// POST /openai/v1/responses
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Get apiKey and user from context (set by ApiKeyAuth middleware)
apiKey, ok := middleware2.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
// Read request body
body, err := io.ReadAll(c.Request.Body)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
// Parse request body to map for potential modification
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// Extract model and stream
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
// For non-Codex CLI requests, set default instructions
userAgent := c.GetHeader("User-Agent")
if !openai.IsCodexCLIRequest(userAgent) {
reqBody["instructions"] = openai.DefaultInstructions
// Re-serialize body
body, err = json.Marshal(reqBody)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
return
}
}
// Track if we've started streaming (for error handling)
streamStarted := false
// Get subscription info (may be nil)
subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 0. Check if wait queue is full
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
if err != nil {
log.Printf("Increment wait count failed: %v", err)
// On error, allow request to proceed
} else if !canWait {
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
// Ensure wait count is decremented when function exits
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
// 1. First acquire user concurrency slot
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
return
}
// Generate session hash (from header for OpenAI)
sessionHash := h.gatewayService.GenerateSessionHash(c)
// Select account supporting the requested model
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel)
if err != nil {
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
// 3. Acquire account concurrency slot
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountReleaseFunc != nil {
defer accountReleaseFunc()
}
// Forward request
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
if err != nil {
// Error response already handled in Forward, just log
log.Printf("Forward request failed: %v", err)
return
}
// Async record usage
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
ApiKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}()
}
// handleConcurrencyError handles concurrency-related errors with proper 429 response
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
// handleStreamingAwareError handles errors that may occur after streaming has started
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {
// Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher)
if ok {
// Send error event in OpenAI SSE format
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}
flusher.Flush()
}
return
}
// Normal case: return JSON response with proper status code
h.errorResponse(c, status, errType, message)
}
// errorResponse returns OpenAI API format error response
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}

View File

@@ -1,9 +1,10 @@
package handler
import (
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -37,15 +38,9 @@ type RedeemResponse struct {
// Redeem handles redeeming a code
// POST /api/v1/redeem
func (h *RedeemHandler) Redeem(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -55,38 +50,36 @@ func (h *RedeemHandler) Redeem(c *gin.Context) {
return
}
result, err := h.redeemService.Redeem(c.Request.Context(), user.ID, req.Code)
result, err := h.redeemService.Redeem(c.Request.Context(), subject.UserID, req.Code)
if err != nil {
response.BadRequest(c, "Failed to redeem code: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
response.Success(c, dto.RedeemCodeFromService(result))
}
// GetHistory returns the user's redemption history
// GET /api/v1/redeem/history
func (h *RedeemHandler) GetHistory(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
// Default limit is 25
limit := 25
codes, err := h.redeemService.GetUserHistory(c.Request.Context(), user.ID, limit)
codes, err := h.redeemService.GetUserHistory(c.Request.Context(), subject.UserID, limit)
if err != nil {
response.InternalError(c, "Failed to get history: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, codes)
out := make([]dto.RedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
}
response.Success(c, out)
}

View File

@@ -1,8 +1,9 @@
package handler
import (
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -26,10 +27,21 @@ func NewSettingHandler(settingService *service.SettingService, version string) *
func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
settings, err := h.settingService.GetPublicSettings(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get settings: "+err.Error())
response.ErrorFrom(c, err)
return
}
settings.Version = h.version
response.Success(c, settings)
response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
ApiBaseUrl: settings.ApiBaseUrl,
ContactInfo: settings.ContactInfo,
DocUrl: settings.DocUrl,
Version: h.version,
})
}

View File

@@ -1,9 +1,10 @@
package handler
import (
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -25,7 +26,7 @@ type SubscriptionSummaryItem struct {
// SubscriptionProgressInfo represents subscription with progress info
type SubscriptionProgressInfo struct {
Subscription *model.UserSubscription `json:"subscription"`
Subscription *dto.UserSubscription `json:"subscription"`
Progress *service.SubscriptionProgress `json:"progress"`
}
@@ -44,70 +45,60 @@ func NewSubscriptionHandler(subscriptionService *service.SubscriptionService) *S
// List handles listing current user's subscriptions
// GET /api/v1/subscriptions
func (h *SubscriptionHandler) List(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), u.ID)
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), subject.UserID)
if err != nil {
response.InternalError(c, "Failed to list subscriptions: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, subscriptions)
out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.Success(c, out)
}
// GetActive handles getting current user's active subscriptions
// GET /api/v1/subscriptions/active
func (h *SubscriptionHandler) GetActive(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
if err != nil {
response.InternalError(c, "Failed to get active subscriptions: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, subscriptions)
out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.Success(c, out)
}
// GetProgress handles getting subscription progress for current user
// GET /api/v1/subscriptions/progress
func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
// Get all active subscriptions with progress
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
if err != nil {
response.InternalError(c, "Failed to get subscriptions: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -120,7 +111,7 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
continue
}
result = append(result, SubscriptionProgressInfo{
Subscription: sub,
Subscription: dto.UserSubscriptionFromService(sub),
Progress: progress,
})
}
@@ -131,22 +122,16 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
// GetSummary handles getting a summary of current user's subscription status
// GET /api/v1/subscriptions/summary
func (h *SubscriptionHandler) GetSummary(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
// Get all active subscriptions
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
if err != nil {
response.InternalError(c, "Failed to get subscriptions: "+err.Error())
response.ErrorFrom(c, err)
return
}

View File

@@ -4,11 +4,12 @@ import (
"strconv"
"time"
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -16,15 +17,13 @@ import (
// UsageHandler handles usage-related requests
type UsageHandler struct {
usageService *service.UsageService
usageRepo *repository.UsageLogRepository
apiKeyService *service.ApiKeyService
}
// NewUsageHandler creates a new UsageHandler
func NewUsageHandler(usageService *service.UsageService, usageRepo *repository.UsageLogRepository, apiKeyService *service.ApiKeyService) *UsageHandler {
func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.ApiKeyService) *UsageHandler {
return &UsageHandler{
usageService: usageService,
usageRepo: usageRepo,
apiKeyService: apiKeyService,
}
}
@@ -32,15 +31,9 @@ func NewUsageHandler(usageService *service.UsageService, usageRepo *repository.U
// List handles listing usage records with pagination
// GET /api/v1/usage
func (h *UsageHandler) List(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -57,10 +50,10 @@ func (h *UsageHandler) List(c *gin.Context) {
// [Security Fix] Verify API Key ownership to prevent horizontal privilege escalation
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id)
if err != nil {
response.NotFound(c, "API key not found")
response.ErrorFrom(c, err)
return
}
if apiKey.UserID != user.ID {
if apiKey.UserID != subject.UserID {
response.Forbidden(c, "Not authorized to access this API key's usage records")
return
}
@@ -68,36 +61,34 @@ func (h *UsageHandler) List(c *gin.Context) {
apiKeyID = id
}
params := repository.PaginationParams{Page: page, PageSize: pageSize}
var records []model.UsageLog
var result *repository.PaginationResult
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
var records []service.UsageLog
var result *pagination.PaginationResult
var err error
if apiKeyID > 0 {
records, result, err = h.usageService.ListByApiKey(c.Request.Context(), apiKeyID, params)
} else {
records, result, err = h.usageService.ListByUser(c.Request.Context(), user.ID, params)
records, result, err = h.usageService.ListByUser(c.Request.Context(), subject.UserID, params)
}
if err != nil {
response.InternalError(c, "Failed to list usage records: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Paginated(c, records, result.Total, page, pageSize)
out := make([]dto.UsageLog, 0, len(records))
for i := range records {
out = append(out, *dto.UsageLogFromService(&records[i]))
}
response.Paginated(c, out, result.Total, page, pageSize)
}
// GetByID handles getting a single usage record
// GET /api/v1/usage/:id
func (h *UsageHandler) GetByID(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -109,31 +100,25 @@ func (h *UsageHandler) GetByID(c *gin.Context) {
record, err := h.usageService.GetByID(c.Request.Context(), usageID)
if err != nil {
response.NotFound(c, "Usage record not found")
response.ErrorFrom(c, err)
return
}
// 验证所有权
if record.UserID != user.ID {
if record.UserID != subject.UserID {
response.Forbidden(c, "Not authorized to access this record")
return
}
response.Success(c, record)
response.Success(c, dto.UsageLogFromService(record))
}
// Stats handles getting usage statistics
// GET /api/v1/usage/stats
func (h *UsageHandler) Stats(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -151,7 +136,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
response.NotFound(c, "API key not found")
return
}
if apiKey.UserID != user.ID {
if apiKey.UserID != subject.UserID {
response.Forbidden(c, "Not authorized to access this API key's statistics")
return
}
@@ -203,10 +188,10 @@ func (h *UsageHandler) Stats(c *gin.Context) {
if apiKeyID > 0 {
stats, err = h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
} else {
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), user.ID, startTime, endTime)
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), subject.UserID, startTime, endTime)
}
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -247,21 +232,15 @@ func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) {
// DashboardStats handles getting user dashboard statistics
// GET /api/v1/usage/dashboard/stats
func (h *UsageHandler) DashboardStats(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
stats, err := h.usageRepo.GetUserDashboardStats(c.Request.Context(), user.ID)
stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID)
if err != nil {
response.InternalError(c, "Failed to get dashboard statistics")
response.ErrorFrom(c, err)
return
}
@@ -271,24 +250,18 @@ func (h *UsageHandler) DashboardStats(c *gin.Context) {
// DashboardTrend handles getting user usage trend data
// GET /api/v1/usage/dashboard/trend
func (h *UsageHandler) DashboardTrend(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
startTime, endTime := parseUserTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
trend, err := h.usageRepo.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity)
trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), subject.UserID, startTime, endTime, granularity)
if err != nil {
response.InternalError(c, "Failed to get usage trend")
response.ErrorFrom(c, err)
return
}
@@ -303,23 +276,17 @@ func (h *UsageHandler) DashboardTrend(c *gin.Context) {
// DashboardModels handles getting user model usage statistics
// GET /api/v1/usage/dashboard/models
func (h *UsageHandler) DashboardModels(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
startTime, endTime := parseUserTimeRange(c)
stats, err := h.usageRepo.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime)
stats, err := h.usageService.GetUserModelStats(c.Request.Context(), subject.UserID, startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get model statistics")
response.ErrorFrom(c, err)
return
}
@@ -338,15 +305,9 @@ type BatchApiKeysUsageRequest struct {
// DashboardApiKeysUsage handles getting usage stats for user's own API keys
// POST /api/v1/usage/dashboard/api-keys-usage
func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -357,14 +318,14 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
}
if len(req.ApiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]interface{}{}})
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
// Verify ownership of all requested API keys
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, repository.PaginationParams{Page: 1, PageSize: 1000})
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, pagination.PaginationParams{Page: 1, PageSize: 1000})
if err != nil {
response.InternalError(c, "Failed to verify API key ownership")
response.ErrorFrom(c, err)
return
}
@@ -382,13 +343,13 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
}
if len(validApiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]interface{}{}})
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
stats, err := h.usageRepo.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
stats, err := h.usageService.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
if err != nil {
response.InternalError(c, "Failed to get API key usage stats")
response.ErrorFrom(c, err)
return
}

View File

@@ -1,9 +1,10 @@
package handler
import (
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -26,42 +27,39 @@ type ChangePasswordRequest struct {
NewPassword string `json:"new_password" binding:"required,min=6"`
}
// UpdateProfileRequest represents the update profile request payload
type UpdateProfileRequest struct {
Username *string `json:"username"`
Wechat *string `json:"wechat"`
}
// GetProfile handles getting user profile
// GET /api/v1/users/me
func (h *UserHandler) GetProfile(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
userData, err := h.userService.GetByID(c.Request.Context(), user.ID)
userData, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
response.InternalError(c, "Failed to get user profile: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, userData)
// 清空notes字段普通用户不应看到备注
userData.Notes = ""
response.Success(c, dto.UserFromService(userData))
}
// ChangePassword handles changing user password
// POST /api/v1/users/me/password
func (h *UserHandler) ChangePassword(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -75,11 +73,42 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
CurrentPassword: req.OldPassword,
NewPassword: req.NewPassword,
}
err := h.userService.ChangePassword(c.Request.Context(), user.ID, svcReq)
err := h.userService.ChangePassword(c.Request.Context(), subject.UserID, svcReq)
if err != nil {
response.BadRequest(c, "Failed to change password: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Password changed successfully"})
}
// UpdateProfile handles updating user profile
// PUT /api/v1/users/me
func (h *UserHandler) UpdateProfile(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req UpdateProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
svcReq := service.UpdateProfileRequest{
Username: req.Username,
Wechat: req.Wechat,
}
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
if err != nil {
response.ErrorFrom(c, err)
return
}
// 清空notes字段普通用户不应看到备注
updatedUser.Notes = ""
response.Success(c, dto.UserFromService(updatedUser))
}

View File

@@ -1,11 +1,10 @@
package handler
import (
"sub2api/internal/handler/admin"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
)
// ProvideAdminHandlers creates the AdminHandlers struct
@@ -15,6 +14,8 @@ func ProvideAdminHandlers(
groupHandler *admin.GroupHandler,
accountHandler *admin.AccountHandler,
oauthHandler *admin.OAuthHandler,
openaiOAuthHandler *admin.OpenAIOAuthHandler,
geminiOAuthHandler *admin.GeminiOAuthHandler,
proxyHandler *admin.ProxyHandler,
redeemHandler *admin.RedeemHandler,
settingHandler *admin.SettingHandler,
@@ -28,6 +29,8 @@ func ProvideAdminHandlers(
Group: groupHandler,
Account: accountHandler,
OAuth: oauthHandler,
OpenAIOAuth: openaiOAuthHandler,
GeminiOAuth: geminiOAuthHandler,
Proxy: proxyHandler,
Redeem: redeemHandler,
Setting: settingHandler,
@@ -37,9 +40,9 @@ func ProvideAdminHandlers(
}
}
// ProvideSystemHandler creates admin.SystemHandler with BuildInfo parameters
func ProvideSystemHandler(rdb *redis.Client, buildInfo BuildInfo) *admin.SystemHandler {
return admin.NewSystemHandler(rdb, buildInfo.Version, buildInfo.BuildType)
// ProvideSystemHandler creates admin.SystemHandler with UpdateService
func ProvideSystemHandler(updateService *service.UpdateService) *admin.SystemHandler {
return admin.NewSystemHandler(updateService)
}
// ProvideSettingHandler creates SettingHandler with version from BuildInfo
@@ -57,18 +60,20 @@ func ProvideHandlers(
subscriptionHandler *SubscriptionHandler,
adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler,
settingHandler *SettingHandler,
) *Handlers {
return &Handlers{
Auth: authHandler,
User: userHandler,
APIKey: apiKeyHandler,
Usage: usageHandler,
Redeem: redeemHandler,
Subscription: subscriptionHandler,
Admin: adminHandlers,
Gateway: gatewayHandler,
Setting: settingHandler,
Auth: authHandler,
User: userHandler,
APIKey: apiKeyHandler,
Usage: usageHandler,
Redeem: redeemHandler,
Subscription: subscriptionHandler,
Admin: adminHandlers,
Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler,
Setting: settingHandler,
}
}
@@ -82,6 +87,7 @@ var ProviderSet = wire.NewSet(
NewRedeemHandler,
NewSubscriptionHandler,
NewGatewayHandler,
NewOpenAIGatewayHandler,
ProvideSettingHandler,
// Admin handlers
@@ -90,6 +96,8 @@ var ProviderSet = wire.NewSet(
admin.NewGroupHandler,
admin.NewAccountHandler,
admin.NewOAuthHandler,
admin.NewOpenAIOAuthHandler,
admin.NewGeminiOAuthHandler,
admin.NewProxyHandler,
admin.NewRedeemHandler,
admin.NewSettingHandler,

View File

@@ -1,9 +1,9 @@
package infrastructure
import (
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/repository"
"gorm.io/driver/postgres"
"gorm.io/gorm"
@@ -30,7 +30,7 @@ func InitDB(cfg *config.Config) (*gorm.DB, error) {
// 自动迁移(始终执行,确保数据库结构与代码同步)
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
if err := model.AutoMigrate(db); err != nil {
if err := repository.AutoMigrate(db); err != nil {
return nil, err
}

View File

@@ -0,0 +1,158 @@
package errors
import (
"errors"
"fmt"
"net/http"
)
const (
UnknownCode = http.StatusInternalServerError
UnknownReason = ""
UnknownMessage = "internal error"
)
type Status struct {
Code int32 `json:"code"`
Reason string `json:"reason,omitempty"`
Message string `json:"message"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// ApplicationError is the standard error type used to control HTTP responses.
//
// Code is expected to be an HTTP status code (e.g. 400/401/403/404/409/500).
type ApplicationError struct {
Status
cause error
}
// Error is kept for backwards compatibility within this package.
type Error = ApplicationError
func (e *ApplicationError) Error() string {
if e == nil {
return "<nil>"
}
if e.cause == nil {
return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v", e.Code, e.Reason, e.Message, e.Metadata)
}
return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v cause=%v", e.Code, e.Reason, e.Message, e.Metadata, e.cause)
}
// Unwrap provides compatibility for Go 1.13 error chains.
func (e *ApplicationError) Unwrap() error { return e.cause }
// Is matches each error in the chain with the target value.
func (e *ApplicationError) Is(err error) bool {
if se := new(ApplicationError); errors.As(err, &se) {
return se.Code == e.Code && se.Reason == e.Reason
}
return false
}
// WithCause attaches the underlying cause of the error.
func (e *ApplicationError) WithCause(cause error) *ApplicationError {
err := Clone(e)
err.cause = cause
return err
}
// WithMetadata deep-copies the given metadata map.
func (e *ApplicationError) WithMetadata(md map[string]string) *ApplicationError {
err := Clone(e)
if md == nil {
err.Metadata = nil
return err
}
err.Metadata = make(map[string]string, len(md))
for k, v := range md {
err.Metadata[k] = v
}
return err
}
// New returns an error object for the code, message.
func New(code int, reason, message string) *ApplicationError {
return &ApplicationError{
Status: Status{
Code: int32(code),
Message: message,
Reason: reason,
},
}
}
// Newf New(code fmt.Sprintf(format, a...))
func Newf(code int, reason, format string, a ...any) *ApplicationError {
return New(code, reason, fmt.Sprintf(format, a...))
}
// Errorf returns an error object for the code, message and error info.
func Errorf(code int, reason, format string, a ...any) error {
return New(code, reason, fmt.Sprintf(format, a...))
}
// Code returns the http code for an error.
// It supports wrapped errors.
func Code(err error) int {
if err == nil {
return http.StatusOK
}
return int(FromError(err).Code)
}
// Reason returns the reason for a particular error.
// It supports wrapped errors.
func Reason(err error) string {
if err == nil {
return UnknownReason
}
return FromError(err).Reason
}
// Message returns the message for a particular error.
// It supports wrapped errors.
func Message(err error) string {
if err == nil {
return ""
}
return FromError(err).Message
}
// Clone deep clone error to a new error.
func Clone(err *ApplicationError) *ApplicationError {
if err == nil {
return nil
}
var metadata map[string]string
if err.Metadata != nil {
metadata = make(map[string]string, len(err.Metadata))
for k, v := range err.Metadata {
metadata[k] = v
}
}
return &ApplicationError{
cause: err.cause,
Status: Status{
Code: err.Code,
Reason: err.Reason,
Message: err.Message,
Metadata: metadata,
},
}
}
// FromError tries to convert an error to *ApplicationError.
// It supports wrapped errors.
func FromError(err error) *ApplicationError {
if err == nil {
return nil
}
if se := new(ApplicationError); errors.As(err, &se) {
return se
}
// Fall back to a generic internal error.
return New(UnknownCode, UnknownReason, UnknownMessage).WithCause(err)
}

View File

@@ -0,0 +1,168 @@
//go:build unit
package errors
import (
stderrors "errors"
"fmt"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
func TestApplicationError_Basics(t *testing.T) {
tests := []struct {
name string
err *ApplicationError
want Status
wantIs bool
target error
wrapped error
}{
{
name: "new",
err: New(400, "BAD_REQUEST", "invalid input"),
want: Status{
Code: 400,
Reason: "BAD_REQUEST",
Message: "invalid input",
},
},
{
name: "is_matches_code_and_reason",
err: New(401, "UNAUTHORIZED", "nope"),
want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"},
target: New(401, "UNAUTHORIZED", "ignored message"),
wantIs: true,
},
{
name: "is_does_not_match_reason",
err: New(401, "UNAUTHORIZED", "nope"),
want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"},
target: New(401, "DIFFERENT", "ignored message"),
wantIs: false,
},
{
name: "from_error_unwraps_wrapped_application_error",
err: New(404, "NOT_FOUND", "missing"),
wrapped: fmt.Errorf("wrap: %w", New(404, "NOT_FOUND", "missing")),
want: Status{
Code: 404,
Reason: "NOT_FOUND",
Message: "missing",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.err != nil {
require.Equal(t, tt.want, tt.err.Status)
}
if tt.target != nil {
require.Equal(t, tt.wantIs, stderrors.Is(tt.err, tt.target))
}
if tt.wrapped != nil {
got := FromError(tt.wrapped)
require.Equal(t, tt.want, got.Status)
}
})
}
}
func TestApplicationError_WithMetadataDeepCopy(t *testing.T) {
tests := []struct {
name string
md map[string]string
}{
{name: "non_nil", md: map[string]string{"a": "1"}},
{name: "nil", md: nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
appErr := BadRequest("BAD_REQUEST", "invalid input").WithMetadata(tt.md)
if tt.md == nil {
require.Nil(t, appErr.Metadata)
return
}
tt.md["a"] = "changed"
require.Equal(t, "1", appErr.Metadata["a"])
})
}
}
func TestFromError_Generic(t *testing.T) {
tests := []struct {
name string
err error
wantCode int32
wantReason string
wantMsg string
}{
{
name: "plain_error",
err: stderrors.New("boom"),
wantCode: UnknownCode,
wantReason: UnknownReason,
wantMsg: UnknownMessage,
},
{
name: "wrapped_plain_error",
err: fmt.Errorf("wrap: %w", io.EOF),
wantCode: UnknownCode,
wantReason: UnknownReason,
wantMsg: UnknownMessage,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := FromError(tt.err)
require.Equal(t, tt.wantCode, got.Code)
require.Equal(t, tt.wantReason, got.Reason)
require.Equal(t, tt.wantMsg, got.Message)
require.Equal(t, tt.err, got.Unwrap())
})
}
}
func TestToHTTP(t *testing.T) {
tests := []struct {
name string
err error
wantStatusCode int
wantBody Status
}{
{
name: "nil_error",
err: nil,
wantStatusCode: http.StatusOK,
wantBody: Status{Code: int32(http.StatusOK)},
},
{
name: "application_error",
err: Forbidden("FORBIDDEN", "no access"),
wantStatusCode: http.StatusForbidden,
wantBody: Status{
Code: int32(http.StatusForbidden),
Reason: "FORBIDDEN",
Message: "no access",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
code, body := ToHTTP(tt.err)
require.Equal(t, tt.wantStatusCode, code)
require.Equal(t, tt.wantBody, body)
})
}
}

View File

@@ -0,0 +1,21 @@
package errors
import "net/http"
// ToHTTP converts an error into an HTTP status code and a JSON-serializable body.
//
// The returned body matches the project's Status shape:
// { code, reason, message, metadata }.
func ToHTTP(err error) (statusCode int, body Status) {
if err == nil {
return http.StatusOK, Status{Code: int32(http.StatusOK)}
}
appErr := FromError(err)
if appErr == nil {
return http.StatusOK, Status{Code: int32(http.StatusOK)}
}
cloned := Clone(appErr)
return int(cloned.Code), cloned.Status
}

View File

@@ -0,0 +1,114 @@
// nolint:mnd
package errors
import "net/http"
// BadRequest new BadRequest error that is mapped to a 400 response.
func BadRequest(reason, message string) *ApplicationError {
return New(http.StatusBadRequest, reason, message)
}
// IsBadRequest determines if err is an error which indicates a BadRequest error.
// It supports wrapped errors.
func IsBadRequest(err error) bool {
return Code(err) == http.StatusBadRequest
}
// TooManyRequests new TooManyRequests error that is mapped to a 429 response.
func TooManyRequests(reason, message string) *ApplicationError {
return New(http.StatusTooManyRequests, reason, message)
}
// IsTooManyRequests determines if err is an error which indicates a TooManyRequests error.
// It supports wrapped errors.
func IsTooManyRequests(err error) bool {
return Code(err) == http.StatusTooManyRequests
}
// Unauthorized new Unauthorized error that is mapped to a 401 response.
func Unauthorized(reason, message string) *ApplicationError {
return New(http.StatusUnauthorized, reason, message)
}
// IsUnauthorized determines if err is an error which indicates an Unauthorized error.
// It supports wrapped errors.
func IsUnauthorized(err error) bool {
return Code(err) == http.StatusUnauthorized
}
// Forbidden new Forbidden error that is mapped to a 403 response.
func Forbidden(reason, message string) *ApplicationError {
return New(http.StatusForbidden, reason, message)
}
// IsForbidden determines if err is an error which indicates a Forbidden error.
// It supports wrapped errors.
func IsForbidden(err error) bool {
return Code(err) == http.StatusForbidden
}
// NotFound new NotFound error that is mapped to a 404 response.
func NotFound(reason, message string) *ApplicationError {
return New(http.StatusNotFound, reason, message)
}
// IsNotFound determines if err is an error which indicates an NotFound error.
// It supports wrapped errors.
func IsNotFound(err error) bool {
return Code(err) == http.StatusNotFound
}
// Conflict new Conflict error that is mapped to a 409 response.
func Conflict(reason, message string) *ApplicationError {
return New(http.StatusConflict, reason, message)
}
// IsConflict determines if err is an error which indicates a Conflict error.
// It supports wrapped errors.
func IsConflict(err error) bool {
return Code(err) == http.StatusConflict
}
// InternalServer new InternalServer error that is mapped to a 500 response.
func InternalServer(reason, message string) *ApplicationError {
return New(http.StatusInternalServerError, reason, message)
}
// IsInternalServer determines if err is an error which indicates an Internal error.
// It supports wrapped errors.
func IsInternalServer(err error) bool {
return Code(err) == http.StatusInternalServerError
}
// ServiceUnavailable new ServiceUnavailable error that is mapped to an HTTP 503 response.
func ServiceUnavailable(reason, message string) *ApplicationError {
return New(http.StatusServiceUnavailable, reason, message)
}
// IsServiceUnavailable determines if err is an error which indicates an Unavailable error.
// It supports wrapped errors.
func IsServiceUnavailable(err error) bool {
return Code(err) == http.StatusServiceUnavailable
}
// GatewayTimeout new GatewayTimeout error that is mapped to an HTTP 504 response.
func GatewayTimeout(reason, message string) *ApplicationError {
return New(http.StatusGatewayTimeout, reason, message)
}
// IsGatewayTimeout determines if err is an error which indicates a GatewayTimeout error.
// It supports wrapped errors.
func IsGatewayTimeout(err error) bool {
return Code(err) == http.StatusGatewayTimeout
}
// ClientClosed new ClientClosed error that is mapped to an HTTP 499 response.
func ClientClosed(reason, message string) *ApplicationError {
return New(499, reason, message)
}
// IsClientClosed determines if err is an error which indicates a IsClientClosed error.
// It supports wrapped errors.
func IsClientClosed(err error) bool {
return Code(err) == 499
}

View File

@@ -1,7 +1,7 @@
package infrastructure
import (
"sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/redis/go-redis/v9"
)

View File

@@ -1,7 +1,7 @@
package infrastructure
import (
"sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/wire"
"github.com/redis/go-redis/v9"

View File

@@ -1,265 +0,0 @@
package model
import (
"database/sql/driver"
"encoding/json"
"errors"
"time"
"gorm.io/gorm"
)
// JSONB 用于存储JSONB数据
type JSONB map[string]interface{}
func (j JSONB) Value() (driver.Value, error) {
if j == nil {
return nil, nil
}
return json.Marshal(j)
}
func (j *JSONB) Scan(value interface{}) error {
if value == nil {
*j = nil
return nil
}
bytes, ok := value.([]byte)
if !ok {
return errors.New("type assertion to []byte failed")
}
return json.Unmarshal(bytes, j)
}
type Account struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"size:100;not null" json:"name"`
Platform string `gorm:"size:50;not null" json:"platform"` // anthropic/openai/gemini
Type string `gorm:"size:20;not null" json:"type"` // oauth/apikey
Credentials JSONB `gorm:"type:jsonb;default:'{}'" json:"credentials"` // 凭证(加密存储)
Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息
ProxyID *int64 `gorm:"index" json:"proxy_id"`
Concurrency int `gorm:"default:3;not null" json:"concurrency"`
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100越小越高
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
ErrorMessage string `gorm:"type:text" json:"error_message"`
LastUsedAt *time.Time `gorm:"index" json:"last_used_at"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 调度控制
Schedulable bool `gorm:"default:true;not null" json:"schedulable"`
// 限流状态 (429)
RateLimitedAt *time.Time `gorm:"index" json:"rate_limited_at"`
RateLimitResetAt *time.Time `gorm:"index" json:"rate_limit_reset_at"`
// 过载状态 (529)
OverloadUntil *time.Time `gorm:"index" json:"overload_until"`
// 5小时时间窗口
SessionWindowStart *time.Time `json:"session_window_start"`
SessionWindowEnd *time.Time `json:"session_window_end"`
SessionWindowStatus string `gorm:"size:20" json:"session_window_status"` // allowed/allowed_warning/rejected
// 关联
Proxy *Proxy `gorm:"foreignKey:ProxyID" json:"proxy,omitempty"`
AccountGroups []AccountGroup `gorm:"foreignKey:AccountID" json:"account_groups,omitempty"`
// 虚拟字段 (不存储到数据库)
GroupIDs []int64 `gorm:"-" json:"group_ids,omitempty"`
}
func (Account) TableName() string {
return "accounts"
}
// IsActive 检查是否激活
func (a *Account) IsActive() bool {
return a.Status == "active"
}
// IsSchedulable 检查账号是否可调度
func (a *Account) IsSchedulable() bool {
if !a.IsActive() || !a.Schedulable {
return false
}
now := time.Now()
if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) {
return false
}
if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) {
return false
}
return true
}
// IsRateLimited 检查是否处于限流状态
func (a *Account) IsRateLimited() bool {
if a.RateLimitResetAt == nil {
return false
}
return time.Now().Before(*a.RateLimitResetAt)
}
// IsOverloaded 检查是否处于过载状态
func (a *Account) IsOverloaded() bool {
if a.OverloadUntil == nil {
return false
}
return time.Now().Before(*a.OverloadUntil)
}
// IsOAuth 检查是否为OAuth类型账号包括oauth和setup-token
func (a *Account) IsOAuth() bool {
return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken
}
// CanGetUsage 检查账号是否可以获取usage信息只有oauth类型可以setup-token没有profile权限
func (a *Account) CanGetUsage() bool {
return a.Type == AccountTypeOAuth
}
// GetCredential 获取凭证字段
func (a *Account) GetCredential(key string) string {
if a.Credentials == nil {
return ""
}
if v, ok := a.Credentials[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
// GetModelMapping 获取模型映射配置
// 返回格式: map[请求模型名]实际模型名
func (a *Account) GetModelMapping() map[string]string {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["model_mapping"]
if !ok || raw == nil {
return nil
}
// 处理map[string]interface{}类型
if m, ok := raw.(map[string]interface{}); ok {
result := make(map[string]string)
for k, v := range m {
if s, ok := v.(string); ok {
result[k] = s
}
}
if len(result) > 0 {
return result
}
}
return nil
}
// IsModelSupported 检查请求的模型是否被该账号支持
// 如果没有设置模型映射,则支持所有模型
func (a *Account) IsModelSupported(requestedModel string) bool {
mapping := a.GetModelMapping()
if mapping == nil || len(mapping) == 0 {
return true // 没有映射配置,支持所有模型
}
_, exists := mapping[requestedModel]
return exists
}
// GetMappedModel 获取映射后的实际模型名
// 如果没有映射,返回原始模型名
func (a *Account) GetMappedModel(requestedModel string) string {
mapping := a.GetModelMapping()
if mapping == nil || len(mapping) == 0 {
return requestedModel
}
if mappedModel, exists := mapping[requestedModel]; exists {
return mappedModel
}
return requestedModel
}
// GetBaseURL 获取API基础URL用于apikey类型账号
func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeApiKey {
return ""
}
baseURL := a.GetCredential("base_url")
if baseURL == "" {
return "https://api.anthropic.com" // 默认URL
}
return baseURL
}
// GetExtraString 从Extra字段获取字符串值
func (a *Account) GetExtraString(key string) string {
if a.Extra == nil {
return ""
}
if v, ok := a.Extra[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
// IsCustomErrorCodesEnabled 检查是否启用自定义错误码功能(仅适用于 apikey 类型)
func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeApiKey || a.Credentials == nil {
return false
}
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// GetCustomErrorCodes 获取自定义错误码列表
func (a *Account) GetCustomErrorCodes() []int {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["custom_error_codes"]
if !ok || raw == nil {
return nil
}
// 处理 []interface{} 类型JSON反序列化后的格式
if arr, ok := raw.([]interface{}); ok {
result := make([]int, 0, len(arr))
for _, v := range arr {
// JSON 数字默认解析为 float64
if f, ok := v.(float64); ok {
result = append(result, int(f))
}
}
return result
}
return nil
}
// ShouldHandleErrorCode 检查指定错误码是否应该被处理(停止调度/标记限流等)
// 如果未启用自定义错误码或列表为空,返回 true使用默认策略
// 如果启用且列表非空,只有在列表中的错误码才返回 true
func (a *Account) ShouldHandleErrorCode(statusCode int) bool {
if !a.IsCustomErrorCodesEnabled() {
return true // 未启用,使用默认策略
}
codes := a.GetCustomErrorCodes()
if len(codes) == 0 {
return true // 启用但列表为空fallback到默认策略
}
// 检查是否在自定义列表中
for _, code := range codes {
if code == statusCode {
return true
}
}
return false
}

View File

@@ -1,20 +0,0 @@
package model
import (
"time"
)
type AccountGroup struct {
AccountID int64 `gorm:"primaryKey" json:"account_id"`
GroupID int64 `gorm:"primaryKey" json:"group_id"`
Priority int `gorm:"default:50;not null" json:"priority"` // 分组内优先级
CreatedAt time.Time `gorm:"not null" json:"created_at"`
// 关联
Account *Account `gorm:"foreignKey:AccountID" json:"account,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func (AccountGroup) TableName() string {
return "account_groups"
}

View File

@@ -1,32 +0,0 @@
package model
import (
"time"
"gorm.io/gorm"
)
type ApiKey struct {
ID int64 `gorm:"primaryKey" json:"id"`
UserID int64 `gorm:"index;not null" json:"user_id"`
Key string `gorm:"uniqueIndex;size:128;not null" json:"key"` // sk-xxx
Name string `gorm:"size:100;not null" json:"name"`
GroupID *int64 `gorm:"index" json:"group_id"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func (ApiKey) TableName() string {
return "api_keys"
}
// IsActive 检查是否激活
func (k *ApiKey) IsActive() bool {
return k.Status == "active"
}

View File

@@ -1,73 +0,0 @@
package model
import (
"time"
"gorm.io/gorm"
)
// 订阅类型常量
const (
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
)
type Group struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
Description string `gorm:"type:text" json:"description"`
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
// 订阅功能字段
SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription
DailyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"monthly_limit_usd"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
AccountGroups []AccountGroup `gorm:"foreignKey:GroupID" json:"account_groups,omitempty"`
// 虚拟字段 (不存储到数据库)
AccountCount int64 `gorm:"-" json:"account_count,omitempty"`
}
func (Group) TableName() string {
return "groups"
}
// IsActive 检查是否激活
func (g *Group) IsActive() bool {
return g.Status == "active"
}
// IsSubscriptionType 检查是否为订阅类型分组
func (g *Group) IsSubscriptionType() bool {
return g.SubscriptionType == SubscriptionTypeSubscription
}
// IsFreeSubscription 检查是否为免费订阅(不扣余额但有限额)
func (g *Group) IsFreeSubscription() bool {
return g.IsSubscriptionType() && g.RateMultiplier == 0
}
// HasDailyLimit 检查是否有日限额
func (g *Group) HasDailyLimit() bool {
return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0
}
// HasWeeklyLimit 检查是否有周限额
func (g *Group) HasWeeklyLimit() bool {
return g.WeeklyLimitUSD != nil && *g.WeeklyLimitUSD > 0
}
// HasMonthlyLimit 检查是否有月限额
func (g *Group) HasMonthlyLimit() bool {
return g.MonthlyLimitUSD != nil && *g.MonthlyLimitUSD > 0
}

View File

@@ -1,64 +0,0 @@
package model
import (
"gorm.io/gorm"
)
// AutoMigrate 自动迁移所有模型
func AutoMigrate(db *gorm.DB) error {
return db.AutoMigrate(
&User{},
&ApiKey{},
&Group{},
&Account{},
&AccountGroup{},
&Proxy{},
&RedeemCode{},
&UsageLog{},
&Setting{},
&UserSubscription{},
)
}
// 状态常量
const (
StatusActive = "active"
StatusDisabled = "disabled"
StatusError = "error"
StatusUnused = "unused"
StatusUsed = "used"
StatusExpired = "expired"
)
// 角色常量
const (
RoleAdmin = "admin"
RoleUser = "user"
)
// 平台常量
const (
PlatformAnthropic = "anthropic"
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
)
// 账号类型常量
const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeApiKey = "apikey" // API Key类型账号
)
// 卡密类型常量
const (
RedeemTypeBalance = "balance"
RedeemTypeConcurrency = "concurrency"
RedeemTypeSubscription = "subscription"
)
// 管理员调整类型常量
const (
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
)

View File

@@ -1,45 +0,0 @@
package model
import (
"fmt"
"time"
"gorm.io/gorm"
)
type Proxy struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"size:100;not null" json:"name"`
Protocol string `gorm:"size:20;not null" json:"protocol"` // http/https/socks5
Host string `gorm:"size:255;not null" json:"host"`
Port int `gorm:"not null" json:"port"`
Username string `gorm:"size:100" json:"username"`
Password string `gorm:"size:100" json:"-"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
}
func (Proxy) TableName() string {
return "proxies"
}
// IsActive 检查是否激活
func (p *Proxy) IsActive() bool {
return p.Status == "active"
}
// URL 返回代理URL
func (p *Proxy) URL() string {
if p.Username != "" && p.Password != "" {
return fmt.Sprintf("%s://%s:%s@%s:%d", p.Protocol, p.Username, p.Password, p.Host, p.Port)
}
return fmt.Sprintf("%s://%s:%d", p.Protocol, p.Host, p.Port)
}
// ProxyWithAccountCount extends Proxy with account count information
type ProxyWithAccountCount struct {
Proxy
AccountCount int64 `json:"account_count"`
}

View File

@@ -1,47 +0,0 @@
package model
import (
"crypto/rand"
"encoding/hex"
"time"
)
type RedeemCode struct {
ID int64 `gorm:"primaryKey" json:"id"`
Code string `gorm:"uniqueIndex;size:32;not null" json:"code"`
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used
UsedBy *int64 `gorm:"index" json:"used_by"`
UsedAt *time.Time `json:"used_at"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
// 订阅类型专用字段
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用)
// 关联
User *User `gorm:"foreignKey:UsedBy" json:"user,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func (RedeemCode) TableName() string {
return "redeem_codes"
}
// IsUsed 检查是否已使用
func (r *RedeemCode) IsUsed() bool {
return r.Status == "used"
}
// CanUse 检查是否可以使用
func (r *RedeemCode) CanUse() bool {
return r.Status == "unused"
}
// GenerateRedeemCode 生成唯一的兑换码
func GenerateRedeemCode() string {
b := make([]byte, 16)
rand.Read(b)
return hex.EncodeToString(b)
}

View File

@@ -1,95 +0,0 @@
package model
import (
"time"
)
// Setting 系统设置模型Key-Value存储
type Setting struct {
ID int64 `gorm:"primaryKey" json:"id"`
Key string `gorm:"uniqueIndex;size:100;not null" json:"key"`
Value string `gorm:"type:text;not null" json:"value"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
}
func (Setting) TableName() string {
return "settings"
}
// 设置Key常量
const (
// 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
// 邮件服务设置
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
SettingKeySmtpPort = "smtp_port" // SMTP端口
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
SettingKeySmtpPassword = "smtp_password" // SMTP密码加密存储
SettingKeySmtpFrom = "smtp_from" // 发件人地址
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
// Cloudflare Turnstile 设置
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
// OEM设置
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyApiBaseUrl = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyContactInfo = "contact_info" // 客服联系方式
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
)
// SystemSettings 系统设置结构体用于API响应
type SystemSettings struct {
// 注册设置
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
// 邮件服务设置
SmtpHost string `json:"smtp_host"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password,omitempty"` // 不返回明文密码
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
// Cloudflare Turnstile 设置
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"` // 不返回明文密钥
// OEM设置
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
}
// PublicSettings 公开设置(无需登录即可获取)
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
Version string `json:"version"`
}

View File

@@ -1,67 +0,0 @@
package model
import (
"time"
)
// 消费类型常量
const (
BillingTypeBalance int8 = 0 // 钱包余额
BillingTypeSubscription int8 = 1 // 订阅套餐
)
type UsageLog struct {
ID int64 `gorm:"primaryKey" json:"id"`
UserID int64 `gorm:"index;not null" json:"user_id"`
ApiKeyID int64 `gorm:"index;not null" json:"api_key_id"`
AccountID int64 `gorm:"index;not null" json:"account_id"`
RequestID string `gorm:"size:64" json:"request_id"`
Model string `gorm:"size:100;index;not null" json:"model"`
// 订阅关联(可选)
GroupID *int64 `gorm:"index" json:"group_id"`
SubscriptionID *int64 `gorm:"index" json:"subscription_id"`
// Token使用量4类
InputTokens int `gorm:"default:0;not null" json:"input_tokens"`
OutputTokens int `gorm:"default:0;not null" json:"output_tokens"`
CacheCreationTokens int `gorm:"default:0;not null" json:"cache_creation_tokens"`
CacheReadTokens int `gorm:"default:0;not null" json:"cache_read_tokens"`
// 详细的缓存创建分类
CacheCreation5mTokens int `gorm:"default:0;not null" json:"cache_creation_5m_tokens"`
CacheCreation1hTokens int `gorm:"default:0;not null" json:"cache_creation_1h_tokens"`
// 费用USD
InputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"input_cost"`
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率
// 元数据
BillingType int8 `gorm:"type:smallint;default:0;not null" json:"billing_type"` // 0=余额 1=订阅
Stream bool `gorm:"default:false;not null" json:"stream"`
DurationMs *int `json:"duration_ms"`
FirstTokenMs *int `json:"first_token_ms"` // 首字时间(流式请求)
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
// 关联
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
ApiKey *ApiKey `gorm:"foreignKey:ApiKeyID" json:"api_key,omitempty"`
Account *Account `gorm:"foreignKey:AccountID" json:"account,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
Subscription *UserSubscription `gorm:"foreignKey:SubscriptionID" json:"subscription,omitempty"`
}
func (UsageLog) TableName() string {
return "usage_logs"
}
// TotalTokens 总token数
func (u *UsageLog) TotalTokens() int {
return u.InputTokens + u.OutputTokens + u.CacheCreationTokens + u.CacheReadTokens
}

View File

@@ -1,74 +0,0 @@
package model
import (
"time"
"github.com/lib/pq"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type User struct {
ID int64 `gorm:"primaryKey" json:"id"`
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
PasswordHash string `gorm:"size:255;not null" json:"-"`
Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user
Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
Concurrency int `gorm:"default:5;not null" json:"concurrency"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
AllowedGroups pq.Int64Array `gorm:"type:bigint[]" json:"allowed_groups"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
ApiKeys []ApiKey `gorm:"foreignKey:UserID" json:"api_keys,omitempty"`
}
func (User) TableName() string {
return "users"
}
// IsAdmin 检查是否管理员
func (u *User) IsAdmin() bool {
return u.Role == "admin"
}
// IsActive 检查是否激活
func (u *User) IsActive() bool {
return u.Status == "active"
}
// CanBindGroup 检查是否可以绑定指定分组
// 对于标准类型分组:
// - 如果 AllowedGroups 设置了值(非空数组),只能绑定列表中的分组
// - 如果 AllowedGroups 为 nil 或空数组,可以绑定所有非专属分组
func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool {
// 如果设置了 allowed_groups 且不为空,只能绑定指定的分组
if len(u.AllowedGroups) > 0 {
for _, id := range u.AllowedGroups {
if id == groupID {
return true
}
}
return false
}
// 如果没有设置 allowed_groups 或为空数组,可以绑定所有非专属分组
return !isExclusive
}
// SetPassword 设置密码(哈希存储)
func (u *User) SetPassword(password string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
u.PasswordHash = string(hash)
return nil
}
// CheckPassword 验证密码
func (u *User) CheckPassword(password string) bool {
err := bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password))
return err == nil
}

View File

@@ -1,157 +0,0 @@
package model
import (
"time"
)
// 订阅状态常量
const (
SubscriptionStatusActive = "active"
SubscriptionStatusExpired = "expired"
SubscriptionStatusSuspended = "suspended"
)
// UserSubscription 用户订阅模型
type UserSubscription struct {
ID int64 `gorm:"primaryKey" json:"id"`
UserID int64 `gorm:"index;not null" json:"user_id"`
GroupID int64 `gorm:"index;not null" json:"group_id"`
// 订阅有效期
StartsAt time.Time `gorm:"not null" json:"starts_at"`
ExpiresAt time.Time `gorm:"not null" json:"expires_at"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/expired/suspended
// 滑动窗口起始时间nil = 未激活)
DailyWindowStart *time.Time `json:"daily_window_start"`
WeeklyWindowStart *time.Time `json:"weekly_window_start"`
MonthlyWindowStart *time.Time `json:"monthly_window_start"`
// 当前窗口已用额度USD基于 total_cost 计算)
DailyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"daily_usage_usd"`
WeeklyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"weekly_usage_usd"`
MonthlyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"monthly_usage_usd"`
// 管理员分配信息
AssignedBy *int64 `gorm:"index" json:"assigned_by"`
AssignedAt time.Time `gorm:"not null" json:"assigned_at"`
Notes string `gorm:"type:text" json:"notes"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
// 关联
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
AssignedByUser *User `gorm:"foreignKey:AssignedBy" json:"assigned_by_user,omitempty"`
}
func (UserSubscription) TableName() string {
return "user_subscriptions"
}
// IsActive 检查订阅是否有效状态为active且未过期
func (s *UserSubscription) IsActive() bool {
return s.Status == SubscriptionStatusActive && time.Now().Before(s.ExpiresAt)
}
// IsExpired 检查订阅是否已过期
func (s *UserSubscription) IsExpired() bool {
return time.Now().After(s.ExpiresAt)
}
// DaysRemaining 返回订阅剩余天数
func (s *UserSubscription) DaysRemaining() int {
if s.IsExpired() {
return 0
}
return int(time.Until(s.ExpiresAt).Hours() / 24)
}
// IsWindowActivated 检查窗口是否已激活
func (s *UserSubscription) IsWindowActivated() bool {
return s.DailyWindowStart != nil || s.WeeklyWindowStart != nil || s.MonthlyWindowStart != nil
}
// NeedsDailyReset 检查日窗口是否需要重置
func (s *UserSubscription) NeedsDailyReset() bool {
if s.DailyWindowStart == nil {
return false
}
return time.Since(*s.DailyWindowStart) >= 24*time.Hour
}
// NeedsWeeklyReset 检查周窗口是否需要重置
func (s *UserSubscription) NeedsWeeklyReset() bool {
if s.WeeklyWindowStart == nil {
return false
}
return time.Since(*s.WeeklyWindowStart) >= 7*24*time.Hour
}
// NeedsMonthlyReset 检查月窗口是否需要重置
func (s *UserSubscription) NeedsMonthlyReset() bool {
if s.MonthlyWindowStart == nil {
return false
}
return time.Since(*s.MonthlyWindowStart) >= 30*24*time.Hour
}
// DailyResetTime 返回日窗口重置时间
func (s *UserSubscription) DailyResetTime() *time.Time {
if s.DailyWindowStart == nil {
return nil
}
t := s.DailyWindowStart.Add(24 * time.Hour)
return &t
}
// WeeklyResetTime 返回周窗口重置时间
func (s *UserSubscription) WeeklyResetTime() *time.Time {
if s.WeeklyWindowStart == nil {
return nil
}
t := s.WeeklyWindowStart.Add(7 * 24 * time.Hour)
return &t
}
// MonthlyResetTime 返回月窗口重置时间
func (s *UserSubscription) MonthlyResetTime() *time.Time {
if s.MonthlyWindowStart == nil {
return nil
}
t := s.MonthlyWindowStart.Add(30 * 24 * time.Hour)
return &t
}
// CheckDailyLimit 检查是否超出日限额
func (s *UserSubscription) CheckDailyLimit(group *Group, additionalCost float64) bool {
if !group.HasDailyLimit() {
return true // 无限制
}
return s.DailyUsageUSD+additionalCost <= *group.DailyLimitUSD
}
// CheckWeeklyLimit 检查是否超出周限额
func (s *UserSubscription) CheckWeeklyLimit(group *Group, additionalCost float64) bool {
if !group.HasWeeklyLimit() {
return true // 无限制
}
return s.WeeklyUsageUSD+additionalCost <= *group.WeeklyLimitUSD
}
// CheckMonthlyLimit 检查是否超出月限额
func (s *UserSubscription) CheckMonthlyLimit(group *Group, additionalCost float64) bool {
if !group.HasMonthlyLimit() {
return true // 无限制
}
return s.MonthlyUsageUSD+additionalCost <= *group.MonthlyLimitUSD
}
// CheckAllLimits 检查所有限额
func (s *UserSubscription) CheckAllLimits(group *Group, additionalCost float64) (daily, weekly, monthly bool) {
daily = s.CheckDailyLimit(group, additionalCost)
weekly = s.CheckWeeklyLimit(group, additionalCost)
monthly = s.CheckMonthlyLimit(group, additionalCost)
return
}

View File

@@ -0,0 +1,74 @@
package claude
// Claude Code 客户端相关常量
// Beta header 常量
const (
BetaOAuth = "oauth-2025-04-20"
BetaClaudeCode = "claude-code-20250219"
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
)
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header不需要 claude-code beta
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
// Claude Code 客户端默认请求头
var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)",
"X-Stainless-Lang": "js",
"X-Stainless-Package-Version": "0.52.0",
"X-Stainless-OS": "Linux",
"X-Stainless-Arch": "x64",
"X-Stainless-Runtime": "node",
"X-Stainless-Runtime-Version": "v22.14.0",
"X-Stainless-Retry-Count": "0",
"X-Stainless-Timeout": "60",
"X-App": "cli",
"Anthropic-Dangerous-Direct-Browser-Access": "true",
}
// Model 表示一个 Claude 模型
type Model struct {
ID string `json:"id"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
CreatedAt string `json:"created_at"`
}
// DefaultModels Claude Code 客户端支持的默认模型列表
var DefaultModels = []Model{
{
ID: "claude-opus-4-5-20251101",
Type: "model",
DisplayName: "Claude Opus 4.5",
CreatedAt: "2025-11-01T00:00:00Z",
},
{
ID: "claude-sonnet-4-5-20250929",
Type: "model",
DisplayName: "Claude Sonnet 4.5",
CreatedAt: "2025-09-29T00:00:00Z",
},
{
ID: "claude-haiku-4-5-20251001",
Type: "model",
DisplayName: "Claude Haiku 4.5",
CreatedAt: "2025-10-01T00:00:00Z",
},
}
// DefaultModelIDs 返回默认模型的 ID 列表
func DefaultModelIDs() []string {
ids := make([]string, len(DefaultModels))
for i, m := range DefaultModels {
ids[i] = m.ID
}
return ids
}
// DefaultTestModel 测试时使用的默认模型
const DefaultTestModel = "claude-sonnet-4-5-20250929"

View File

@@ -0,0 +1,42 @@
package gemini
// This package provides minimal fallback model metadata for Gemini native endpoints.
// It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes).
type Model struct {
Name string `json:"name"`
DisplayName string `json:"displayName,omitempty"`
Description string `json:"description,omitempty"`
SupportedGenerationMethods []string `json:"supportedGenerationMethods,omitempty"`
}
type ModelsListResponse struct {
Models []Model `json:"models"`
}
func DefaultModels() []Model {
methods := []string{"generateContent", "streamGenerateContent"}
return []Model{
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.0-flash-lite", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-pro", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-flash-8b", SupportedGenerationMethods: methods},
}
}
func FallbackModelsList() ModelsListResponse {
return ModelsListResponse{Models: DefaultModels()}
}
func FallbackModel(model string) Model {
methods := []string{"generateContent", "streamGenerateContent"}
if model == "" {
return Model{Name: "models/unknown", SupportedGenerationMethods: methods}
}
if len(model) >= 7 && model[:7] == "models/" {
return Model{Name: model, SupportedGenerationMethods: methods}
}
return Model{Name: "models/" + model, SupportedGenerationMethods: methods}
}

View File

@@ -0,0 +1,38 @@
package geminicli
// LoadCodeAssistRequest matches done-hub's internal Code Assist call.
type LoadCodeAssistRequest struct {
Metadata LoadCodeAssistMetadata `json:"metadata"`
}
type LoadCodeAssistMetadata struct {
IDEType string `json:"ideType"`
Platform string `json:"platform"`
PluginType string `json:"pluginType"`
}
type LoadCodeAssistResponse struct {
CurrentTier string `json:"currentTier,omitempty"`
CloudAICompanionProject string `json:"cloudaicompanionProject,omitempty"`
AllowedTiers []AllowedTier `json:"allowedTiers,omitempty"`
}
type AllowedTier struct {
ID string `json:"id"`
IsDefault bool `json:"isDefault,omitempty"`
}
type OnboardUserRequest struct {
TierID string `json:"tierId"`
Metadata LoadCodeAssistMetadata `json:"metadata"`
}
type OnboardUserResponse struct {
Done bool `json:"done"`
Response *OnboardUserResultData `json:"response,omitempty"`
Name string `json:"name,omitempty"`
}
type OnboardUserResultData struct {
CloudAICompanionProject any `json:"cloudaicompanionProject,omitempty"`
}

View File

@@ -0,0 +1,42 @@
package geminicli
import "time"
const (
AIStudioBaseURL = "https://generativelanguage.googleapis.com"
GeminiCliBaseURL = "https://cloudcode-pa.googleapis.com"
AuthorizeURL = "https://accounts.google.com/o/oauth2/v2/auth"
TokenURL = "https://oauth2.googleapis.com/token"
// AIStudioOAuthRedirectURI is the default redirect URI used for AI Studio OAuth.
// This matches the "copy/paste callback URL" flow used by OpenAI OAuth in this project.
// Note: You still need to register this redirect URI in your Google OAuth client
// unless you use an OAuth client type that permits localhost redirect URIs.
AIStudioOAuthRedirectURI = "http://localhost:1455/auth/callback"
// DefaultScopes for Code Assist (includes cloud-platform for API access plus userinfo scopes)
// Required by Google's Code Assist API.
DefaultCodeAssistScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile"
// DefaultScopes for AI Studio (uses generativelanguage API with OAuth)
// Reference: https://ai.google.dev/gemini-api/docs/oauth
// For regular Google accounts, supports API calls to generativelanguage.googleapis.com
// Note: Google Auth platform currently documents the OAuth scope as
// https://www.googleapis.com/auth/generative-language.retriever (often with cloud-platform).
DefaultAIStudioScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever"
// GeminiCLIRedirectURI is the redirect URI used by Gemini CLI for Code Assist OAuth.
GeminiCLIRedirectURI = "https://codeassist.google.com/authcode"
// GeminiCLIOAuthClientID/Secret are the public OAuth client credentials used by Google Gemini CLI.
// They enable the "login without creating your own OAuth client" experience, but Google may
// restrict which scopes are allowed for this client.
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
GeminiCLIOAuthClientSecret = "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"
SessionTTL = 30 * time.Minute
// GeminiCLIUserAgent mimics Gemini CLI to maximize compatibility with internal endpoints.
GeminiCLIUserAgent = "GeminiCLI/0.1.5 (Windows; AMD64)"
)

View File

@@ -0,0 +1,21 @@
package geminicli
// Model represents a selectable Gemini model for UI/testing purposes.
// Keep JSON fields consistent with existing frontend expectations.
type Model struct {
ID string `json:"id"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
CreatedAt string `json:"created_at"`
}
// DefaultModels is the curated Gemini model list used by the admin UI "test account" flow.
var DefaultModels = []Model{
{ID: "gemini-3-pro", Type: "model", DisplayName: "Gemini 3 Pro", CreatedAt: ""},
{ID: "gemini-3-flash", Type: "model", DisplayName: "Gemini 3 Flash", CreatedAt: ""},
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
}
// DefaultTestModel is the default model to preselect in test flows.
const DefaultTestModel = "gemini-2.5-pro"

View File

@@ -0,0 +1,243 @@
package geminicli
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"net/url"
"strings"
"sync"
"time"
)
type OAuthConfig struct {
ClientID string
ClientSecret string
Scopes string
}
type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
ProxyURL string `json:"proxy_url,omitempty"`
RedirectURI string `json:"redirect_uri"`
ProjectID string `json:"project_id,omitempty"`
OAuthType string `json:"oauth_type"` // "code_assist" 或 "ai_studio"
CreatedAt time.Time `json:"created_at"`
}
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*OAuthSession
stopCh chan struct{}
}
func NewSessionStore() *SessionStore {
store := &SessionStore{
sessions: make(map[string]*OAuthSession),
stopCh: make(chan struct{}),
}
go store.cleanup()
return store
}
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[sessionID] = session
}
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
session, ok := s.sessions[sessionID]
if !ok {
return nil, false
}
if time.Since(session.CreatedAt) > SessionTTL {
return nil, false
}
return session, true
}
func (s *SessionStore) Delete(sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionID)
}
func (s *SessionStore) Stop() {
select {
case <-s.stopCh:
return
default:
close(s.stopCh)
}
}
func (s *SessionStore) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-s.stopCh:
return
case <-ticker.C:
s.mu.Lock()
for id, session := range s.sessions {
if time.Since(session.CreatedAt) > SessionTTL {
delete(s.sessions, id)
}
}
s.mu.Unlock()
}
}
}
func GenerateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
func GenerateState() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return base64URLEncode(bytes), nil
}
func GenerateSessionID() (string, error) {
bytes, err := GenerateRandomBytes(16)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateCodeVerifier returns an RFC 7636 compatible code verifier (43+ chars).
func GenerateCodeVerifier() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return base64URLEncode(bytes), nil
}
func GenerateCodeChallenge(verifier string) string {
hash := sha256.Sum256([]byte(verifier))
return base64URLEncode(hash[:])
}
func base64URLEncode(data []byte) string {
return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=")
}
// EffectiveOAuthConfig returns the effective OAuth configuration.
// oauthType: "code_assist" or "ai_studio" (defaults to "code_assist" if empty).
//
// If ClientID/ClientSecret is not provided, this falls back to the built-in Gemini CLI OAuth client.
//
// Note: The built-in Gemini CLI OAuth client is restricted and may reject some scopes (e.g.
// https://www.googleapis.com/auth/generative-language), which will surface as
// "restricted_client" / "Unregistered scope(s)" errors during browser authorization.
func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error) {
effective := OAuthConfig{
ClientID: strings.TrimSpace(cfg.ClientID),
ClientSecret: strings.TrimSpace(cfg.ClientSecret),
Scopes: strings.TrimSpace(cfg.Scopes),
}
// Normalize scopes: allow comma-separated input but send space-delimited scopes to Google.
if effective.Scopes != "" {
effective.Scopes = strings.Join(strings.Fields(strings.ReplaceAll(effective.Scopes, ",", " ")), " ")
}
// Fall back to built-in Gemini CLI OAuth client when not configured.
if effective.ClientID == "" && effective.ClientSecret == "" {
effective.ClientID = GeminiCLIOAuthClientID
effective.ClientSecret = GeminiCLIOAuthClientSecret
} else if effective.ClientID == "" || effective.ClientSecret == "" {
return OAuthConfig{}, fmt.Errorf("OAuth client not configured: please set both client_id and client_secret (or leave both empty to use the built-in Gemini CLI client)")
}
isBuiltinClient := effective.ClientID == GeminiCLIOAuthClientID &&
effective.ClientSecret == GeminiCLIOAuthClientSecret
if effective.Scopes == "" {
// Use different default scopes based on OAuth type
if oauthType == "ai_studio" {
// Built-in client can't request some AI Studio scopes (notably generative-language).
if isBuiltinClient {
effective.Scopes = DefaultCodeAssistScopes
} else {
effective.Scopes = DefaultAIStudioScopes
}
} else {
// Default to Code Assist scopes
effective.Scopes = DefaultCodeAssistScopes
}
} else if oauthType == "ai_studio" && isBuiltinClient {
// If user overrides scopes while still using the built-in client, strip restricted scopes.
parts := strings.Fields(effective.Scopes)
filtered := make([]string, 0, len(parts))
for _, s := range parts {
if strings.Contains(s, "generative-language") {
continue
}
filtered = append(filtered, s)
}
if len(filtered) == 0 {
effective.Scopes = DefaultCodeAssistScopes
} else {
effective.Scopes = strings.Join(filtered, " ")
}
}
// Backward compatibility: normalize older AI Studio scope to the currently documented one.
if oauthType == "ai_studio" && effective.Scopes != "" {
parts := strings.Fields(effective.Scopes)
for i := range parts {
if parts[i] == "https://www.googleapis.com/auth/generative-language" {
parts[i] = "https://www.googleapis.com/auth/generative-language.retriever"
}
}
effective.Scopes = strings.Join(parts, " ")
}
return effective, nil
}
func BuildAuthorizationURL(cfg OAuthConfig, state, codeChallenge, redirectURI, projectID, oauthType string) (string, error) {
effectiveCfg, err := EffectiveOAuthConfig(cfg, oauthType)
if err != nil {
return "", err
}
redirectURI = strings.TrimSpace(redirectURI)
if redirectURI == "" {
return "", fmt.Errorf("redirect_uri is required")
}
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", effectiveCfg.ClientID)
params.Set("redirect_uri", redirectURI)
params.Set("scope", effectiveCfg.Scopes)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
params.Set("access_type", "offline")
params.Set("prompt", "consent")
params.Set("include_granted_scopes", "true")
if strings.TrimSpace(projectID) != "" {
params.Set("project_id", strings.TrimSpace(projectID))
}
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode()), nil
}

View File

@@ -0,0 +1,46 @@
package geminicli
import "strings"
const maxLogBodyLen = 2048
func SanitizeBodyForLogs(body string) string {
body = truncateBase64InMessage(body)
if len(body) > maxLogBodyLen {
body = body[:maxLogBodyLen] + "...[truncated]"
}
return body
}
func truncateBase64InMessage(message string) string {
const maxBase64Length = 50
result := message
offset := 0
for {
idx := strings.Index(result[offset:], ";base64,")
if idx == -1 {
break
}
actualIdx := offset + idx
start := actualIdx + len(";base64,")
end := start
for end < len(result) && isBase64Char(result[end]) {
end++
}
if end-start > maxBase64Length {
result = result[:start+maxBase64Length] + "...[truncated]" + result[end:]
offset = start + maxBase64Length + len("...[truncated]")
continue
}
offset = end
}
return result
}
func isBase64Char(c byte) bool {
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '+' || c == '/' || c == '='
}

View File

@@ -0,0 +1,9 @@
package geminicli
type TokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
Scope string `json:"scope,omitempty"`
}

View File

@@ -0,0 +1,24 @@
package googleapi
import "net/http"
// HTTPStatusToGoogleStatus maps HTTP status codes to Google-style error status strings.
func HTTPStatusToGoogleStatus(status int) string {
switch status {
case http.StatusBadRequest:
return "INVALID_ARGUMENT"
case http.StatusUnauthorized:
return "UNAUTHENTICATED"
case http.StatusForbidden:
return "PERMISSION_DENIED"
case http.StatusNotFound:
return "NOT_FOUND"
case http.StatusTooManyRequests:
return "RESOURCE_EXHAUSTED"
default:
if status >= 500 {
return "INTERNAL"
}
return "UNKNOWN"
}
}

View File

@@ -43,18 +43,25 @@ type OAuthSession struct {
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*OAuthSession
stopCh chan struct{}
}
// NewSessionStore creates a new session store
func NewSessionStore() *SessionStore {
store := &SessionStore{
sessions: make(map[string]*OAuthSession),
stopCh: make(chan struct{}),
}
// Start cleanup goroutine
go store.cleanup()
return store
}
// Stop stops the cleanup goroutine
func (s *SessionStore) Stop() {
close(s.stopCh)
}
// Set stores a session
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
s.mu.Lock()
@@ -87,14 +94,20 @@ func (s *SessionStore) Delete(sessionID string) {
// cleanup removes expired sessions periodically
func (s *SessionStore) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
s.mu.Lock()
for id, session := range s.sessions {
if time.Since(session.CreatedAt) > SessionTTL {
delete(s.sessions, id)
defer ticker.Stop()
for {
select {
case <-s.stopCh:
return
case <-ticker.C:
s.mu.Lock()
for id, session := range s.sessions {
if time.Since(session.CreatedAt) > SessionTTL {
delete(s.sessions, id)
}
}
s.mu.Unlock()
}
s.mu.Unlock()
}
}

View File

@@ -0,0 +1,42 @@
package openai
import _ "embed"
// Model represents an OpenAI model
type Model struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
}
// DefaultModels OpenAI models list
var DefaultModels = []Model{
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
{ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
{ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},
{ID: "gpt-5.1-codex", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex"},
{ID: "gpt-5.1", Object: "model", Created: 1731456000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1"},
{ID: "gpt-5.1-codex-mini", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Mini"},
{ID: "gpt-5", Object: "model", Created: 1722988800, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5"},
}
// DefaultModelIDs returns the default model ID list
func DefaultModelIDs() []string {
ids := make([]string, len(DefaultModels))
for i, m := range DefaultModels {
ids[i] = m.ID
}
return ids
}
// DefaultTestModel default model for testing OpenAI accounts
const DefaultTestModel = "gpt-5.1-codex"
// DefaultInstructions default instructions for non-Codex CLI requests
// Content loaded from instructions.txt at compile time
//
//go:embed instructions.txt
var DefaultInstructions string

View File

@@ -0,0 +1,118 @@
You are Codex, based on GPT-5. You are running as a coding agent in the Codex CLI on a user's computer.
## General
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
## Editing constraints
- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
- Add succinct code comments that explain what is going on if code is not self-explanatory. You should not add comments like \"Assigns the value to the variable\", but a brief comment might be useful ahead of a complex code block that the user would otherwise have to spend time parsing out. Usage of these comments should be rare.
- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase).
- You may be in a dirty git worktree.
* NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
* If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
* If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
* If the changes are in unrelated files, just ignore them and don't revert them.
- Do not amend a commit unless explicitly requested to do so.
- While you are working, you might notice unexpected changes that you didn't make. If this happens, STOP IMMEDIATELY and ask the user how they would like to proceed.
- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user.
## Plan tool
When using the planning tool:
- Skip using the planning tool for straightforward tasks (roughly the easiest 25%).
- Do not make single-step plans.
- When you made a plan, update it after having performed one of the sub-tasks that you shared on the plan.
## Codex CLI harness, sandboxing, and approvals
The Codex CLI harness supports several different configurations for sandboxing and escalation approvals that the user can choose from.
Filesystem sandboxing defines which files can be read or written. The options for `sandbox_mode` are:
- **read-only**: The sandbox only permits reading files.
- **workspace-write**: The sandbox permits reading files, and editing files in `cwd` and `writable_roots`. Editing files in other directories requires approval.
- **danger-full-access**: No filesystem sandboxing - all commands are permitted.
Network sandboxing defines whether network can be accessed without approval. Options for `network_access` are:
- **restricted**: Requires approval
- **enabled**: No approval needed
Approvals are your mechanism to get user consent to run shell commands without the sandbox. Possible configuration options for `approval_policy` are
- **untrusted**: The harness will escalate most commands for user approval, apart from a limited allowlist of safe \"read\" commands.
- **on-failure**: The harness will allow all commands to run in the sandbox (if enabled), and failures will be escalated to the user for approval to run again without the sandbox.
- **on-request**: Commands will be run in the sandbox by default, and you can specify in your tool call if you want to escalate a command to run without sandboxing. (Note that this mode is not always available. If it is, you'll see parameters for it in the `shell` command description.)
- **never**: This is a non-interactive mode where you may NEVER ask the user for approval to run commands. Instead, you must always persist and work around constraints to solve the task for the user. You MUST do your utmost best to finish the task and validate your work before yielding. If this mode is paired with `danger-full-access`, take advantage of it to deliver the best outcome for the user. Further, in this mode, your default testing philosophy is overridden: Even if you don't see local patterns for testing, you may add tests and scripts to validate your work. Just remove them before yielding.
When you are running with `approval_policy == on-request`, and sandboxing enabled, here are scenarios where you'll need to request approval:
- You need to run a command that writes to a directory that requires it (e.g. running tests that write to /var)
- You need to run a GUI app (e.g., open/xdg-open/osascript) to open browsers or files.
- You are running sandboxed and need to run a command that requires network access (e.g. installing packages)
- If you run a command that is important to solving the user's query, but it fails because of sandboxing, rerun the command with approval. ALWAYS proceed to use the `sandbox_permissions` and `justification` parameters - do not message the user before requesting approval for the command.
- You are about to take a potentially destructive action such as an `rm` or `git reset` that the user did not explicitly ask for
- (for all of these, you should weigh alternative paths that do not require approval)
When `sandbox_mode` is set to read-only, you'll need to request approval for any command that isn't a read.
You will be told what filesystem sandboxing, network sandboxing, and approval mode are active in a developer or user message. If you are not told about this, assume that you are running with workspace-write, network sandboxing enabled, and approval on-failure.
Although they introduce friction to the user because your work is paused until the user responds, you should leverage them when necessary to accomplish important work. If the completing the task requires escalated permissions, Do not let these settings or the sandbox deter you from attempting to accomplish the user's task unless it is set to \"never\", in which case never ask for approvals.
When requesting approval to execute a command that will require escalated privileges:
- Provide the `sandbox_permissions` parameter with the value `\"require_escalated\"`
- Include a short, 1 sentence explanation for why you need escalated permissions in the justification parameter
## Special user requests
- If the user makes a simple request (such as asking for the time) which you can fulfill by running a terminal command (such as `date`), you should do so.
- If the user asks for a \"review\", default to a code review mindset: prioritise identifying bugs, risks, behavioural regressions, and missing tests. Findings must be the primary focus of the response - keep summaries or overviews brief and only after enumerating the issues. Present findings first (ordered by severity with file/line references), follow with open questions or assumptions, and offer a change-summary only as a secondary detail. If no findings are discovered, state that explicitly and mention any residual risks or testing gaps.
## Frontend tasks
When doing frontend design tasks, avoid collapsing into \"AI slop\" or safe, average-looking layouts.
Aim for interfaces that feel intentional, bold, and a bit surprising.
- Typography: Use expressive, purposeful fonts and avoid default stacks (Inter, Roboto, Arial, system).
- Color & Look: Choose a clear visual direction; define CSS variables; avoid purple-on-white defaults. No purple bias or dark mode bias.
- Motion: Use a few meaningful animations (page-load, staggered reveals) instead of generic micro-motions.
- Background: Don't rely on flat, single-color backgrounds; use gradients, shapes, or subtle patterns to build atmosphere.
- Overall: Avoid boilerplate layouts and interchangeable UI patterns. Vary themes, type families, and visual languages across outputs.
- Ensure the page loads properly on both desktop and mobile
Exception: If working within an existing website or design system, preserve the established patterns, structure, and visual language.
## Presenting your work and final message
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
- Default: be very concise; friendly coding teammate tone.
- Ask only when needed; suggest ideas; mirror the user's style.
- For substantial work, summarize clearly; follow finalanswer formatting.
- Skip heavy formatting for simple confirmations.
- Don't dump large files you've written; reference paths only.
- No \"save/copy this file\" - User is on the same machine.
- Offer logical next steps (tests, commits, build) briefly; add verify steps if you couldn't do something.
- For code changes:
* Lead with a quick explanation of the change, and then give more details on the context covering where and why a change was made. Do not start this explanation with \"summary\", just jump right in.
* If there are natural next steps the user may want to take, suggest them at the end of your response. Do not make suggestions if there are no natural next steps.
* When suggesting multiple options, use numeric lists for the suggestions so the user can quickly respond with a single number.
- The user does not command execution outputs. When asked to show the output of a command (e.g. `git show`), relay the important details in your answer or summarize the key lines so the user understands the result.
### Final answer structure and style guidelines
- Plain text; CLI handles styling. Use structure only when it helps scanability.
- Headers: optional; short Title Case (1-3 words) wrapped in **…**; no blank line before the first bullet; add only if they truly help.
- Bullets: use - ; merge related points; keep to one line when possible; 46 per list ordered by importance; keep phrasing consistent.
- Monospace: backticks for commands/paths/env vars/code ids and inline examples; use for literal keyword bullets; never combine with **.
- Code samples or multi-line snippets should be wrapped in fenced code blocks; include an info string as often as possible.
- Structure: group related bullets; order sections general → specific → supporting; for subsections, start with a bolded keyword bullet, then items; match complexity to the task.
- Tone: collaborative, concise, factual; present tense, active voice; selfcontained; no \"above/below\"; parallel wording.
- Don'ts: no nested bullets/hierarchies; no ANSI codes; don't cram unrelated keywords; keep keyword lists short—wrap/reformat if long; avoid naming formatting styles in answers.
- Adaptation: code explanations → precise, structured with code refs; simple tasks → lead with outcome; big changes → logical walkthrough + rationale + next actions; casual one-offs → plain sentences, no headers/bullets.
- File References: When referencing files in your response follow the below rules:
* Use inline code to make file paths clickable.
* Each reference should have a stand alone path. Even if it's the same file.
* Accepted: absolute, workspacerelative, a/ or b/ diff prefixes, or bare filename/suffix.
* Optionally include line/column (1based): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
* Do not use URIs like file://, vscode://, or https://.
* Do not provide range of lines
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\\repo\\project\\main.rs:12:5

View File

@@ -0,0 +1,366 @@
package openai
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"net/url"
"strings"
"sync"
"time"
)
// OpenAI OAuth Constants (from CRS project - Codex CLI client)
const (
// OAuth Client ID for OpenAI (Codex CLI official)
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
// OAuth endpoints
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
TokenURL = "https://auth.openai.com/oauth/token"
// Default redirect URI (can be customized)
DefaultRedirectURI = "http://localhost:1455/auth/callback"
// Scopes
DefaultScopes = "openid profile email offline_access"
// RefreshScopes - scope for token refresh (without offline_access, aligned with CRS project)
RefreshScopes = "openid profile email"
// Session TTL
SessionTTL = 30 * time.Minute
)
// OAuthSession stores OAuth flow state for OpenAI
type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
ProxyURL string `json:"proxy_url,omitempty"`
RedirectURI string `json:"redirect_uri"`
CreatedAt time.Time `json:"created_at"`
}
// SessionStore manages OAuth sessions in memory
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*OAuthSession
stopCh chan struct{}
}
// NewSessionStore creates a new session store
func NewSessionStore() *SessionStore {
store := &SessionStore{
sessions: make(map[string]*OAuthSession),
stopCh: make(chan struct{}),
}
// Start cleanup goroutine
go store.cleanup()
return store
}
// Set stores a session
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[sessionID] = session
}
// Get retrieves a session
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
session, ok := s.sessions[sessionID]
if !ok {
return nil, false
}
// Check if expired
if time.Since(session.CreatedAt) > SessionTTL {
return nil, false
}
return session, true
}
// Delete removes a session
func (s *SessionStore) Delete(sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionID)
}
// Stop stops the cleanup goroutine
func (s *SessionStore) Stop() {
close(s.stopCh)
}
// cleanup removes expired sessions periodically
func (s *SessionStore) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-s.stopCh:
return
case <-ticker.C:
s.mu.Lock()
for id, session := range s.sessions {
if time.Since(session.CreatedAt) > SessionTTL {
delete(s.sessions, id)
}
}
s.mu.Unlock()
}
}
}
// GenerateRandomBytes generates cryptographically secure random bytes
func GenerateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
// GenerateState generates a random state string for OAuth
func GenerateState() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateSessionID generates a unique session ID
func GenerateSessionID() (string, error) {
bytes, err := GenerateRandomBytes(16)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateCodeVerifier generates a PKCE code verifier (64 bytes -> hex for OpenAI)
// OpenAI uses hex encoding instead of base64url
func GenerateCodeVerifier() (string, error) {
bytes, err := GenerateRandomBytes(64)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
// Uses base64url encoding as per RFC 7636
func GenerateCodeChallenge(verifier string) string {
hash := sha256.Sum256([]byte(verifier))
return base64URLEncode(hash[:])
}
// base64URLEncode encodes bytes to base64url without padding
func base64URLEncode(data []byte) string {
encoded := base64.URLEncoding.EncodeToString(data)
// Remove padding
return strings.TrimRight(encoded, "=")
}
// BuildAuthorizationURL builds the OpenAI OAuth authorization URL
func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
if redirectURI == "" {
redirectURI = DefaultRedirectURI
}
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", ClientID)
params.Set("redirect_uri", redirectURI)
params.Set("scope", DefaultScopes)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
// OpenAI specific parameters
params.Set("id_token_add_organizations", "true")
params.Set("codex_cli_simplified_flow", "true")
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
// TokenRequest represents the token exchange request body
type TokenRequest struct {
GrantType string `json:"grant_type"`
ClientID string `json:"client_id"`
Code string `json:"code"`
RedirectURI string `json:"redirect_uri"`
CodeVerifier string `json:"code_verifier"`
}
// TokenResponse represents the token response from OpenAI OAuth
type TokenResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
}
// RefreshTokenRequest represents the refresh token request
type RefreshTokenRequest struct {
GrantType string `json:"grant_type"`
RefreshToken string `json:"refresh_token"`
ClientID string `json:"client_id"`
Scope string `json:"scope"`
}
// IDTokenClaims represents the claims from OpenAI ID Token
type IDTokenClaims struct {
// Standard claims
Sub string `json:"sub"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Iss string `json:"iss"`
Aud []string `json:"aud"` // OpenAI returns aud as an array
Exp int64 `json:"exp"`
Iat int64 `json:"iat"`
// OpenAI specific claims (nested under https://api.openai.com/auth)
OpenAIAuth *OpenAIAuthClaims `json:"https://api.openai.com/auth,omitempty"`
}
// OpenAIAuthClaims represents the OpenAI specific auth claims
type OpenAIAuthClaims struct {
ChatGPTAccountID string `json:"chatgpt_account_id"`
ChatGPTUserID string `json:"chatgpt_user_id"`
UserID string `json:"user_id"`
Organizations []OrganizationClaim `json:"organizations"`
}
// OrganizationClaim represents an organization in the ID Token
type OrganizationClaim struct {
ID string `json:"id"`
Role string `json:"role"`
Title string `json:"title"`
IsDefault bool `json:"is_default"`
}
// BuildTokenRequest creates a token exchange request for OpenAI
func BuildTokenRequest(code, codeVerifier, redirectURI string) *TokenRequest {
if redirectURI == "" {
redirectURI = DefaultRedirectURI
}
return &TokenRequest{
GrantType: "authorization_code",
ClientID: ClientID,
Code: code,
RedirectURI: redirectURI,
CodeVerifier: codeVerifier,
}
}
// BuildRefreshTokenRequest creates a refresh token request for OpenAI
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
return &RefreshTokenRequest{
GrantType: "refresh_token",
RefreshToken: refreshToken,
ClientID: ClientID,
Scope: RefreshScopes,
}
}
// ToFormData converts TokenRequest to URL-encoded form data
func (r *TokenRequest) ToFormData() string {
params := url.Values{}
params.Set("grant_type", r.GrantType)
params.Set("client_id", r.ClientID)
params.Set("code", r.Code)
params.Set("redirect_uri", r.RedirectURI)
params.Set("code_verifier", r.CodeVerifier)
return params.Encode()
}
// ToFormData converts RefreshTokenRequest to URL-encoded form data
func (r *RefreshTokenRequest) ToFormData() string {
params := url.Values{}
params.Set("grant_type", r.GrantType)
params.Set("client_id", r.ClientID)
params.Set("refresh_token", r.RefreshToken)
params.Set("scope", r.Scope)
return params.Encode()
}
// ParseIDToken parses the ID Token JWT and extracts claims
// Note: This does NOT verify the signature - it only decodes the payload
// For production, you should verify the token signature using OpenAI's public keys
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
parts := strings.Split(idToken, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
}
// Decode payload (second part)
payload := parts[1]
// Add padding if necessary
switch len(payload) % 4 {
case 2:
payload += "=="
case 3:
payload += "="
}
decoded, err := base64.URLEncoding.DecodeString(payload)
if err != nil {
// Try standard encoding
decoded, err = base64.StdEncoding.DecodeString(payload)
if err != nil {
return nil, fmt.Errorf("failed to decode JWT payload: %w", err)
}
}
var claims IDTokenClaims
if err := json.Unmarshal(decoded, &claims); err != nil {
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
}
return &claims, nil
}
// ExtractUserInfo extracts user information from ID Token claims
type UserInfo struct {
Email string
ChatGPTAccountID string
ChatGPTUserID string
UserID string
OrganizationID string
Organizations []OrganizationClaim
}
// GetUserInfo extracts user info from ID Token claims
func (c *IDTokenClaims) GetUserInfo() *UserInfo {
info := &UserInfo{
Email: c.Email,
}
if c.OpenAIAuth != nil {
info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID
info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID
info.UserID = c.OpenAIAuth.UserID
info.Organizations = c.OpenAIAuth.Organizations
// Get default organization ID
for _, org := range c.OpenAIAuth.Organizations {
if org.IsDefault {
info.OrganizationID = org.ID
break
}
}
// If no default, use first org
if info.OrganizationID == "" && len(c.OpenAIAuth.Organizations) > 0 {
info.OrganizationID = c.OpenAIAuth.Organizations[0].ID
}
}
return info
}

View File

@@ -0,0 +1,18 @@
package openai
// CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns
// Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2"
var CodexCLIUserAgentPrefixes = []string{
"codex_vscode/",
"codex_cli_rs/",
}
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
func IsCodexCLIRequest(userAgent string) bool {
for _, prefix := range CodexCLIUserAgentPrefixes {
if len(userAgent) >= len(prefix) && userAgent[:len(prefix)] == prefix {
return true
}
}
return false
}

View File

@@ -1,17 +1,4 @@
package repository
// Repositories 所有仓库的集合
type Repositories struct {
User *UserRepository
ApiKey *ApiKeyRepository
Group *GroupRepository
Account *AccountRepository
Proxy *ProxyRepository
RedeemCode *RedeemCodeRepository
UsageLog *UsageLogRepository
Setting *SettingRepository
UserSubscription *UserSubscriptionRepository
}
package pagination
// PaginationParams 分页参数
type PaginationParams struct {

View File

@@ -4,27 +4,30 @@ import (
"math"
"net/http"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/gin-gonic/gin"
)
// Response 标准API响应格式
type Response struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
Code int `json:"code"`
Message string `json:"message"`
Reason string `json:"reason,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
Data any `json:"data,omitempty"`
}
// PaginatedData 分页数据格式(匹配前端期望)
type PaginatedData struct {
Items interface{} `json:"items"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
Pages int `json:"pages"`
Items any `json:"items"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
Pages int `json:"pages"`
}
// Success 返回成功响应
func Success(c *gin.Context, data interface{}) {
func Success(c *gin.Context, data any) {
c.JSON(http.StatusOK, Response{
Code: 0,
Message: "success",
@@ -33,7 +36,7 @@ func Success(c *gin.Context, data interface{}) {
}
// Created 返回创建成功响应
func Created(c *gin.Context, data interface{}) {
func Created(c *gin.Context, data any) {
c.JSON(http.StatusCreated, Response{
Code: 0,
Message: "success",
@@ -44,11 +47,36 @@ func Created(c *gin.Context, data interface{}) {
// Error 返回错误响应
func Error(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, Response{
Code: statusCode,
Message: message,
Code: statusCode,
Message: message,
Reason: "",
Metadata: nil,
})
}
// ErrorWithDetails returns an error response compatible with the existing envelope while
// optionally providing structured error fields (reason/metadata).
func ErrorWithDetails(c *gin.Context, statusCode int, message, reason string, metadata map[string]string) {
c.JSON(statusCode, Response{
Code: statusCode,
Message: message,
Reason: reason,
Metadata: metadata,
})
}
// ErrorFrom converts an ApplicationError (or any error) into the envelope-compatible error response.
// It returns true if an error was written.
func ErrorFrom(c *gin.Context, err error) bool {
if err == nil {
return false
}
statusCode, status := infraerrors.ToHTTP(err)
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
return true
}
// BadRequest 返回400错误
func BadRequest(c *gin.Context, message string) {
Error(c, http.StatusBadRequest, message)
@@ -75,7 +103,7 @@ func InternalError(c *gin.Context, message string) {
}
// Paginated 返回分页数据
func Paginated(c *gin.Context, items interface{}, total int64, page, pageSize int) {
func Paginated(c *gin.Context, items any, total int64, page, pageSize int) {
pages := int(math.Ceil(float64(total) / float64(pageSize)))
if pages < 1 {
pages = 1
@@ -90,7 +118,7 @@ func Paginated(c *gin.Context, items interface{}, total int64, page, pageSize in
})
}
// PaginationResult 分页结果(与repository.PaginationResult兼容
// PaginationResult 分页结果(与pagination.PaginationResult兼容
type PaginationResult struct {
Total int64
Page int
@@ -99,7 +127,7 @@ type PaginationResult struct {
}
// PaginatedWithResult 使用PaginationResult返回分页数据
func PaginatedWithResult(c *gin.Context, items interface{}, pagination *PaginationResult) {
func PaginatedWithResult(c *gin.Context, items any, pagination *PaginationResult) {
if pagination == nil {
Success(c, PaginatedData{
Items: items,

View File

@@ -0,0 +1,171 @@
//go:build unit
package response
import (
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestErrorWithDetails(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
statusCode int
message string
reason string
metadata map[string]string
want Response
}{
{
name: "plain_error",
statusCode: http.StatusBadRequest,
message: "invalid request",
want: Response{
Code: http.StatusBadRequest,
Message: "invalid request",
},
},
{
name: "structured_error",
statusCode: http.StatusForbidden,
message: "no access",
reason: "FORBIDDEN",
metadata: map[string]string{"k": "v"},
want: Response{
Code: http.StatusForbidden,
Message: "no access",
Reason: "FORBIDDEN",
Metadata: map[string]string{"k": "v"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
ErrorWithDetails(c, tt.statusCode, tt.message, tt.reason, tt.metadata)
require.Equal(t, tt.statusCode, w.Code)
var got Response
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
require.Equal(t, tt.want, got)
})
}
}
func TestErrorFrom(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
err error
wantWritten bool
wantHTTPCode int
wantBody Response
}{
{
name: "nil_error",
err: nil,
wantWritten: false,
},
{
name: "application_error",
err: infraerrors.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}),
wantWritten: true,
wantHTTPCode: http.StatusForbidden,
wantBody: Response{
Code: http.StatusForbidden,
Message: "no access",
Reason: "FORBIDDEN",
Metadata: map[string]string{"scope": "admin"},
},
},
{
name: "bad_request_error",
err: infraerrors.BadRequest("INVALID_REQUEST", "invalid request"),
wantWritten: true,
wantHTTPCode: http.StatusBadRequest,
wantBody: Response{
Code: http.StatusBadRequest,
Message: "invalid request",
Reason: "INVALID_REQUEST",
},
},
{
name: "unauthorized_error",
err: infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized"),
wantWritten: true,
wantHTTPCode: http.StatusUnauthorized,
wantBody: Response{
Code: http.StatusUnauthorized,
Message: "unauthorized",
Reason: "UNAUTHORIZED",
},
},
{
name: "not_found_error",
err: infraerrors.NotFound("NOT_FOUND", "not found"),
wantWritten: true,
wantHTTPCode: http.StatusNotFound,
wantBody: Response{
Code: http.StatusNotFound,
Message: "not found",
Reason: "NOT_FOUND",
},
},
{
name: "conflict_error",
err: infraerrors.Conflict("CONFLICT", "conflict"),
wantWritten: true,
wantHTTPCode: http.StatusConflict,
wantBody: Response{
Code: http.StatusConflict,
Message: "conflict",
Reason: "CONFLICT",
},
},
{
name: "unknown_error_defaults_to_500",
err: errors.New("boom"),
wantWritten: true,
wantHTTPCode: http.StatusInternalServerError,
wantBody: Response{
Code: http.StatusInternalServerError,
Message: infraerrors.UnknownMessage,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
written := ErrorFrom(c, tt.err)
require.Equal(t, tt.wantWritten, written)
if !tt.wantWritten {
require.Equal(t, 200, w.Code)
require.Empty(t, w.Body.String())
return
}
require.Equal(t, tt.wantHTTPCode, w.Code)
var got Response
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
require.Equal(t, tt.wantBody, got)
})
}
}

View File

@@ -37,11 +37,15 @@ func TestInitInvalidTimezone(t *testing.T) {
func TestTimeNowAffected(t *testing.T) {
// Reset to UTC first
Init("UTC")
if err := Init("UTC"); err != nil {
t.Fatalf("Init failed with UTC: %v", err)
}
utcNow := time.Now()
// Switch to Shanghai (UTC+8)
Init("Asia/Shanghai")
if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
}
shanghaiNow := time.Now()
// The times should be the same instant, but different timezone representation
@@ -58,7 +62,9 @@ func TestTimeNowAffected(t *testing.T) {
}
func TestToday(t *testing.T) {
Init("Asia/Shanghai")
if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
}
today := Today()
now := Now()
@@ -75,7 +81,9 @@ func TestToday(t *testing.T) {
}
func TestStartOfDay(t *testing.T) {
Init("Asia/Shanghai")
if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
}
// Create a time at 15:30:45
testTime := time.Date(2024, 6, 15, 15, 30, 45, 123456789, Location())
@@ -91,7 +99,9 @@ func TestTruncateVsStartOfDay(t *testing.T) {
// This test demonstrates why Truncate(24*time.Hour) can be problematic
// and why StartOfDay is more reliable for timezone-aware code
Init("Asia/Shanghai")
if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
}
now := Now()

View File

@@ -0,0 +1,8 @@
package usagestats
// AccountStats 账号使用统计
type AccountStats struct {
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
}

View File

@@ -0,0 +1,209 @@
package usagestats
import "time"
// DashboardStats 仪表盘统计
type DashboardStats struct {
// 用户统计
TotalUsers int64 `json:"total_users"`
TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
// API Key 统计
TotalApiKeys int64 `json:"total_api_keys"`
ActiveApiKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
// 账户统计
TotalAccounts int64 `json:"total_accounts"`
NormalAccounts int64 `json:"normal_accounts"` // 正常账户数 (schedulable=true, status=active)
ErrorAccounts int64 `json:"error_accounts"` // 异常账户数 (status=error)
RateLimitAccounts int64 `json:"ratelimit_accounts"` // 限流账户数
OverloadAccounts int64 `json:"overload_accounts"` // 过载账户数
// 累计 Token 使用统计
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"` // 累计标准计费
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
// 今日 Token 使用统计
TodayRequests int64 `json:"today_requests"`
TodayInputTokens int64 `json:"today_input_tokens"`
TodayOutputTokens int64 `json:"today_output_tokens"`
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
TodayTokens int64 `json:"today_tokens"`
TodayCost float64 `json:"today_cost"` // 今日标准计费
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
// 系统运行统计
AverageDurationMs float64 `json:"average_duration_ms"` // 平均响应时间
// 性能指标
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
}
// TrendDataPoint represents a single point in trend data
type TrendDataPoint struct {
Date string `json:"date"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
CacheTokens int64 `json:"cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// ModelStat represents usage statistics for a single model
type ModelStat struct {
Model string `json:"model"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// UserUsageTrendPoint represents user usage trend data point
type UserUsageTrendPoint struct {
Date string `json:"date"`
UserID int64 `json:"user_id"`
Email string `json:"email"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// ApiKeyUsageTrendPoint represents API key usage trend data point
type ApiKeyUsageTrendPoint struct {
Date string `json:"date"`
ApiKeyID int64 `json:"api_key_id"`
KeyName string `json:"key_name"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
}
// UserDashboardStats 用户仪表盘统计
type UserDashboardStats struct {
// API Key 统计
TotalApiKeys int64 `json:"total_api_keys"`
ActiveApiKeys int64 `json:"active_api_keys"`
// 累计 Token 使用统计
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"` // 累计标准计费
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
// 今日 Token 使用统计
TodayRequests int64 `json:"today_requests"`
TodayInputTokens int64 `json:"today_input_tokens"`
TodayOutputTokens int64 `json:"today_output_tokens"`
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
TodayTokens int64 `json:"today_tokens"`
TodayCost float64 `json:"today_cost"` // 今日标准计费
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
// 性能统计
AverageDurationMs float64 `json:"average_duration_ms"`
// 性能指标
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
}
// UsageLogFilters represents filters for usage log queries
type UsageLogFilters struct {
UserID int64
ApiKeyID int64
StartTime *time.Time
EndTime *time.Time
}
// UsageStats represents usage statistics
type UsageStats struct {
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
AverageDurationMs float64 `json:"average_duration_ms"`
}
// BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats struct {
UserID int64 `json:"user_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}
// BatchApiKeyUsageStats represents usage stats for a single API key
type BatchApiKeyUsageStats struct {
ApiKeyID int64 `json:"api_key_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}
// AccountUsageHistory represents daily usage history for an account
type AccountUsageHistory struct {
Date string `json:"date"`
Label string `json:"label"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
ActualCost float64 `json:"actual_cost"`
}
// AccountUsageSummary represents summary statistics for an account
type AccountUsageSummary struct {
Days int `json:"days"`
ActualDaysUsed int `json:"actual_days_used"`
TotalCost float64 `json:"total_cost"`
TotalStandardCost float64 `json:"total_standard_cost"`
TotalRequests int64 `json:"total_requests"`
TotalTokens int64 `json:"total_tokens"`
AvgDailyCost float64 `json:"avg_daily_cost"`
AvgDailyRequests float64 `json:"avg_daily_requests"`
AvgDailyTokens float64 `json:"avg_daily_tokens"`
AvgDurationMs float64 `json:"avg_duration_ms"`
Today *struct {
Date string `json:"date"`
Cost float64 `json:"cost"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
} `json:"today"`
HighestCostDay *struct {
Date string `json:"date"`
Label string `json:"label"`
Cost float64 `json:"cost"`
Requests int64 `json:"requests"`
} `json:"highest_cost_day"`
HighestRequestDay *struct {
Date string `json:"date"`
Label string `json:"label"`
Requests int64 `json:"requests"`
Cost float64 `json:"cost"`
} `json:"highest_request_day"`
}
// AccountUsageStatsResponse represents the full usage statistics response for an account
type AccountUsageStatsResponse struct {
History []AccountUsageHistory `json:"history"`
Summary AccountUsageSummary `json:"summary"`
Models []ModelStat `json:"models"`
}

View File

@@ -2,63 +2,85 @@ package repository
import (
"context"
"sub2api/internal/model"
"errors"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/datatypes"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type AccountRepository struct {
type accountRepository struct {
db *gorm.DB
}
func NewAccountRepository(db *gorm.DB) *AccountRepository {
return &AccountRepository{db: db}
func NewAccountRepository(db *gorm.DB) service.AccountRepository {
return &accountRepository{db: db}
}
func (r *AccountRepository) Create(ctx context.Context, account *model.Account) error {
return r.db.WithContext(ctx).Create(account).Error
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
m := accountModelFromService(account)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyAccountModelToService(account, m)
}
return err
}
func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
var account model.Account
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups").First(&account, id).Error
func (r *accountRepository) GetByID(ctx context.Context, id int64) (*service.Account, error) {
var m accountModel
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&m, id).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
return accountModelToService(&m), nil
}
func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
if crsAccountID == "" {
return nil, nil
}
var m accountModel
err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&m).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
// 填充 GroupIDs 虚拟字段
account.GroupIDs = make([]int64, 0, len(account.AccountGroups))
for _, ag := range account.AccountGroups {
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
return accountModelToService(&m), nil
}
func (r *accountRepository) Update(ctx context.Context, account *service.Account) error {
m := accountModelFromService(account)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyAccountModelToService(account, m)
}
return &account, nil
return err
}
func (r *AccountRepository) Update(ctx context.Context, account *model.Account) error {
return r.db.WithContext(ctx).Save(account).Error
}
func (r *AccountRepository) Delete(ctx context.Context, id int64) error {
// 先删除账号与分组的绑定关系
if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&accountGroupModel{}).Error; err != nil {
return err
}
// 再删除账号
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
return r.db.WithContext(ctx).Delete(&accountModel{}, id).Error
}
func (r *AccountRepository) List(ctx context.Context, params PaginationParams) ([]model.Account, *PaginationResult, error) {
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "")
}
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
func (r *AccountRepository) ListWithFilters(ctx context.Context, params PaginationParams, platform, accountType, status, search string) ([]model.Account, *PaginationResult, error) {
var accounts []model.Account
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
var accounts []accountModel
var total int64
db := r.db.WithContext(ctx).Model(&model.Account{})
db := r.db.WithContext(ctx).Model(&accountModel{})
// Apply filters
if platform != "" {
db = db.Where("platform = ?", platform)
}
@@ -77,67 +99,88 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params Paginati
return nil, nil, err
}
if err := db.Preload("Proxy").Preload("AccountGroups").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&accounts).Error; err != nil {
if err := db.Preload("Proxy").Preload("AccountGroups.Group").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&accounts).Error; err != nil {
return nil, nil, err
}
// 填充每个 Account 的 GroupIDs 虚拟字段
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
accounts[i].GroupIDs = make([]int64, 0, len(accounts[i].AccountGroups))
for _, ag := range accounts[i].AccountGroups {
accounts[i].GroupIDs = append(accounts[i].GroupIDs, ag.GroupID)
}
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return accounts, &PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return outAccounts, paginationResultFromTotal(total, params), nil
}
func (r *AccountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
var accounts []model.Account
func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
var accounts []accountModel
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ? AND accounts.status = ?", groupID, model.StatusActive).
Where("account_groups.group_id = ? AND accounts.status = ?", groupID, service.StatusActive).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
return accounts, err
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
func (r *AccountRepository) ListActive(ctx context.Context) ([]model.Account, error) {
var accounts []model.Account
func (r *accountRepository) ListActive(ctx context.Context) ([]service.Account, error) {
var accounts []accountModel
err := r.db.WithContext(ctx).
Where("status = ?", model.StatusActive).
Where("status = ?", service.StatusActive).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
func (r *AccountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
var accounts []accountModel
err := r.db.WithContext(ctx).
Where("platform = ? AND status = ?", platform, service.StatusActive).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Update("last_used_at", now).Error
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).Update("last_used_at", now).Error
}
func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]interface{}{
"status": model.StatusError,
func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Updates(map[string]any{
"status": service.StatusError,
"error_message": errorMsg,
}).Error
}
func (r *AccountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
ag := &model.AccountGroup{
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
ag := &accountGroupModel{
AccountID: accountID,
GroupID: groupID,
Priority: priority,
@@ -145,111 +188,159 @@ func (r *AccountRepository) AddToGroup(ctx context.Context, accountID, groupID i
return r.db.WithContext(ctx).Create(ag).Error
}
func (r *AccountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID).
Delete(&model.AccountGroup{}).Error
Delete(&accountGroupModel{}).Error
}
func (r *AccountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) {
var groups []model.Group
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) {
var groups []groupModel
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.group_id = groups.id").
Where("account_groups.account_id = ?", accountID).
Find(&groups).Error
return groups, err
if err != nil {
return nil, err
}
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *groupModelToService(&groups[i]))
}
return outGroups, nil
}
func (r *AccountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
var accounts []model.Account
err := r.db.WithContext(ctx).
Where("platform = ? AND status = ?", platform, model.StatusActive).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
}
func (r *AccountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
// 删除现有绑定
if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&model.AccountGroup{}).Error; err != nil {
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&accountGroupModel{}).Error; err != nil {
return err
}
// 添加新绑定
if len(groupIDs) > 0 {
accountGroups := make([]model.AccountGroup, 0, len(groupIDs))
for i, groupID := range groupIDs {
accountGroups = append(accountGroups, model.AccountGroup{
AccountID: accountID,
GroupID: groupID,
Priority: i + 1, // 使用索引作为优先级
})
}
return r.db.WithContext(ctx).Create(&accountGroups).Error
if len(groupIDs) == 0 {
return nil
}
return nil
accountGroups := make([]accountGroupModel, 0, len(groupIDs))
for i, groupID := range groupIDs {
accountGroups = append(accountGroups, accountGroupModel{
AccountID: accountID,
GroupID: groupID,
Priority: i + 1,
})
}
return r.db.WithContext(ctx).Create(&accountGroups).Error
}
// ListSchedulable 获取所有可调度的账号
func (r *AccountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) {
var accounts []model.Account
func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Account, error) {
var accounts []accountModel
now := time.Now()
err := r.db.WithContext(ctx).
Where("status = ? AND schedulable = ?", model.StatusActive, true).
Where("status = ? AND schedulable = ?", service.StatusActive, true).
Where("(overload_until IS NULL OR overload_until <= ?)", now).
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
// ListSchedulableByGroupID 按组获取可调度的账号
func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) {
var accounts []model.Account
func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
var accounts []accountModel
now := time.Now()
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ?", groupID).
Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true).
Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
return accounts, err
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
// SetRateLimited 标记账号为限流状态(429)
func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
var accounts []accountModel
now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]interface{}{
"rate_limited_at": now,
err := r.db.WithContext(ctx).
Where("platform = ?", platform).
Where("status = ? AND schedulable = ?", service.StatusActive, true).
Where("(overload_until IS NULL OR overload_until <= ?)", now).
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
var accounts []accountModel
now := time.Now()
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ?", groupID).
Where("accounts.platform = ?", platform).
Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Updates(map[string]any{
"rate_limited_at": now,
"rate_limit_reset_at": resetAt,
}).Error
}
// SetOverloaded 标记账号为过载状态(529)
func (r *AccountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Update("overload_until", until).Error
}
// ClearRateLimit 清除账号的限流状态
func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]interface{}{
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Updates(map[string]any{
"rate_limited_at": nil,
"rate_limit_reset_at": nil,
"overload_until": nil,
}).Error
}
// UpdateSessionWindow 更新账号的5小时时间窗口信息
func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
updates := map[string]interface{}{
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
updates := map[string]any{
"session_window_status": status,
}
if start != nil {
@@ -258,11 +349,241 @@ func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
if end != nil {
updates["session_window_end"] = end
}
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Updates(updates).Error
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).Updates(updates).Error
}
// SetSchedulable 设置账号的调度开关
func (r *AccountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Update("schedulable", schedulable).Error
}
func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
if len(updates) == 0 {
return nil
}
var account accountModel
if err := r.db.WithContext(ctx).Select("extra").Where("id = ?", id).First(&account).Error; err != nil {
return err
}
if account.Extra == nil {
account.Extra = datatypes.JSONMap{}
}
for k, v := range updates {
account.Extra[k] = v
}
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Update("extra", account.Extra).Error
}
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
if len(ids) == 0 {
return 0, nil
}
updateMap := map[string]any{}
if updates.Name != nil {
updateMap["name"] = *updates.Name
}
if updates.ProxyID != nil {
updateMap["proxy_id"] = updates.ProxyID
}
if updates.Concurrency != nil {
updateMap["concurrency"] = *updates.Concurrency
}
if updates.Priority != nil {
updateMap["priority"] = *updates.Priority
}
if updates.Status != nil {
updateMap["status"] = *updates.Status
}
if len(updates.Credentials) > 0 {
updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", datatypes.JSONMap(updates.Credentials))
}
if len(updates.Extra) > 0 {
updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", datatypes.JSONMap(updates.Extra))
}
if len(updateMap) == 0 {
return 0, nil
}
result := r.db.WithContext(ctx).
Model(&accountModel{}).
Where("id IN ?", ids).
Clauses(clause.Returning{}).
Updates(updateMap)
return result.RowsAffected, result.Error
}
type accountModel struct {
ID int64 `gorm:"primaryKey"`
Name string `gorm:"size:100;not null"`
Platform string `gorm:"size:50;not null"`
Type string `gorm:"size:20;not null"`
Credentials datatypes.JSONMap `gorm:"type:jsonb;default:'{}'"`
Extra datatypes.JSONMap `gorm:"type:jsonb;default:'{}'"`
ProxyID *int64 `gorm:"index"`
Concurrency int `gorm:"default:3;not null"`
Priority int `gorm:"default:50;not null"`
Status string `gorm:"size:20;default:active;not null"`
ErrorMessage string `gorm:"type:text"`
LastUsedAt *time.Time `gorm:"index"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
Schedulable bool `gorm:"default:true;not null"`
RateLimitedAt *time.Time `gorm:"index"`
RateLimitResetAt *time.Time `gorm:"index"`
OverloadUntil *time.Time `gorm:"index"`
SessionWindowStart *time.Time
SessionWindowEnd *time.Time
SessionWindowStatus string `gorm:"size:20"`
Proxy *proxyModel `gorm:"foreignKey:ProxyID"`
AccountGroups []accountGroupModel `gorm:"foreignKey:AccountID"`
}
func (accountModel) TableName() string { return "accounts" }
type accountGroupModel struct {
AccountID int64 `gorm:"primaryKey"`
GroupID int64 `gorm:"primaryKey"`
Priority int `gorm:"default:50;not null"`
CreatedAt time.Time `gorm:"not null"`
Account *accountModel `gorm:"foreignKey:AccountID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
}
func (accountGroupModel) TableName() string { return "account_groups" }
func accountGroupModelToService(m *accountGroupModel) *service.AccountGroup {
if m == nil {
return nil
}
return &service.AccountGroup{
AccountID: m.AccountID,
GroupID: m.GroupID,
Priority: m.Priority,
CreatedAt: m.CreatedAt,
Account: accountModelToService(m.Account),
Group: groupModelToService(m.Group),
}
}
func accountModelToService(m *accountModel) *service.Account {
if m == nil {
return nil
}
var credentials map[string]any
if m.Credentials != nil {
credentials = map[string]any(m.Credentials)
}
var extra map[string]any
if m.Extra != nil {
extra = map[string]any(m.Extra)
}
account := &service.Account{
ID: m.ID,
Name: m.Name,
Platform: m.Platform,
Type: m.Type,
Credentials: credentials,
Extra: extra,
ProxyID: m.ProxyID,
Concurrency: m.Concurrency,
Priority: m.Priority,
Status: m.Status,
ErrorMessage: m.ErrorMessage,
LastUsedAt: m.LastUsedAt,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
Schedulable: m.Schedulable,
RateLimitedAt: m.RateLimitedAt,
RateLimitResetAt: m.RateLimitResetAt,
OverloadUntil: m.OverloadUntil,
SessionWindowStart: m.SessionWindowStart,
SessionWindowEnd: m.SessionWindowEnd,
SessionWindowStatus: m.SessionWindowStatus,
Proxy: proxyModelToService(m.Proxy),
}
if len(m.AccountGroups) > 0 {
account.AccountGroups = make([]service.AccountGroup, 0, len(m.AccountGroups))
account.GroupIDs = make([]int64, 0, len(m.AccountGroups))
account.Groups = make([]*service.Group, 0, len(m.AccountGroups))
for i := range m.AccountGroups {
ag := accountGroupModelToService(&m.AccountGroups[i])
if ag == nil {
continue
}
account.AccountGroups = append(account.AccountGroups, *ag)
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
if ag.Group != nil {
account.Groups = append(account.Groups, ag.Group)
}
}
}
return account
}
func accountModelFromService(a *service.Account) *accountModel {
if a == nil {
return nil
}
var credentials datatypes.JSONMap
if a.Credentials != nil {
credentials = datatypes.JSONMap(a.Credentials)
}
var extra datatypes.JSONMap
if a.Extra != nil {
extra = datatypes.JSONMap(a.Extra)
}
return &accountModel{
ID: a.ID,
Name: a.Name,
Platform: a.Platform,
Type: a.Type,
Credentials: credentials,
Extra: extra,
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
Priority: a.Priority,
Status: a.Status,
ErrorMessage: a.ErrorMessage,
LastUsedAt: a.LastUsedAt,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
Schedulable: a.Schedulable,
RateLimitedAt: a.RateLimitedAt,
RateLimitResetAt: a.RateLimitResetAt,
OverloadUntil: a.OverloadUntil,
SessionWindowStart: a.SessionWindowStart,
SessionWindowEnd: a.SessionWindowEnd,
SessionWindowStatus: a.SessionWindowStatus,
}
}
func applyAccountModelToService(account *service.Account, m *accountModel) {
if account == nil || m == nil {
return
}
account.ID = m.ID
account.CreatedAt = m.CreatedAt
account.UpdatedAt = m.UpdatedAt
}

View File

@@ -0,0 +1,585 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/datatypes"
"gorm.io/gorm"
)
type AccountRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *accountRepository
}
func (s *AccountRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewAccountRepository(s.db).(*accountRepository)
}
func TestAccountRepoSuite(t *testing.T) {
suite.Run(t, new(AccountRepoSuite))
}
// --- Create / GetByID / Update / Delete ---
func (s *AccountRepoSuite) TestCreate() {
account := &service.Account{
Name: "test-create",
Platform: service.PlatformAnthropic,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Credentials: map[string]any{},
Extra: map[string]any{},
Concurrency: 3,
Priority: 50,
Schedulable: true,
}
err := s.repo.Create(s.ctx, account)
s.Require().NoError(err, "Create")
s.Require().NotZero(account.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("test-create", got.Name)
}
func (s *AccountRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *AccountRepoSuite) TestUpdate() {
account := accountModelToService(mustCreateAccount(s.T(), s.db, &accountModel{Name: "original"}))
account.Name = "updated"
err := s.repo.Update(s.ctx, account)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated", got.Name)
}
func (s *AccountRepoSuite) TestDelete() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "to-delete"})
err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, account.ID)
s.Require().Error(err, "expected error after delete")
}
func (s *AccountRepoSuite) TestDelete_WithGroupBindings() {
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-del"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-del"})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete should cascade remove bindings")
var count int64
s.db.Model(&accountGroupModel{}).Where("account_id = ?", account.ID).Count(&count)
s.Require().Zero(count, "expected bindings to be removed")
}
// --- List / ListWithFilters ---
func (s *AccountRepoSuite) TestList() {
mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc1"})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc2"})
accounts, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
s.Require().Len(accounts, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *AccountRepoSuite) TestListWithFilters() {
tests := []struct {
name string
setup func(db *gorm.DB)
platform string
accType string
status string
search string
wantCount int
validate func(accounts []service.Account)
}{
{
name: "filter_by_platform",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic})
mustCreateAccount(s.T(), db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI})
},
platform: service.PlatformOpenAI,
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal(service.PlatformOpenAI, accounts[0].Platform)
},
},
{
name: "filter_by_type",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &accountModel{Name: "t1", Type: service.AccountTypeOAuth})
mustCreateAccount(s.T(), db, &accountModel{Name: "t2", Type: service.AccountTypeApiKey})
},
accType: service.AccountTypeApiKey,
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal(service.AccountTypeApiKey, accounts[0].Type)
},
},
{
name: "filter_by_status",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &accountModel{Name: "s1", Status: service.StatusActive})
mustCreateAccount(s.T(), db, &accountModel{Name: "s2", Status: service.StatusDisabled})
},
status: service.StatusDisabled,
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal(service.StatusDisabled, accounts[0].Status)
},
},
{
name: "filter_by_search",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &accountModel{Name: "alpha-account"})
mustCreateAccount(s.T(), db, &accountModel{Name: "beta-account"})
},
search: "alpha",
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Contains(accounts[0].Name, "alpha")
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
// 每个 case 重新获取隔离资源
db := testTx(s.T())
repo := NewAccountRepository(db).(*accountRepository)
ctx := context.Background()
tt.setup(db)
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search)
s.Require().NoError(err)
s.Require().Len(accounts, tt.wantCount)
if tt.validate != nil {
tt.validate(accounts)
}
})
}
}
// --- ListByGroup / ListActive / ListByPlatform ---
func (s *AccountRepoSuite) TestListByGroup() {
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
acc1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Status: service.StatusActive})
acc2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Status: service.StatusActive})
mustBindAccountToGroup(s.T(), s.db, acc1.ID, group.ID, 2)
mustBindAccountToGroup(s.T(), s.db, acc2.ID, group.ID, 1)
accounts, err := s.repo.ListByGroup(s.ctx, group.ID)
s.Require().NoError(err, "ListByGroup")
s.Require().Len(accounts, 2)
// Should be ordered by priority
s.Require().Equal(acc2.ID, accounts[0].ID, "expected acc2 first (priority=1)")
}
func (s *AccountRepoSuite) TestListActive() {
mustCreateAccount(s.T(), s.db, &accountModel{Name: "active1", Status: service.StatusActive})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "inactive1", Status: service.StatusDisabled})
accounts, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
s.Require().Len(accounts, 1)
s.Require().Equal("active1", accounts[0].Name)
}
func (s *AccountRepoSuite) TestListByPlatform() {
mustCreateAccount(s.T(), s.db, &accountModel{Name: "p1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "p2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
accounts, err := s.repo.ListByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err, "ListByPlatform")
s.Require().Len(accounts, 1)
s.Require().Equal(service.PlatformAnthropic, accounts[0].Platform)
}
// --- Preload and VirtualFields ---
func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
account := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "acc1",
ProxyID: &proxy.ID,
})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err, "GetByID")
s.Require().NotNil(got.Proxy, "expected Proxy preload")
s.Require().Equal(proxy.ID, got.Proxy.ID)
s.Require().Len(got.GroupIDs, 1, "expected GroupIDs to be populated")
s.Require().Equal(group.ID, got.GroupIDs[0])
s.Require().Len(got.Groups, 1, "expected Groups to be populated")
s.Require().Equal(group.ID, got.Groups[0].ID)
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc")
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(accounts, 1)
s.Require().NotNil(accounts[0].Proxy, "expected Proxy preload in list")
s.Require().Equal(proxy.ID, accounts[0].Proxy.ID)
s.Require().Len(accounts[0].GroupIDs, 1, "expected GroupIDs in list")
s.Require().Equal(group.ID, accounts[0].GroupIDs[0])
}
// --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups ---
func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc"})
s.Require().NoError(s.repo.AddToGroup(s.ctx, account.ID, g1.ID, 10), "AddToGroup")
groups, err := s.repo.GetGroups(s.ctx, account.ID)
s.Require().NoError(err, "GetGroups")
s.Require().Len(groups, 1, "expected 1 group")
s.Require().Equal(g1.ID, groups[0].ID)
s.Require().NoError(s.repo.RemoveFromGroup(s.ctx, account.ID, g1.ID), "RemoveFromGroup")
groups, err = s.repo.GetGroups(s.ctx, account.ID)
s.Require().NoError(err, "GetGroups after remove")
s.Require().Empty(groups, "expected 0 groups after remove")
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{g1.ID, g2.ID}), "BindGroups")
groups, err = s.repo.GetGroups(s.ctx, account.ID)
s.Require().NoError(err, "GetGroups after bind")
s.Require().Len(groups, 2, "expected 2 groups after bind")
}
func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-empty"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-empty"})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{}), "BindGroups empty")
groups, err := s.repo.GetGroups(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Empty(groups, "expected 0 groups after binding empty list")
}
// --- Schedulable ---
func (s *AccountRepoSuite) TestListSchedulable() {
now := time.Now()
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sched"})
okAcc := mustCreateAccount(s.T(), s.db, &accountModel{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
future := now.Add(10 * time.Minute)
overloaded := mustCreateAccount(s.T(), s.db, &accountModel{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
sched, err := s.repo.ListSchedulable(s.ctx)
s.Require().NoError(err, "ListSchedulable")
ids := idsOfAccounts(sched)
s.Require().Contains(ids, okAcc.ID)
s.Require().NotContains(ids, overloaded.ID)
}
func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates() {
now := time.Now()
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sched"})
okAcc := mustCreateAccount(s.T(), s.db, &accountModel{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
future := now.Add(10 * time.Minute)
overloaded := mustCreateAccount(s.T(), s.db, &accountModel{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
rateLimited := mustCreateAccount(s.T(), s.db, &accountModel{Name: "rl", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, rateLimited.ID, group.ID, 1)
s.Require().NoError(s.repo.SetRateLimited(s.ctx, rateLimited.ID, now.Add(10*time.Minute)), "SetRateLimited")
s.Require().NoError(s.repo.SetError(s.ctx, overloaded.ID, "boom"), "SetError")
sched, err := s.repo.ListSchedulableByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ListSchedulableByGroupID")
s.Require().Len(sched, 1, "expected only ok account schedulable")
s.Require().Equal(okAcc.ID, sched[0].ID)
s.Require().NoError(s.repo.ClearRateLimit(s.ctx, rateLimited.ID), "ClearRateLimit")
sched2, err := s.repo.ListSchedulableByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ListSchedulableByGroupID after ClearRateLimit")
s.Require().Len(sched2, 2, "expected 2 schedulable accounts after ClearRateLimit")
}
func (s *AccountRepoSuite) TestListSchedulableByPlatform() {
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err)
s.Require().Len(accounts, 1)
s.Require().Equal(service.PlatformAnthropic, accounts[0].Platform)
}
func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sp"})
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, service.PlatformAnthropic)
s.Require().NoError(err)
s.Require().Len(accounts, 1)
s.Require().Equal(a1.ID, accounts[0].ID)
}
func (s *AccountRepoSuite) TestSetSchedulable() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-sched", Schedulable: true})
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().False(got.Schedulable)
}
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
func (s *AccountRepoSuite) TestSetOverloaded() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-over"})
until := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().NotNil(got.OverloadUntil)
s.Require().WithinDuration(until, *got.OverloadUntil, time.Second)
}
func (s *AccountRepoSuite) TestSetRateLimited() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-rl"})
resetAt := time.Date(2025, 6, 15, 14, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, resetAt))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().NotNil(got.RateLimitedAt)
s.Require().NotNil(got.RateLimitResetAt)
s.Require().WithinDuration(resetAt, *got.RateLimitResetAt, time.Second)
}
func (s *AccountRepoSuite) TestClearRateLimit() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-clear"})
until := time.Now().Add(1 * time.Hour)
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, until))
s.Require().NoError(s.repo.ClearRateLimit(s.ctx, account.ID))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Nil(got.RateLimitedAt)
s.Require().Nil(got.RateLimitResetAt)
s.Require().Nil(got.OverloadUntil)
}
// --- UpdateLastUsed ---
func (s *AccountRepoSuite) TestUpdateLastUsed() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-used"})
s.Require().Nil(account.LastUsedAt)
s.Require().NoError(s.repo.UpdateLastUsed(s.ctx, account.ID))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().NotNil(got.LastUsedAt)
}
// --- SetError ---
func (s *AccountRepoSuite) TestSetError() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-err", Status: service.StatusActive})
s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong"))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Equal(service.StatusError, got.Status)
s.Require().Equal("something went wrong", got.ErrorMessage)
}
// --- UpdateSessionWindow ---
func (s *AccountRepoSuite) TestUpdateSessionWindow() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-win"})
start := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC)
end := time.Date(2025, 6, 15, 15, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.UpdateSessionWindow(s.ctx, account.ID, &start, &end, "active"))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().NotNil(got.SessionWindowStart)
s.Require().NotNil(got.SessionWindowEnd)
s.Require().Equal("active", got.SessionWindowStatus)
}
// --- UpdateExtra ---
func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
account := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "acc-extra",
Extra: datatypes.JSONMap{"a": "1"},
})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"b": "2"}), "UpdateExtra")
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("1", got.Extra["a"])
s.Require().Equal("2", got.Extra["b"])
}
func (s *AccountRepoSuite) TestUpdateExtra_EmptyUpdates() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-extra-empty"})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{}))
}
func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-nil-extra", Extra: nil})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"key": "val"}))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Equal("val", got.Extra["key"])
}
// --- GetByCRSAccountID ---
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
crsID := "crs-12345"
mustCreateAccount(s.T(), s.db, &accountModel{
Name: "acc-crs",
Extra: datatypes.JSONMap{"crs_account_id": crsID},
})
got, err := s.repo.GetByCRSAccountID(s.ctx, crsID)
s.Require().NoError(err)
s.Require().NotNil(got)
s.Require().Equal("acc-crs", got.Name)
}
func (s *AccountRepoSuite) TestGetByCRSAccountID_NotFound() {
got, err := s.repo.GetByCRSAccountID(s.ctx, "non-existent")
s.Require().NoError(err)
s.Require().Nil(got)
}
func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() {
got, err := s.repo.GetByCRSAccountID(s.ctx, "")
s.Require().NoError(err)
s.Require().Nil(got)
}
// --- BulkUpdate ---
func (s *AccountRepoSuite) TestBulkUpdate() {
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk1", Priority: 1})
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk2", Priority: 1})
newPriority := 99
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, service.AccountBulkUpdate{
Priority: &newPriority,
})
s.Require().NoError(err)
s.Require().GreaterOrEqual(affected, int64(1), "expected at least one affected row")
got1, _ := s.repo.GetByID(s.ctx, a1.ID)
got2, _ := s.repo.GetByID(s.ctx, a2.ID)
s.Require().Equal(99, got1.Priority)
s.Require().Equal(99, got2.Priority)
}
func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
a1 := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "bulk-cred",
Credentials: datatypes.JSONMap{"existing": "value"},
})
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
Credentials: datatypes.JSONMap{"new_key": "new_value"},
})
s.Require().NoError(err)
got, _ := s.repo.GetByID(s.ctx, a1.ID)
s.Require().Equal("value", got.Credentials["existing"])
s.Require().Equal("new_value", got.Credentials["new_key"])
}
func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() {
a1 := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "bulk-extra",
Extra: datatypes.JSONMap{"existing": "val"},
})
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
Extra: datatypes.JSONMap{"new_key": "new_val"},
})
s.Require().NoError(err)
got, _ := s.repo.GetByID(s.ctx, a1.ID)
s.Require().Equal("val", got.Extra["existing"])
s.Require().Equal("new_val", got.Extra["new_key"])
}
func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
affected, err := s.repo.BulkUpdate(s.ctx, []int64{}, service.AccountBulkUpdate{})
s.Require().NoError(err)
s.Require().Zero(affected)
}
func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() {
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk-empty"})
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{})
s.Require().NoError(err)
s.Require().Zero(affected)
}
func idsOfAccounts(accounts []service.Account) []int64 {
out := make([]int64, 0, len(accounts))
for i := range accounts {
out = append(out, accounts[i].ID)
}
return out
}

View File

@@ -0,0 +1,60 @@
package repository
import (
"context"
"errors"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
apiKeyRateLimitDuration = 24 * time.Hour
)
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
func apiKeyRateLimitKey(userID int64) string {
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
}
type apiKeyCache struct {
rdb *redis.Client
}
func NewApiKeyCache(rdb *redis.Client) service.ApiKeyCache {
return &apiKeyCache{rdb: rdb}
}
func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
key := apiKeyRateLimitKey(userID)
count, err := c.rdb.Get(ctx, key).Int()
if errors.Is(err, redis.Nil) {
return 0, nil
}
return count, err
}
func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
key := apiKeyRateLimitKey(userID)
pipe := c.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
_, err := pipe.Exec(ctx)
return err
}
func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
key := apiKeyRateLimitKey(userID)
return c.rdb.Del(ctx, key).Err()
}
func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
return c.rdb.Incr(ctx, apiKey).Err()
}
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
return c.rdb.Expire(ctx, apiKey, ttl).Err()
}

View File

@@ -0,0 +1,127 @@
//go:build integration
package repository
import (
"context"
"fmt"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ApiKeyCacheSuite struct {
IntegrationRedisSuite
}
func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
}{
{
name: "missing_key_returns_zero_nil",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
userID := int64(1)
count, err := cache.GetCreateAttemptCount(ctx, userID)
require.NoError(s.T(), err, "expected nil error for missing key")
require.Equal(s.T(), 0, count, "expected zero count for missing key")
},
},
{
name: "increment_increases_count_and_sets_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
userID := int64(1)
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount")
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount 2")
count, err := cache.GetCreateAttemptCount(ctx, userID)
require.NoError(s.T(), err, "GetCreateAttemptCount")
require.Equal(s.T(), 2, count, "count mismatch")
ttl, err := rdb.TTL(ctx, key).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, apiKeyRateLimitDuration)
},
},
{
name: "delete_removes_key",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
userID := int64(1)
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID))
require.NoError(s.T(), cache.DeleteCreateAttemptCount(ctx, userID), "DeleteCreateAttemptCount")
count, err := cache.GetCreateAttemptCount(ctx, userID)
require.NoError(s.T(), err, "expected nil error after delete")
require.Equal(s.T(), 0, count, "expected zero count after delete")
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
// 每个 case 重新获取隔离资源
rdb := testRedis(s.T())
cache := &apiKeyCache{rdb: rdb}
ctx := context.Background()
tt.fn(ctx, rdb, cache)
})
}
}
func (s *ApiKeyCacheSuite) TestDailyUsage() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
}{
{
name: "increment_increases_count",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
dailyKey := "daily:sk-test"
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage")
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage 2")
n, err := rdb.Get(ctx, dailyKey).Int()
require.NoError(s.T(), err, "Get dailyKey")
require.Equal(s.T(), 2, n, "expected daily usage=2")
},
},
{
name: "set_expiry_sets_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
dailyKey := "daily:sk-test-expiry"
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey))
require.NoError(s.T(), cache.SetDailyUsageExpiry(ctx, dailyKey, 1*time.Hour), "SetDailyUsageExpiry")
ttl, err := rdb.TTL(ctx, dailyKey).Result()
require.NoError(s.T(), err, "TTL dailyKey")
require.Greater(s.T(), ttl, time.Duration(0), "expected ttl > 0")
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
rdb := testRedis(s.T())
cache := &apiKeyCache{rdb: rdb}
ctx := context.Background()
tt.fn(ctx, rdb, cache)
})
}
}
func TestApiKeyCacheSuite(t *testing.T) {
suite.Run(t, new(ApiKeyCacheSuite))
}

View File

@@ -0,0 +1,46 @@
//go:build unit
package repository
import (
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestApiKeyRateLimitKey(t *testing.T) {
tests := []struct {
name string
userID int64
expected string
}{
{
name: "normal_user_id",
userID: 123,
expected: "apikey:ratelimit:123",
},
{
name: "zero_user_id",
userID: 0,
expected: "apikey:ratelimit:0",
},
{
name: "negative_user_id",
userID: -1,
expected: "apikey:ratelimit:-1",
},
{
name: "max_int64",
userID: math.MaxInt64,
expected: "apikey:ratelimit:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := apiKeyRateLimitKey(tc.userID)
require.Equal(t, tc.expected, got)
})
}
}

View File

@@ -2,54 +2,68 @@ package repository
import (
"context"
"sub2api/internal/model"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
type ApiKeyRepository struct {
type apiKeyRepository struct {
db *gorm.DB
}
func NewApiKeyRepository(db *gorm.DB) *ApiKeyRepository {
return &ApiKeyRepository{db: db}
func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository {
return &apiKeyRepository{db: db}
}
func (r *ApiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error {
return r.db.WithContext(ctx).Create(key).Error
}
func (r *ApiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
var key model.ApiKey
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&key, id).Error
if err != nil {
return nil, err
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
m := apiKeyModelFromService(key)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyApiKeyModelToService(key, m)
}
return &key, nil
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
}
func (r *ApiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
var apiKey model.ApiKey
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&apiKey).Error
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
var m apiKeyModel
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&m, id).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
}
return &apiKey, nil
return apiKeyModelToService(&m), nil
}
func (r *ApiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error {
return r.db.WithContext(ctx).Model(key).Select("name", "group_id", "status", "updated_at").Updates(key).Error
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
var m apiKeyModel
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&m).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
}
return apiKeyModelToService(&m), nil
}
func (r *ApiKeyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
m := apiKeyModelFromService(key)
err := r.db.WithContext(ctx).Model(m).Select("name", "group_id", "status", "updated_at").Updates(m).Error
if err == nil {
applyApiKeyModelToService(key, m)
}
return err
}
func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params PaginationParams) ([]model.ApiKey, *PaginationResult, error) {
var keys []model.ApiKey
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&apiKeyModel{}, id).Error
}
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
var keys []apiKeyModel
var total int64
db := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID)
db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID)
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
@@ -59,36 +73,31 @@ func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
}
return keys, &PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return outKeys, paginationResultFromTotal(total, params), nil
}
func (r *ApiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID).Count(&count).Error
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID).Count(&count).Error
return count, err
}
func (r *ApiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("key = ?", key).Count(&count).Error
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("key = ?", key).Count(&count).Error
return count > 0, err
}
func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params PaginationParams) ([]model.ApiKey, *PaginationResult, error) {
var keys []model.ApiKey
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
var keys []apiKeyModel
var total int64
db := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID)
db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID)
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
@@ -98,24 +107,19 @@ func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
}
return keys, &PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return outKeys, paginationResultFromTotal(total, params), nil
}
// SearchApiKeys searches API keys by user ID and/or keyword (name)
func (r *ApiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
var keys []model.ApiKey
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
var keys []apiKeyModel
db := r.db.WithContext(ctx).Model(&model.ApiKey{})
db := r.db.WithContext(ctx).Model(&apiKeyModel{})
if userID > 0 {
db = db.Where("user_id = ?", userID)
@@ -130,20 +134,84 @@ func (r *ApiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
return nil, err
}
return keys, nil
outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
}
return outKeys, nil
}
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *ApiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.ApiKey{}).
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&apiKeyModel{}).
Where("group_id = ?", groupID).
Update("group_id", nil)
return result.RowsAffected, result.Error
}
// CountByGroupID 获取分组的 API Key 数量
func (r *ApiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID).Count(&count).Error
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID).Count(&count).Error
return count, err
}
type apiKeyModel struct {
ID int64 `gorm:"primaryKey"`
UserID int64 `gorm:"index;not null"`
Key string `gorm:"uniqueIndex;size:128;not null"`
Name string `gorm:"size:100;not null"`
GroupID *int64 `gorm:"index"`
Status string `gorm:"size:20;default:active;not null"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
User *userModel `gorm:"foreignKey:UserID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
}
func (apiKeyModel) TableName() string { return "api_keys" }
func apiKeyModelToService(m *apiKeyModel) *service.ApiKey {
if m == nil {
return nil
}
return &service.ApiKey{
ID: m.ID,
UserID: m.UserID,
Key: m.Key,
Name: m.Name,
GroupID: m.GroupID,
Status: m.Status,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
User: userModelToService(m.User),
Group: groupModelToService(m.Group),
}
}
func apiKeyModelFromService(k *service.ApiKey) *apiKeyModel {
if k == nil {
return nil
}
return &apiKeyModel{
ID: k.ID,
UserID: k.UserID,
Key: k.Key,
Name: k.Name,
GroupID: k.GroupID,
Status: k.Status,
CreatedAt: k.CreatedAt,
UpdatedAt: k.UpdatedAt,
}
}
func applyApiKeyModelToService(key *service.ApiKey, m *apiKeyModel) {
if key == nil || m == nil {
return
}
key.ID = m.ID
key.CreatedAt = m.CreatedAt
key.UpdatedAt = m.UpdatedAt
}

View File

@@ -0,0 +1,355 @@
//go:build integration
package repository
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type ApiKeyRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *apiKeyRepository
}
func (s *ApiKeyRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewApiKeyRepository(s.db).(*apiKeyRepository)
}
func TestApiKeyRepoSuite(t *testing.T) {
suite.Run(t, new(ApiKeyRepoSuite))
}
// --- Create / GetByID / GetByKey ---
func (s *ApiKeyRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "create@test.com"})
key := &service.ApiKey{
UserID: user.ID,
Key: "sk-create-test",
Name: "Test Key",
Status: service.StatusActive,
}
err := s.repo.Create(s.ctx, key)
s.Require().NoError(err, "Create")
s.Require().NotZero(key.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, key.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("sk-create-test", got.Key)
}
func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *ApiKeyRepoSuite) TestGetByKey() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "getbykey@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-key"})
key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-getbykey",
Name: "My Key",
GroupID: &group.ID,
Status: service.StatusActive,
})
got, err := s.repo.GetByKey(s.ctx, key.Key)
s.Require().NoError(err, "GetByKey")
s.Require().Equal(key.ID, got.ID)
s.Require().NotNil(got.User, "expected User preload")
s.Require().Equal(user.ID, got.User.ID)
s.Require().NotNil(got.Group, "expected Group preload")
s.Require().Equal(group.ID, got.Group.ID)
}
func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
_, err := s.repo.GetByKey(s.ctx, "non-existent-key")
s.Require().Error(err, "expected error for non-existent key")
}
// --- Update ---
func (s *ApiKeyRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com"})
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-update",
Name: "Original",
Status: service.StatusActive,
}))
key.Name = "Renamed"
key.Status = service.StatusDisabled
err := s.repo.Update(s.ctx, key)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, key.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("sk-update", got.Key, "Update should not change key")
s.Require().Equal(user.ID, got.UserID, "Update should not change user_id")
s.Require().Equal("Renamed", got.Name)
s.Require().Equal(service.StatusDisabled, got.Status)
}
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear"})
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-clear-group",
Name: "Group Key",
GroupID: &group.ID,
}))
key.GroupID = nil
err := s.repo.Update(s.ctx, key)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, key.ID)
s.Require().NoError(err)
s.Require().Nil(got.GroupID, "expected GroupID to be cleared")
}
// --- Delete ---
func (s *ApiKeyRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-delete",
Name: "Delete Me",
})
err := s.repo.Delete(s.ctx, key.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, key.ID)
s.Require().Error(err, "expected error after delete")
}
// --- ListByUserID / CountByUserID ---
func (s *ApiKeyRepoSuite) TestListByUserID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyuser@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"})
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUserID")
s.Require().Len(keys, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "paging@test.com"})
for i := 0; i < 5; i++ {
mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-page-" + string(rune('a'+i)),
Name: "Key",
})
}
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
s.Require().NoError(err)
s.Require().Len(keys, 2)
s.Require().Equal(int64(5), page.Total)
s.Require().Equal(3, page.Pages)
}
func (s *ApiKeyRepoSuite) TestCountByUserID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "count@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-2", Name: "K2"})
count, err := s.repo.CountByUserID(s.ctx, user.ID)
s.Require().NoError(err, "CountByUserID")
s.Require().Equal(int64(2), count)
}
// --- ListByGroupID / CountByGroupID ---
func (s *ApiKeyRepoSuite) TestListByGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbygroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByGroupID")
s.Require().Len(keys, 2)
s.Require().Equal(int64(2), page.Total)
// User preloaded
s.Require().NotNil(keys[0].User)
}
func (s *ApiKeyRepoSuite) TestCountByGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "countgroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID})
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID")
s.Require().Equal(int64(1), count)
}
// --- ExistsByKey ---
func (s *ApiKeyRepoSuite) TestExistsByKey() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-exists", Name: "K"})
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
s.Require().NoError(err, "ExistsByKey")
s.Require().True(exists)
notExists, err := s.repo.ExistsByKey(s.ctx, "sk-not-exists")
s.Require().NoError(err)
s.Require().False(notExists)
}
// --- SearchApiKeys ---
func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "search@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"})
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
s.Require().NoError(err, "SearchApiKeys")
s.Require().Len(found, 1)
s.Require().Contains(found[0].Name, "Production")
}
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnokw@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-2", Name: "K2"})
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
s.Require().NoError(err)
s.Require().Len(found, 2)
}
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnouid@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"})
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
s.Require().NoError(err)
s.Require().Len(found, 1)
}
// --- ClearGroupIDByGroupID ---
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargrp@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear-bulk"})
k1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID})
k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ClearGroupIDByGroupID")
s.Require().Equal(int64(2), affected)
got1, _ := s.repo.GetByID(s.ctx, k1.ID)
got2, _ := s.repo.GetByID(s.ctx, k2.ID)
s.Require().Nil(got1.GroupID)
s.Require().Nil(got2.GroupID)
count, _ := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().Zero(count)
}
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "k@example.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-k"})
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-test-1",
Name: "My Key",
GroupID: &group.ID,
Status: service.StatusActive,
}))
got, err := s.repo.GetByKey(s.ctx, key.Key)
s.Require().NoError(err, "GetByKey")
s.Require().Equal(key.ID, got.ID)
s.Require().NotNil(got.User)
s.Require().Equal(user.ID, got.User.ID)
s.Require().NotNil(got.Group)
s.Require().Equal(group.ID, got.Group.ID)
key.Name = "Renamed"
key.Status = service.StatusDisabled
key.GroupID = nil
s.Require().NoError(s.repo.Update(s.ctx, key), "Update")
got2, err := s.repo.GetByID(s.ctx, key.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("sk-test-1", got2.Key, "Update should not change key")
s.Require().Equal(user.ID, got2.UserID, "Update should not change user_id")
s.Require().Equal("Renamed", got2.Name)
s.Require().Equal(service.StatusDisabled, got2.Status)
s.Require().Nil(got2.GroupID)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUserID")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(keys, 1)
exists, err := s.repo.ExistsByKey(s.ctx, "sk-test-1")
s.Require().NoError(err, "ExistsByKey")
s.Require().True(exists, "expected key to exist")
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "renam", 10)
s.Require().NoError(err, "SearchApiKeys")
s.Require().Len(found, 1)
s.Require().Equal(key.ID, found[0].ID)
// ClearGroupIDByGroupID
k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-test-2",
Name: "Group Key",
GroupID: &group.ID,
})
countBefore, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID")
s.Require().Equal(int64(1), countBefore, "expected 1 key in group before clear")
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ClearGroupIDByGroupID")
s.Require().Equal(int64(1), affected, "expected 1 affected row")
got3, err := s.repo.GetByID(s.ctx, k2.ID)
s.Require().NoError(err, "GetByID")
s.Require().Nil(got3.GroupID, "expected GroupID cleared")
countAfter, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID after clear")
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
}

View File

@@ -0,0 +1,20 @@
package repository
import "gorm.io/gorm"
// AutoMigrate runs schema migrations for all repository persistence models.
// Persistence models are defined within individual `*_repo.go` files.
func AutoMigrate(db *gorm.DB) error {
return db.AutoMigrate(
&userModel{},
&apiKeyModel{},
&groupModel{},
&accountModel{},
&accountGroupModel{},
&proxyModel{},
&redeemCodeModel{},
&usageLogModel{},
&settingModel{},
&userSubscriptionModel{},
)
}

View File

@@ -0,0 +1,183 @@
package repository
import (
"context"
"errors"
"fmt"
"log"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
billingBalanceKeyPrefix = "billing:balance:"
billingSubKeyPrefix = "billing:sub:"
billingCacheTTL = 5 * time.Minute
)
// billingBalanceKey generates the Redis key for user balance cache.
func billingBalanceKey(userID int64) string {
return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
}
// billingSubKey generates the Redis key for subscription cache.
func billingSubKey(userID, groupID int64) string {
return fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
}
const (
subFieldStatus = "status"
subFieldExpiresAt = "expires_at"
subFieldDailyUsage = "daily_usage"
subFieldWeeklyUsage = "weekly_usage"
subFieldMonthlyUsage = "monthly_usage"
subFieldVersion = "version"
)
var (
deductBalanceScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
return 0
end
local newVal = tonumber(current) - tonumber(ARGV[1])
redis.call('SET', KEYS[1], newVal)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
updateSubUsageScript = redis.NewScript(`
local exists = redis.call('EXISTS', KEYS[1])
if exists == 0 then
return 0
end
local cost = tonumber(ARGV[1])
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
)
type billingCache struct {
rdb *redis.Client
}
func NewBillingCache(rdb *redis.Client) service.BillingCache {
return &billingCache{rdb: rdb}
}
func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
key := billingBalanceKey(userID)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return 0, err
}
return strconv.ParseFloat(val, 64)
}
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
key := billingBalanceKey(userID)
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
}
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
key := billingBalanceKey(userID)
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
}
return nil
}
func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error {
key := billingBalanceKey(userID)
return c.rdb.Del(ctx, key).Err()
}
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*service.SubscriptionCacheData, error) {
key := billingSubKey(userID, groupID)
result, err := c.rdb.HGetAll(ctx, key).Result()
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, redis.Nil
}
return c.parseSubscriptionCache(result)
}
func (c *billingCache) parseSubscriptionCache(data map[string]string) (*service.SubscriptionCacheData, error) {
result := &service.SubscriptionCacheData{}
result.Status = data[subFieldStatus]
if result.Status == "" {
return nil, errors.New("invalid cache: missing status")
}
if expiresStr, ok := data[subFieldExpiresAt]; ok {
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
if err == nil {
result.ExpiresAt = time.Unix(expiresAt, 0)
}
}
if dailyStr, ok := data[subFieldDailyUsage]; ok {
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
}
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
}
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
}
if versionStr, ok := data[subFieldVersion]; ok {
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
}
return result, nil
}
func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *service.SubscriptionCacheData) error {
if data == nil {
return nil
}
key := billingSubKey(userID, groupID)
fields := map[string]any{
subFieldStatus: data.Status,
subFieldExpiresAt: data.ExpiresAt.Unix(),
subFieldDailyUsage: data.DailyUsage,
subFieldWeeklyUsage: data.WeeklyUsage,
subFieldMonthlyUsage: data.MonthlyUsage,
subFieldVersion: data.Version,
}
pipe := c.rdb.Pipeline()
pipe.HSet(ctx, key, fields)
pipe.Expire(ctx, key, billingCacheTTL)
_, err := pipe.Exec(ctx)
return err
}
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
key := billingSubKey(userID, groupID)
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
}
return nil
}
func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
key := billingSubKey(userID, groupID)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,283 @@
//go:build integration
package repository
import (
"context"
"fmt"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type BillingCacheSuite struct {
IntegrationRedisSuite
}
func (s *BillingCacheSuite) TestUserBalance() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
}{
{
name: "missing_key_returns_redis_nil",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
_, err := cache.GetUserBalance(ctx, 1)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing balance key")
},
},
{
name: "deduct_on_nonexistent_is_noop",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(1)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 1), "DeductUserBalance should not error")
_, err := rdb.Get(ctx, balanceKey).Result()
require.ErrorIs(s.T(), err, redis.Nil, "expected missing key after deduct on non-existent")
},
},
{
name: "set_and_get_with_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(2)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
got, err := cache.GetUserBalance(ctx, userID)
require.NoError(s.T(), err, "GetUserBalance")
require.Equal(s.T(), 10.5, got, "balance mismatch")
ttl, err := rdb.TTL(ctx, balanceKey).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
},
},
{
name: "deduct_reduces_balance",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(3)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 2.25), "DeductUserBalance")
got, err := cache.GetUserBalance(ctx, userID)
require.NoError(s.T(), err, "GetUserBalance after deduct")
require.Equal(s.T(), 8.25, got, "deduct mismatch")
},
},
{
name: "invalidate_removes_key",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(100)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 50.0), "SetUserBalance")
exists, err := rdb.Exists(ctx, balanceKey).Result()
require.NoError(s.T(), err, "Exists")
require.Equal(s.T(), int64(1), exists, "expected balance key to exist")
require.NoError(s.T(), cache.InvalidateUserBalance(ctx, userID), "InvalidateUserBalance")
exists, err = rdb.Exists(ctx, balanceKey).Result()
require.NoError(s.T(), err, "Exists after invalidate")
require.Equal(s.T(), int64(0), exists, "expected balance key to be removed after invalidate")
_, err = cache.GetUserBalance(ctx, userID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
},
},
{
name: "deduct_refreshes_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(103)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 100.0), "SetUserBalance")
ttl1, err := rdb.TTL(ctx, balanceKey).Result()
require.NoError(s.T(), err, "TTL before deduct")
s.AssertTTLWithin(ttl1, 1*time.Second, billingCacheTTL)
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 25.0), "DeductUserBalance")
balance, err := cache.GetUserBalance(ctx, userID)
require.NoError(s.T(), err, "GetUserBalance")
require.Equal(s.T(), 75.0, balance, "expected balance 75.0")
ttl2, err := rdb.TTL(ctx, balanceKey).Result()
require.NoError(s.T(), err, "TTL after deduct")
s.AssertTTLWithin(ttl2, 1*time.Second, billingCacheTTL)
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
rdb := testRedis(s.T())
cache := NewBillingCache(rdb)
ctx := context.Background()
tt.fn(ctx, rdb, cache)
})
}
}
func (s *BillingCacheSuite) TestSubscriptionCache() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
}{
{
name: "missing_key_returns_redis_nil",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(10)
groupID := int64(20)
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing subscription key")
},
},
{
name: "update_usage_on_nonexistent_is_noop",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(11)
groupID := int64(21)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 1.0), "UpdateSubscriptionUsage should not error")
exists, err := rdb.Exists(ctx, subKey).Result()
require.NoError(s.T(), err, "Exists")
require.Equal(s.T(), int64(0), exists, "expected missing subscription key after UpdateSubscriptionUsage on non-existent")
},
},
{
name: "set_and_get_with_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(12)
groupID := int64(22)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
data := &service.SubscriptionCacheData{
Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour),
DailyUsage: 1.0,
WeeklyUsage: 2.0,
MonthlyUsage: 3.0,
Version: 7,
}
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
require.NoError(s.T(), err, "GetSubscriptionCache")
require.Equal(s.T(), "active", gotSub.Status)
require.Equal(s.T(), int64(7), gotSub.Version)
require.Equal(s.T(), 1.0, gotSub.DailyUsage)
ttl, err := rdb.TTL(ctx, subKey).Result()
require.NoError(s.T(), err, "TTL subKey")
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
},
},
{
name: "update_usage_increments_all_fields",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(13)
groupID := int64(23)
data := &service.SubscriptionCacheData{
Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour),
DailyUsage: 1.0,
WeeklyUsage: 2.0,
MonthlyUsage: 3.0,
Version: 1,
}
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 0.5), "UpdateSubscriptionUsage")
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
require.NoError(s.T(), err, "GetSubscriptionCache after update")
require.Equal(s.T(), 1.5, gotSub.DailyUsage)
require.Equal(s.T(), 2.5, gotSub.WeeklyUsage)
require.Equal(s.T(), 3.5, gotSub.MonthlyUsage)
},
},
{
name: "invalidate_removes_key",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(101)
groupID := int64(10)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
data := &service.SubscriptionCacheData{
Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour),
DailyUsage: 1.0,
WeeklyUsage: 2.0,
MonthlyUsage: 3.0,
Version: 1,
}
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
exists, err := rdb.Exists(ctx, subKey).Result()
require.NoError(s.T(), err, "Exists")
require.Equal(s.T(), int64(1), exists, "expected subscription key to exist")
require.NoError(s.T(), cache.InvalidateSubscriptionCache(ctx, userID, groupID), "InvalidateSubscriptionCache")
exists, err = rdb.Exists(ctx, subKey).Result()
require.NoError(s.T(), err, "Exists after invalidate")
require.Equal(s.T(), int64(0), exists, "expected subscription key to be removed after invalidate")
_, err = cache.GetSubscriptionCache(ctx, userID, groupID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
},
},
{
name: "missing_status_returns_parsing_error",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(102)
groupID := int64(11)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
fields := map[string]any{
"expires_at": time.Now().Add(1 * time.Hour).Unix(),
"daily_usage": 1.0,
"weekly_usage": 2.0,
"monthly_usage": 3.0,
"version": 1,
}
require.NoError(s.T(), rdb.HSet(ctx, subKey, fields).Err(), "HSet")
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
require.Error(s.T(), err, "expected error for missing status field")
require.NotErrorIs(s.T(), err, redis.Nil, "expected parsing error, not redis.Nil")
require.Equal(s.T(), "invalid cache: missing status", err.Error())
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
rdb := testRedis(s.T())
cache := NewBillingCache(rdb)
ctx := context.Background()
tt.fn(ctx, rdb, cache)
})
}
}
func TestBillingCacheSuite(t *testing.T) {
suite.Run(t, new(BillingCacheSuite))
}

View File

@@ -0,0 +1,87 @@
//go:build unit
package repository
import (
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestBillingBalanceKey(t *testing.T) {
tests := []struct {
name string
userID int64
expected string
}{
{
name: "normal_user_id",
userID: 123,
expected: "billing:balance:123",
},
{
name: "zero_user_id",
userID: 0,
expected: "billing:balance:0",
},
{
name: "negative_user_id",
userID: -1,
expected: "billing:balance:-1",
},
{
name: "max_int64",
userID: math.MaxInt64,
expected: "billing:balance:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := billingBalanceKey(tc.userID)
require.Equal(t, tc.expected, got)
})
}
}
func TestBillingSubKey(t *testing.T) {
tests := []struct {
name string
userID int64
groupID int64
expected string
}{
{
name: "normal_ids",
userID: 123,
groupID: 456,
expected: "billing:sub:123:456",
},
{
name: "zero_ids",
userID: 0,
groupID: 0,
expected: "billing:sub:0:0",
},
{
name: "negative_ids",
userID: -1,
groupID: -2,
expected: "billing:sub:-1:-2",
},
{
name: "max_int64_ids",
userID: math.MaxInt64,
groupID: math.MaxInt64,
expected: "billing:sub:9223372036854775807:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := billingSubKey(tc.userID, tc.groupID)
require.Equal(t, tc.expected, got)
})
}
}

View File

@@ -0,0 +1,246 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3"
)
func NewClaudeOAuthClient() service.ClaudeOAuthClient {
return &claudeOAuthService{
baseURL: "https://claude.ai",
tokenURL: oauth.TokenURL,
clientFactory: createReqClient,
}
}
type claudeOAuthService struct {
baseURL string
tokenURL string
clientFactory func(proxyURL string) *req.Client
}
func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
client := s.clientFactory(proxyURL)
var orgs []struct {
UUID string `json:"uuid"`
}
targetURL := s.baseURL + "/api/organizations"
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
resp, err := client.R().
SetContext(ctx).
SetCookies(&http.Cookie{
Name: "sessionKey",
Value: sessionKey,
}).
SetSuccessResult(&orgs).
Get(targetURL)
if err != nil {
log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
}
if len(orgs) == 0 {
return "", fmt.Errorf("no organizations found")
}
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
return orgs[0].UUID, nil
}
func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
client := s.clientFactory(proxyURL)
authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID)
reqBody := map[string]any{
"response_type": "code",
"client_id": oauth.ClientID,
"organization_uuid": orgUUID,
"redirect_uri": oauth.RedirectURI,
"scope": scope,
"state": state,
"code_challenge": codeChallenge,
"code_challenge_method": "S256",
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
var result struct {
RedirectURI string `json:"redirect_uri"`
}
resp, err := client.R().
SetContext(ctx).
SetCookies(&http.Cookie{
Name: "sessionKey",
Value: sessionKey,
}).
SetHeader("Accept", "application/json").
SetHeader("Accept-Language", "en-US,en;q=0.9").
SetHeader("Cache-Control", "no-cache").
SetHeader("Origin", "https://claude.ai").
SetHeader("Referer", "https://claude.ai/new").
SetHeader("Content-Type", "application/json").
SetBody(reqBody).
SetSuccessResult(&result).
Post(authURL)
if err != nil {
log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
}
if result.RedirectURI == "" {
return "", fmt.Errorf("no redirect_uri in response")
}
parsedURL, err := url.Parse(result.RedirectURI)
if err != nil {
return "", fmt.Errorf("failed to parse redirect_uri: %w", err)
}
queryParams := parsedURL.Query()
authCode := queryParams.Get("code")
responseState := queryParams.Get("state")
if authCode == "" {
return "", fmt.Errorf("no authorization code in redirect_uri")
}
fullCode := authCode
if responseState != "" {
fullCode = authCode + "#" + responseState
}
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", prefix(authCode, 20))
return fullCode, nil
}
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) {
client := s.clientFactory(proxyURL)
// Parse code which may contain state in format "authCode#state"
authCode := code
codeState := ""
if idx := strings.Index(code, "#"); idx != -1 {
authCode = code[:idx]
codeState = code[idx+1:]
}
reqBody := map[string]any{
"code": authCode,
"grant_type": "authorization_code",
"client_id": oauth.ClientID,
"redirect_uri": oauth.RedirectURI,
"code_verifier": codeVerifier,
}
if codeState != "" {
reqBody["state"] = codeState
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
var tokenResp oauth.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetHeader("Content-Type", "application/json").
SetBody(reqBody).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)
if err != nil {
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
return nil, fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
}
log.Printf("[OAuth] Step 3 SUCCESS - Got access token")
return &tokenResp, nil
}
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
client := s.clientFactory(proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken)
formData.Set("client_id", oauth.ClientID)
var tokenResp oauth.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
}
func createReqClient(proxyURL string) *req.Client {
client := req.C().
ImpersonateChrome().
SetTimeout(60 * time.Second)
if proxyURL != "" {
client.SetProxyURL(proxyURL)
}
return client
}
func prefix(s string, n int) string {
if n <= 0 {
return ""
}
if len(s) <= n {
return s
}
return s[:n]
}

View File

@@ -0,0 +1,343 @@
package repository
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ClaudeOAuthServiceSuite struct {
suite.Suite
srv *httptest.Server
client *claudeOAuthService
}
func (s *ClaudeOAuthServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
// requestCapture holds captured request data for assertions in the main goroutine.
type requestCapture struct {
path string
method string
cookies []*http.Cookie
body []byte
formValues url.Values
bodyJSON map[string]any
contentType string
}
func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
tests := []struct {
name string
handler http.HandlerFunc
wantErr bool
errContain string
wantUUID string
validate func(captured requestCapture)
}{
{
name: "success",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`[{"uuid":"org-1"}]`))
},
wantUUID: "org-1",
validate: func(captured requestCapture) {
require.Equal(s.T(), "/api/organizations", captured.path, "unexpected path")
require.Len(s.T(), captured.cookies, 1, "expected 1 cookie")
require.Equal(s.T(), "sessionKey", captured.cookies[0].Name)
require.Equal(s.T(), "sess", captured.cookies[0].Value)
},
},
{
name: "non_200_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte("unauthorized"))
},
wantErr: true,
errContain: "401",
},
{
name: "invalid_json_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte("not-json"))
},
wantErr: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.path = r.URL.Path
captured.cookies = r.Cookies()
tt.handler(w, r)
}))
defer s.srv.Close()
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.baseURL = s.srv.URL
got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "")
if tt.wantErr {
require.Error(s.T(), err)
if tt.errContain != "" {
require.ErrorContains(s.T(), err, tt.errContain)
}
return
}
require.NoError(s.T(), err)
require.Equal(s.T(), tt.wantUUID, got)
if tt.validate != nil {
tt.validate(captured)
}
})
}
}
func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
tests := []struct {
name string
handler http.HandlerFunc
wantErr bool
wantCode string
validate func(captured requestCapture)
}{
{
name: "parses_redirect_uri",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{
"redirect_uri": oauth.RedirectURI + "?code=AUTH&state=STATE",
})
},
wantCode: "AUTH#STATE",
validate: func(captured requestCapture) {
require.True(s.T(), strings.HasPrefix(captured.path, "/v1/oauth/") && strings.HasSuffix(captured.path, "/authorize"), "unexpected path: %s", captured.path)
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
require.Len(s.T(), captured.cookies, 1, "expected 1 cookie")
require.Equal(s.T(), "sess", captured.cookies[0].Value)
require.Equal(s.T(), "org-1", captured.bodyJSON["organization_uuid"])
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
require.Equal(s.T(), "st", captured.bodyJSON["state"])
},
},
{
name: "missing_code_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{
"redirect_uri": oauth.RedirectURI + "?state=STATE", // no code
})
},
wantErr: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.path = r.URL.Path
captured.method = r.Method
captured.cookies = r.Cookies()
captured.body, _ = io.ReadAll(r.Body)
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r)
}))
defer s.srv.Close()
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.baseURL = s.srv.URL
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "")
if tt.wantErr {
require.Error(s.T(), err)
return
}
require.NoError(s.T(), err)
require.Equal(s.T(), tt.wantCode, code)
if tt.validate != nil {
tt.validate(captured)
}
})
}
}
func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
tests := []struct {
name string
handler http.HandlerFunc
code string
wantErr bool
wantResp *oauth.TokenResponse
validate func(captured requestCapture)
}{
{
name: "sends_state_when_embedded",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
AccessToken: "at",
TokenType: "bearer",
ExpiresIn: 3600,
RefreshToken: "rt",
Scope: "s",
})
},
code: "AUTH#STATE2",
wantResp: &oauth.TokenResponse{
AccessToken: "at",
RefreshToken: "rt",
},
validate: func(captured requestCapture) {
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
require.True(s.T(), strings.HasPrefix(captured.contentType, "application/json"), "unexpected content-type")
require.Equal(s.T(), "AUTH", captured.bodyJSON["code"])
require.Equal(s.T(), "STATE2", captured.bodyJSON["state"])
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
require.Equal(s.T(), "ver", captured.bodyJSON["code_verifier"])
},
},
{
name: "non_200_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte("bad request"))
},
code: "AUTH",
wantErr: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.method = r.Method
captured.contentType = r.Header.Get("Content-Type")
captured.body, _ = io.ReadAll(r.Body)
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r)
}))
defer s.srv.Close()
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.tokenURL = s.srv.URL
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "")
if tt.wantErr {
require.Error(s.T(), err)
return
}
require.NoError(s.T(), err)
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
require.Equal(s.T(), tt.wantResp.RefreshToken, resp.RefreshToken)
if tt.validate != nil {
tt.validate(captured)
}
})
}
}
func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
tests := []struct {
name string
handler http.HandlerFunc
wantErr bool
wantResp *oauth.TokenResponse
validate func(captured requestCapture)
}{
{
name: "sends_form",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{AccessToken: "at2", TokenType: "bearer", ExpiresIn: 3600})
},
wantResp: &oauth.TokenResponse{AccessToken: "at2"},
validate: func(captured requestCapture) {
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
require.Equal(s.T(), "refresh_token", captured.formValues.Get("grant_type"))
require.Equal(s.T(), "rt", captured.formValues.Get("refresh_token"))
require.Equal(s.T(), oauth.ClientID, captured.formValues.Get("client_id"))
},
},
{
name: "non_200_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte("unauthorized"))
},
wantErr: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.method = r.Method
captured.body, _ = io.ReadAll(r.Body)
captured.formValues, _ = url.ParseQuery(string(captured.body))
tt.handler(w, r)
}))
defer s.srv.Close()
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.tokenURL = s.srv.URL
resp, err := s.client.RefreshToken(context.Background(), "rt", "")
if tt.wantErr {
require.Error(s.T(), err)
return
}
require.NoError(s.T(), err)
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
if tt.validate != nil {
tt.validate(captured)
}
})
}
}
func TestClaudeOAuthServiceSuite(t *testing.T) {
suite.Run(t, new(ClaudeOAuthServiceSuite))
}

View File

@@ -0,0 +1,67 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
type claudeUsageService struct {
usageURL string
}
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
return &claudeUsageService{usageURL: defaultClaudeUsageURL}
}
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
transport, ok := http.DefaultTransport.(*http.Transport)
if !ok {
return nil, fmt.Errorf("failed to get default transport")
}
transport = transport.Clone()
if proxyURL != "" {
if parsedURL, err := url.Parse(proxyURL); err == nil {
transport.Proxy = http.ProxyURL(parsedURL)
}
}
client := &http.Client{
Transport: transport,
Timeout: 30 * time.Second,
}
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
}
var usageResp service.ClaudeUsageResponse
if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil {
return nil, fmt.Errorf("decode response failed: %w", err)
}
return &usageResp, nil
}

View File

@@ -0,0 +1,105 @@
package repository
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ClaudeUsageServiceSuite struct {
suite.Suite
srv *httptest.Server
fetcher *claudeUsageService
}
func (s *ClaudeUsageServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
// usageRequestCapture holds captured request data for assertions in the main goroutine.
type usageRequestCapture struct {
authorization string
anthropicBeta string
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
var captured usageRequestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.authorization = r.Header.Get("Authorization")
captured.anthropicBeta = r.Header.Get("anthropic-beta")
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{
"five_hour": {"utilization": 12.5, "resets_at": "2025-01-01T00:00:00Z"},
"seven_day": {"utilization": 34.0, "resets_at": "2025-01-08T00:00:00Z"},
"seven_day_sonnet": {"utilization": 56.0, "resets_at": "2025-01-08T00:00:00Z"}
}`)
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
require.NoError(s.T(), err, "FetchUsage")
require.Equal(s.T(), 12.5, resp.FiveHour.Utilization, "FiveHour utilization mismatch")
require.Equal(s.T(), 34.0, resp.SevenDay.Utilization, "SevenDay utilization mismatch")
require.Equal(s.T(), 56.0, resp.SevenDaySonnet.Utilization, "SevenDaySonnet utilization mismatch")
// Assertions on captured request data
require.Equal(s.T(), "Bearer at", captured.authorization, "Authorization header mismatch")
require.Equal(s.T(), "oauth-2025-04-20", captured.anthropicBeta, "anthropic-beta header mismatch")
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = io.WriteString(w, "nope")
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "status 401")
require.ErrorContains(s.T(), err, "nope")
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-json")
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "decode response failed")
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Never respond - simulate slow server
<-r.Context().Done()
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
_, err := s.fetcher.FetchUsage(ctx, "at", "")
require.Error(s.T(), err, "expected error for cancelled context")
}
func TestClaudeUsageServiceSuite(t *testing.T) {
suite.Run(t, new(ClaudeUsageServiceSuite))
}

View File

@@ -0,0 +1,203 @@
package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
// Key prefixes for independent slot keys
// Format: concurrency:account:{accountID}:{requestID}
accountSlotKeyPrefix = "concurrency:account:"
// Format: concurrency:user:{userID}:{requestID}
userSlotKeyPrefix = "concurrency:user:"
// Wait queue keeps counter format: concurrency:wait:{userID}
waitQueueKeyPrefix = "concurrency:wait:"
// Slot TTL - each slot expires independently
slotTTL = 5 * time.Minute
)
var (
// acquireScript uses SCAN to count existing slots and creates new slot if under limit
// KEYS[1] = pattern for SCAN (e.g., "concurrency:account:2:*")
// KEYS[2] = full slot key (e.g., "concurrency:account:2:req_xxx")
// ARGV[1] = maxConcurrency
// ARGV[2] = TTL in seconds
acquireScript = redis.NewScript(`
local pattern = KEYS[1]
local slotKey = KEYS[2]
local maxConcurrency = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
-- Count existing slots using SCAN
local cursor = "0"
local count = 0
repeat
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
cursor = result[1]
count = count + #result[2]
until cursor == "0"
-- Check if we can acquire a slot
if count < maxConcurrency then
redis.call('SET', slotKey, '1', 'EX', ttl)
return 1
end
return 0
`)
// getCountScript counts slots using SCAN
// KEYS[1] = pattern for SCAN
getCountScript = redis.NewScript(`
local pattern = KEYS[1]
local cursor = "0"
local count = 0
repeat
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
cursor = result[1]
count = count + #result[2]
until cursor == "0"
return count
`)
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
// KEYS[1] = wait queue key
// ARGV[1] = maxWait
// ARGV[2] = TTL in seconds
incrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current >= tonumber(ARGV[1]) then
return 0
end
local newVal = redis.call('INCR', KEYS[1])
-- Only set TTL on first creation to avoid refreshing zombie data
if newVal == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1
`)
// decrementWaitScript - same as before
decrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
)
type concurrencyCache struct {
rdb *redis.Client
}
func NewConcurrencyCache(rdb *redis.Client) service.ConcurrencyCache {
return &concurrencyCache{rdb: rdb}
}
// Helper functions for key generation
func accountSlotKey(accountID int64, requestID string) string {
return fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, requestID)
}
func accountSlotPattern(accountID int64) string {
return fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
}
func userSlotKey(userID int64, requestID string) string {
return fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, requestID)
}
func userSlotPattern(userID int64) string {
return fmt.Sprintf("%s%d:*", userSlotKeyPrefix, userID)
}
func waitQueueKey(userID int64) string {
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
}
// Account slot operations
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
pattern := accountSlotPattern(accountID)
slotKey := accountSlotKey(accountID, requestID)
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
slotKey := accountSlotKey(accountID, requestID)
return c.rdb.Del(ctx, slotKey).Err()
}
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
pattern := accountSlotPattern(accountID)
result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
if err != nil {
return 0, err
}
return result, nil
}
// User slot operations
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
pattern := userSlotPattern(userID)
slotKey := userSlotKey(userID, requestID)
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
slotKey := userSlotKey(userID, requestID)
return c.rdb.Del(ctx, slotKey).Err()
}
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
pattern := userSlotPattern(userID)
result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
if err != nil {
return 0, err
}
return result, nil
}
// Wait queue operations
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
key := waitQueueKey(userID)
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(slotTTL.Seconds())).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
key := waitQueueKey(userID)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}

View File

@@ -0,0 +1,231 @@
//go:build integration
package repository
import (
"errors"
"fmt"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ConcurrencyCacheSuite struct {
IntegrationRedisSuite
cache service.ConcurrencyCache
}
func (s *ConcurrencyCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewConcurrencyCache(s.rdb)
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
accountID := int64(10)
reqID1, reqID2, reqID3 := "req1", "req2", "req3"
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID1)
require.NoError(s.T(), err, "AcquireAccountSlot 1")
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID2)
require.NoError(s.T(), err, "AcquireAccountSlot 2")
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID3)
require.NoError(s.T(), err, "AcquireAccountSlot 3")
require.False(s.T(), ok, "expected third acquire to fail")
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err, "GetAccountConcurrency")
require.Equal(s.T(), 2, cur, "concurrency mismatch")
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID1), "ReleaseAccountSlot")
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err, "GetAccountConcurrency after release")
require.Equal(s.T(), 1, cur, "expected 1 after release")
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
accountID := int64(11)
reqID := "req_ttl_test"
slotKey := fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, reqID)
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID)
require.NoError(s.T(), err, "AcquireAccountSlot")
require.True(s.T(), ok)
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() {
accountID := int64(12)
reqID := "dup-req"
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Acquiring with same reqID should be idempotent
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
require.NoError(s.T(), err)
require.True(s.T(), ok)
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 1, cur, "expected concurrency=1 (idempotent)")
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_ReleaseIdempotent() {
accountID := int64(13)
reqID := "release-test"
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 1, reqID)
require.NoError(s.T(), err)
require.True(s.T(), ok)
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot")
// Releasing again should not error
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot again")
// Releasing non-existent should not error
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, "non-existent"), "ReleaseAccountSlot non-existent")
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 0, cur)
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_MaxZero() {
accountID := int64(14)
reqID := "max-zero-test"
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 0, reqID)
require.NoError(s.T(), err)
require.False(s.T(), ok, "expected acquire to fail with max=0")
}
func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() {
userID := int64(42)
reqID1, reqID2 := "req1", "req2"
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID1)
require.NoError(s.T(), err, "AcquireUserSlot")
require.True(s.T(), ok)
ok, err = s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID2)
require.NoError(s.T(), err, "AcquireUserSlot 2")
require.False(s.T(), ok, "expected second acquire to fail at max=1")
cur, err := s.cache.GetUserConcurrency(s.ctx, userID)
require.NoError(s.T(), err, "GetUserConcurrency")
require.Equal(s.T(), 1, cur, "expected concurrency=1")
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, reqID1), "ReleaseUserSlot")
// Releasing a non-existent slot should not error
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, "non-existent"), "ReleaseUserSlot non-existent")
cur, err = s.cache.GetUserConcurrency(s.ctx, userID)
require.NoError(s.T(), err, "GetUserConcurrency after release")
require.Equal(s.T(), 0, cur, "expected concurrency=0 after release")
}
func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
userID := int64(200)
reqID := "req_ttl_test"
slotKey := fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, reqID)
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID)
require.NoError(s.T(), err, "AcquireUserSlot")
require.True(s.T(), ok)
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
}
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
userID := int64(20)
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2)
require.NoError(s.T(), err, "IncrementWaitCount 1")
require.True(s.T(), ok)
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
require.NoError(s.T(), err, "IncrementWaitCount 2")
require.True(s.T(), ok)
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
require.NoError(s.T(), err, "IncrementWaitCount 3")
require.False(s.T(), ok, "expected wait increment over max to fail")
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.Equal(s.T(), 1, val, "expected wait count 1")
}
func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
userID := int64(300)
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
// Test decrement on non-existent key - should not error and should not create negative value
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key")
// Verify no key was created or it's not negative
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty")
// Set count to 1, then decrement twice
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 5)
require.NoError(s.T(), err, "IncrementWaitCount")
require.True(s.T(), ok)
// Decrement once (1 -> 0)
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
// Decrement again on 0 - should not go negative
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero")
// Verify count is 0, not negative
val, err = s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey after double decrement")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
}
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
// When no slots exist, GetAccountConcurrency should return 0
cur, err := s.cache.GetAccountConcurrency(s.ctx, 999)
require.NoError(s.T(), err)
require.Equal(s.T(), 0, cur)
}
func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
// When no slots exist, GetUserConcurrency should return 0
cur, err := s.cache.GetUserConcurrency(s.ctx, 999)
require.NoError(s.T(), err)
require.Equal(s.T(), 0, cur)
}
func TestConcurrencyCacheSuite(t *testing.T) {
suite.Run(t, new(ConcurrencyCacheSuite))
}

Some files were not shown because too many files have changed in this diff Show More